<a href="https://colab.research.google.com/github/kotOcelot/Milky_Way_generator/blob/main/sdss_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###Импорты

In [None]:
from astropy.io import fits
import os
import math
import pandas as pd
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import shutil
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset, random_split
import torch
from PIL import Image
from torch.utils.data import random_split
from sklearn.preprocessing import StandardScaler
from torch import nn
import torch.optim as optim
from scipy.optimize import curve_fit
from torchvision.models import resnet18, ResNet18_Weights
from torchsummary import summary
from collections import OrderedDict
import random
from astropy.wcs import WCS
from astropy import units as u
from astropy.coordinates import SkyCoord
from IPython.display import clear_output

def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

clear_output()


###Картинки

In [None]:

from collections import defaultdict

from IPython.display import clear_output


class ProgressPlotter:

    """
    Groups contain a list of variables to output, like ["loss", "accuracy"]
    If group is None all variables will be plotted

    Title is experiment_id like "Relu_Adam_lr003"
    All new collected data binded to current title
    """

    def __init__(self, title="default", groups=None) -> None:
        self._history_dict = defaultdict(dict)
        self.set_title(title)
        self.groups = self.get_groups(groups)

    def get_groups(self, groups):
        if groups is not None:
            return self._history_dict.keys()
        if type(groups) is str:
            groups = [groups]
        return groups

    def set_title(self, title):
        for g in self._history_dict.keys():
            self._history_dict[g][title] = []  # reset data
        self.title = title

    # group e.g. "loss_val" tag e.g. "experiment_1"
    def add_scalar(self, group: str, value, tag=None) -> None:
        tag = self.title if tag is None else tag

        if not tag in self._history_dict[group]:
            self._history_dict[group][tag] = []
        self._history_dict[group][tag].append(value)

    def add_row(self, group: str, value, tag=None) -> None:
        tag = self.title if tag is None else tag
        self._history_dict[group][tag] = value

    def display_keys(self, ax, data):
        history_len = 0
        ax.grid()
        for key in data:
            ax.plot(data[key], label=key)
            history_len = max(history_len, len(data[key]))
            if len(data) > 1:
                ax.legend(loc="best")
            if history_len < 50:
                ax.set_xlabel("step")
                ax.set_xticks(np.arange(history_len))
                ax.set_xticklabels(np.arange(history_len))

    def display(self, groups=None):
        clear_output()
        if groups is None:
            groups = self.groups
        n_groups = len(groups)
        fig, ax = plt.subplots(1, n_groups, figsize=(48 // n_groups, 3))
        if n_groups == 1:
            ax = [ax]
        for i, g in enumerate(groups):
            ax[i].set_ylabel(g)
            self.display_keys(ax[i], self.history_dict[g])
        fig.tight_layout()
        plt.show()

    @property
    def history_dict(self):
        return dict(self._history_dict)

In [None]:
def plot(output, target, maps, labels = ['Output', 'Target', 'Input']):
  for k in range(len(output)):
    if (len(output) != 1):
      pic_rec = output[k]
      pic = target[k]
      map = maps[k]
    else:
      pic_rec = output
      pic = target
      map = maps
    rband_rec = torch.squeeze(pic_rec, 0)
    rband_orig = torch.squeeze(pic, 0)
    map_orig = torch.squeeze(map, 0)
    fig, ax = plt.subplot_mosaic([
        labels
    ], figsize=(14, 5))
    ax[labels[0]].imshow(rband_rec.cpu().detach().numpy())
    ax[labels[1]].imshow(rband_orig.cpu().detach().numpy())
    ax[labels[2]].imshow(torch.squeeze(map_orig).cpu().detach().numpy())
    ax[labels[0]].set_title(labels[0])
    ax[labels[1]].set_title(labels[1])
    ax[labels[2]].set_title(labels[2])
    plt.show()


In [None]:
def plot_dataset(pics, maps, labels, num = 10):
  for k in range(num):
      pic = pics[k]
      label = labels[k][:6]
      rband_orig = torch.squeeze(pic, 0)
      map_orig = torch.squeeze(map, 0)
      fig, ax = plt.subplot_mosaic([
          ['ha', 'stellar']
      ], figsize=(14, 5))
      ax['ha'].imshow(rband_orig.cpu().detach().numpy())
      ax['stellar'].imshow(torch.squeeze(map_orig).cpu().detach().numpy())
      ax['ha'].set_title(label)
      ax['stellar'].set_title(label)
      plt.show()


###Получение параметров

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# list_files = []
# with open('/content/drive/MyDrive/SDSS Manga data/list.dat') as f:
#   for line in f:
#       list_files.append(line.replace('%', '')[:-1])
# print(len(list_files))

In [None]:
# rgb_x = []
# rgb_y = []
# list_rgb = []
# with open('/content/drive/MyDrive/SDSS Manga data/rgb_orig/rgb_list_2.dat') as f:
#      for line in f:
#        if line[0] == '%':
#          name = line[:14].replace('%', '')
#          if (int(line[18:22])>=120) and (int(line[15:18])>=120) and (np.sum(np.isin(list_files, name)) == 1):
#            list_rgb.append(name)
#            rgb_y.append(int(line[15:18]))
#            rgb_x.append(int(line[18:22]))
# print(np.mean(rgb_y), np.mean(rgb_x))
# print(len(list_rgb))

In [None]:
# list_disp = []
# disp_medians = []
# with open('/content/drive/MyDrive/SDSS Manga data/st_disp/star_disp_list.dat') as f:
#      for line in f:
#       if line[0] == '%':
#           l = 0
#           ln = line[15:-1].split(' ')
#           for f in range(len(ln)):
#             if ln[f]!='':
#               l += 1
#               if l == 4:
#                 if float(ln[f])!=-350:
#                   disp_medians.append(ln[f])
#                   list_disp.append(line[:14].replace('%', ''))
# print(len(list_disp))

In [None]:
# list_ff = []
# mass = []
# lage = []
# mage = []
# lz = []
# mz = []
# with open('/content/drive/MyDrive/SDSS Manga data/firefly_list.dat') as f:
#   for line in f:
#       if line[0] == '%':
#           l = 0
#           ln = line[15:-1].split(' ')
#           new_line = []
#           for f in range(len(ln)):
#             if ln[f]!='':
#               new_line.append(float(ln[f]))
#           if np.sum(np.array(new_line) == -9999) == 0:
#               list_ff.append(line[:14].replace('%', ''))
#               mass.append(new_line[0])
#               lage.append(new_line[1])
#               mage.append(new_line[2])
#               lz.append(new_line[3])
#               mz.append(new_line[4])
# print(len(mage))

In [None]:
# list_ff = np.array(list_ff)
# list_disp = np.array(list_disp)
# list_rgb = np.array(list_rgb)
# int1 = np.intersect1d(list_ff, list_disp)
# int2 = np.intersect1d(int1, list_rgb)
# print(len(int2))

In [None]:
# list_dm = []
# list_mass = []
# list_lage = []
# list_mage = []
# list_lz = []
# list_mz = []
# for name in int2:
#   ind = np.where(list_disp == name)[0][0]
#   list_dm.append(disp_medians[ind])
#   ind = np.where(list_ff == name)[0][0]
#   list_mass.append(mass[ind])
#   list_lage.append(lage[ind])
#   list_mage.append(mage[ind])
#   list_lz.append(lz[ind])
#   list_mz.append(mz[ind])

In [None]:
# int2 = np.array(int2)
# list_dm = np.array(list_dm)
# list_mass = np.array(list_mass)
# list_lage = np.array(list_lage)
# list_mage = np.array(list_mage)
# list_lz = np.array(list_lz)
# list_mz = np.array(list_mz)

In [None]:
# int2 = int2[:, np.newaxis]
# list_dm = list_dm[:, np.newaxis]
# list_mass = list_mass[:, np.newaxis]
# list_lage = list_lage[:, np.newaxis]
# list_mage = list_mage[:, np.newaxis]
# list_lz = list_lz[:, np.newaxis]
# list_mz = list_mz[:, np.newaxis]


In [None]:
#arr = np.vstack((int2, list_dm,list_mass,list_lage,list_mage,list_lz,list_mz))

In [None]:
# df = pd.DataFrame(arr.T, columns=['name', 'disp, kms', 'mass, log m/msun', 'lw age, log Myr', 'mw age, log Myr', 'lw z, z/h', 'mw z, z/h'])


In [None]:
#df.to_csv('/content/drive/MyDrive/SDSS Manga data/galaxies_data.csv')

In [None]:
df = pd.read_csv('/content/drive/MyDrive/SDSS Manga data/galaxies_data.csv', index_col = 0)

In [None]:
len(df)

9833

In [None]:
list_files_test = df.name[:1000]

NameError: ignored

###Загрузка картинок

In [None]:
rgbs = []
for name_f in list_files_test:
  name = '{}_rgb.npy'.format(name_f)
  rgb = torch.permute(torch.tensor(np.load('/content/drive/MyDrive/SDSS Manga data/rgb_orig/{}'.format(name))).type(torch.FloatTensor), (2,0,1))
  rgbs.append(rgb)

###Датасет

In [None]:
class DS_pics(Dataset):
    def __init__(self, images, df, transform = None):
        super().__init__()
        self.images = images
        self.metal = df['mw z, z/h']
        self.age = df['mw age, log Myr']
        self.mass = df['mass, log m/msun']
        self.disp = df['disp, kms']
        self.transform = transform

    def __getitem__(self, indx):
        image = self.images[indx]
        met = self.metal[indx]
        a = self.age[indx]
        ms = self.mass[indx]
        dsp = self.disp[indx]


        if self.transform:
            image = self.transform(image)

        return image, met.astype(np.float32), a.astype(np.float32), ms.astype(np.float32), dsp.astype(np.float32)

    def __len__(self):
        return len(self.images)

In [None]:
rgb_test = rgbs[:2]
rgb_res = rgbs[2:]
df_test = df[:2]
df_res = df[2:]

In [None]:
transform = transforms.Compose([ transforms.Resize(size = (241, 241))],)

galaxies = DS_pics(rgb_res, df_res, transform=transform)
gal_test = DS_pics(rgb_test, df_test, transform=transform)

In [None]:
gal_train, gal_val  = random_split(galaxies, [800, 198])

In [None]:
train_loader = torch.utils.data.DataLoader(gal_train, batch_size=8, shuffle=True)
val_loader = torch.utils.data.DataLoader(gal_val, batch_size=8, shuffle=False)
test_loader = torch.utils.data.DataLoader(gal_test, batch_size=8, shuffle=False)

###Картинка -> числа



In [None]:
class Two_seq(nn.Module):
    def __init__(self):
        super().__init__()
        self.mass = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.disp = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )
        self.metal = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )
        self.age = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )
        self.avg = nn.AdaptiveAvgPool1d((2048))
        self.fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.Linear(1024, 512),
            nn.Linear(512, 1)
        )
    def forward(self, rgb):
        disp_out = self.disp(rgb)
        mass_out = self.mass(rgb)
        age_out = self.age(rgb)
        metal_out = self.metal(rgb)

        metal_out = metal_out.view(metal_out.size()[0], -1)
        metal_out = self.avg(metal_out)
        metal_out = self.fc(metal_out)
        metal_out = torch.squeeze(metal_out, 1)

        age_out = age_out.view(age_out.size()[0], -1)
        age_out = self.avg(age_out)
        age_out = self.fc(age_out)
        age_out = torch.squeeze(age_out, 1)

        mass_out = mass_out.view(mass_out.size()[0], -1)
        mass_out = self.avg(mass_out)
        mass_out = self.fc(mass_out)
        mass_out = torch.squeeze(mass_out, 1)

        disp_out = disp_out.view(disp_out.size()[0], -1)
        disp_out = self.avg(disp_out)
        disp_out = self.fc(disp_out)
        disp_out = torch.squeeze(disp_out, 1)
        return metal_out, age_out, mass_out, disp_out

###Обучение

In [None]:
def train(model, criterion, optimizer, num_ep, scheduler=None):
    progress = {}
    pp = ProgressPlotter(title="baseline", groups=["loss_train", "loss_val"])
    for epoch in range(num_ep):
          ep_loss_train = 0
          model.train()
          for (pic, *param) in train_loader: #metal, age, mass, disp
              optimizer.zero_grad()
              output = model(pic.to(device))
              [item.to(device) for item in param]
              #param = param.to(device)
              loss = criterion(output, param)
              loss.backward()
              optimizer.step()
              train_loss = loss.item()
              ep_loss_train += train_loss

          ep_loss_val = 0
          model.eval()
          for (pic, *param) in val_loader:
              with torch.no_grad():
                  output = model(pic.to(device))
                  output = torch.squeeze(output, 1)
                  [item.to(device) for item in param]
                  #param = param.to(device)
                  loss = criterion(output, param)
                  val_loss = loss.item()
                  ep_loss_val += loss.item()
                  if scheduler:
                      scheduler.step(val_loss)
          pp.add_scalar(group = "loss_train", value = ep_loss_train/(len(train_loader)), tag = "loss_train")
          pp.add_scalar(group = "loss_val", value = ep_loss_val/(len(val_loader)), tag = "loss_val")
          pp.display()

In [None]:
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self,yhat,y):
        yhat = torch.stack([torch.FloatTensor(yh) for yh in yhat])
        y =    torch.stack([torch.FloatTensor(y_) for y_ in y])
        return torch.sqrt(self.mse(yhat,y))


In [None]:
from torchvision.models import resnet50, ResNet50_Weights

set_random_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Two_seq()
model.to(device)
criterion = RMSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.03, weight_decay = 0.02)

In [None]:
pr = train(model, criterion, optimizer, 35)

KeyboardInterrupt: ignored

In [None]:
model.eval()
for pic, *label in test_loader:
  output = model(pic.to(device))
  print(output, label)

(tensor([11.3226, 11.3764], grad_fn=<SqueezeBackward1>), tensor([11.2548, 11.3764], grad_fn=<SqueezeBackward1>), tensor([11.3764, 11.3764], grad_fn=<SqueezeBackward1>), tensor([12.5839, 11.4122], grad_fn=<SqueezeBackward1>)) [tensor([-0.0279,  0.1465]), tensor([0.6133, 0.6826]), tensor([ 9.7665, 10.0043]), tensor([78.8419, 69.6468])]
