# First attempt

In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import tqdm
import sys
import timm
import torch
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader, Dataset

PATH = Path('.').resolve()
rng = np.random.RandomState(42)

In [2]:
df = pd.read_csv(f'{PATH}/data/sdss_64k.csv', dtype={'SpecObjID': str})
label_cols = [f'VAE{i+1}' for i in range(6)]

# normalize inputs
mean, std = df[label_cols].mean(), df[label_cols].std()
df[label_cols] = (df[label_cols] - mean) / std

# isolate galaxies
gal = df.loc[(df.primTarget & (64 | 128 | 256)) > 0].copy()
gal

Unnamed: 0,plate,mjd,fiberid,spec_class,VAE1,VAE2,VAE3,VAE4,VAE5,VAE6,SpecObjID,ra,dec,z,primTarget,snMedian,rChi2
0,266,51630,1,0,-0.368723,-0.917751,0.501150,-0.276343,0.462915,-0.062147,299489677444933632,146.714210,-1.041304,0.021222,96,39.046780,1.500621
1,266,51630,2,0,-1.158448,-0.605929,-0.147964,1.074922,-1.062042,-1.431129,299489952322840576,146.744130,-0.652191,0.203783,96,9.905860,1.302946
2,266,51630,4,0,-0.136203,-0.322451,0.598281,-0.124769,1.267600,0.908781,299490502078654464,146.628570,-0.765137,0.064656,64,17.981470,1.227630
3,266,51630,6,1,1.208109,0.956779,-0.228299,-0.030195,-0.627785,0.611048,299491051834468352,146.631670,-0.988278,0.052654,64,7.963448,1.404003
4,266,51630,7,0,-0.609212,-0.846412,-0.850208,0.566458,-1.246353,-1.169481,299491326712375296,146.919450,-0.990492,0.213861,64,7.951364,1.332148
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
63701,425,51898,602,0,-0.894158,-2.039106,0.413478,0.213428,-0.726911,0.577897,478672968751278080,23.780133,15.450872,0.072433,64,15.980110,1.097326
63702,425,51898,603,0,-1.086042,-1.317682,2.593123,-2.214063,0.349297,0.036920,478673243629185024,23.731774,15.375771,0.071264,96,11.848070,1.116136
63703,425,51898,604,0,-0.502692,-0.267302,-0.165126,0.010934,-1.216511,1.484335,478673518507091968,23.732863,15.356637,0.073216,64,15.784980,0.969229
63704,425,51898,606,0,-1.649470,-0.838885,-0.447896,1.096694,-0.792008,0.449849,478674068262905856,23.760782,15.487749,0.072091,96,31.279550,1.138224


In [3]:
image = np.load(f"{PATH}/data/ps1_npy_images/466156677649950720.npy")
image.shape

(4, 224, 224)

In [4]:
# remove wrong-sized inputs
gal = gal.set_index('SpecObjID').drop(["421190773933893632", "466155853016229888", "466156127894136832", "466156677649950720"]).reset_index()
gal.shape


(60229, 17)

# Helper function

In [5]:
def _open_npy(fn):
    return torch.from_numpy(np.nan_to_num(np.load(fn))).float()


class GalaxyDataset(Dataset):
    def __init__(self, df, image_data_dir="./data/ps1_npy_images", transform=None):
        self.data_dir = image_data_dir
        self.object_ids = df.SpecObjID.astype(str).values
        self.image_files = [f"{image_data_dir}/{objid}.npy" for objid in self.object_ids]
        self.spec_labels = df[[f"VAE{i}" for i in range(1, 7)]].values
        
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        data = _open_npy(self.image_files[idx])
        targets = self.spec_labels[idx]
        
        if self.transform:
            data = self.transform(data)
        return data, targets

dl = DataLoader(
    dataset=GalaxyDataset(gal, image_data_dir=f"{PATH}/data/ps1_npy_images"),
    batch_size=64,
    num_workers=16,
    shuffle=True
)

In [6]:
xb, yb = next(iter(dl))
xb.shape, yb.shape

(torch.Size([64, 5, 224, 224]), torch.Size([64, 6]))

In [7]:
image_means, image_stds = xb.mean((0,2,3)), xb.std((0,2,3))
image_means, image_stds

(tensor([123.6845,  86.1325, 149.2548,  69.4890,  31.9974]),
 tensor([2710.9961, 2086.4534, 4439.1387, 1942.0453,  759.2886]))

In [8]:
transforms_train = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5), 
    transforms.RandomVerticalFlip(0.5), 
    transforms.Normalize(image_means, image_stds)
])

transforms_valid = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5), 
    transforms.RandomVerticalFlip(0.5), 
    transforms.Normalize(image_means, image_stds)
])

In [9]:
train = gal.copy().sample(frac=0.8)
valid = gal.copy().drop(train.index)

In [10]:
dset_train = GalaxyDataset(train, transform=transforms_train)
dset_valid = GalaxyDataset(valid, transform=transforms_valid)

dl_train = torch.utils.data.DataLoader(dset_train, batch_size=64, shuffle=True)
dl_valid = torch.utils.data.DataLoader(dset_valid, batch_size=64, shuffle=True)

In [22]:
def create_model(model_type='convnext'):
    if model_type == 'convnext':
        # almost 16 GB of RAM used at bs = 64
        model = timm.create_model('convnext_tiny', pretrained=False)
        model.stem[0] = torch.nn.Conv2d(5, 96, kernel_size=(4, 4), stride=(4, 4))
        model.head.fc = torch.nn.Linear(in_features=768, out_features=6, bias=True)
    
    elif model_type == 'resnet':
        # uses >10 GB of RAM at bs = 128
        model = torchvision.models.resnet34()
        model.conv1 = torch.nn.Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        model.fc = torch.nn.Linear(in_features=512, out_features=6, bias=True)
        
    return model

In [23]:
device = torch.device(0)

model = create_model()
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)

# one-cycle schedule
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=3e-3, 
    steps_per_epoch=len(dl_train), 
    epochs=20
)

for epoch in range(1, 21):
    # train
    # =====
    model.train()
    train_loss = 0

    # continue grabbing batches of training data
    for xb, yb in tqdm.tqdm(iter(dl_train)):
        xb = xb.to(device).float()
        yb = yb.to(device).float()
        prediction = model(xb)
        loss = torch.nn.functional.mse_loss(prediction, yb)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    # validate
    # ========
    valid_loss = 0

    # no need to compute gradients, since we're not trying to update the model
    # this let's us perform another forward pass after one epoch of training
    with torch.no_grad():  
        model.eval()
        for xb, yb in tqdm.tqdm(iter(dl_valid)):   
            xb = xb.to(device).float()
            yb = yb.to(device).float()
            prediction = model(xb)
            loss = torch.nn.functional.mse_loss(prediction, yb)
            valid_loss += loss.item()

    # see how well we did
    train_loss = train_loss / len(dl_train)
    valid_loss = valid_loss / len(dl_valid)

    print(f'Epoch {epoch}    Train: {train_loss:.4f}      Valid: {valid_loss:.4f}')

100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [09:56<00:00,  1.26it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:37<00:00,  5.05it/s]


Epoch 1    Train: 0.8124      Valid: 0.7070


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:09<00:00,  1.24it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:35<00:00,  5.27it/s]


Epoch 2    Train: 0.6751      Valid: 0.6513


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:07<00:00,  1.24it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:35<00:00,  5.35it/s]


Epoch 3    Train: 0.6423      Valid: 0.6644


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:03<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.54it/s]


Epoch 4    Train: 0.6278      Valid: 0.6183


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:04<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.51it/s]


Epoch 5    Train: 0.6197      Valid: 0.6192


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:01<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.48it/s]


Epoch 6    Train: 0.6188      Valid: 0.6222


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:33<00:00,  5.71it/s]


Epoch 7    Train: 0.6100      Valid: 0.5936


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:35<00:00,  5.37it/s]


Epoch 8    Train: 0.5921      Valid: 0.5853


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.49it/s]


Epoch 9    Train: 0.5819      Valid: 0.5742


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.43it/s]


Epoch 10    Train: 0.5724      Valid: 0.5637


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.46it/s]


Epoch 11    Train: 0.5635      Valid: 0.5808


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:01<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:33<00:00,  5.58it/s]


Epoch 12    Train: 0.5517      Valid: 0.5548


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:03<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.43it/s]


Epoch 13    Train: 0.5423      Valid: 0.5533


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.50it/s]


Epoch 14    Train: 0.5315      Valid: 0.5417


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:03<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.52it/s]


Epoch 15    Train: 0.5194      Valid: 0.5399


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.41it/s]


Epoch 16    Train: 0.5090      Valid: 0.5213


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:03<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:33<00:00,  5.62it/s]


Epoch 17    Train: 0.4960      Valid: 0.5196


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.41it/s]


Epoch 18    Train: 0.4855      Valid: 0.5182


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.47it/s]


Epoch 19    Train: 0.4793      Valid: 0.5176


100%|███████████████████████████████████████████████████████████████████████████████| 753/753 [10:02<00:00,  1.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 189/189 [00:34<00:00,  5.50it/s]

Epoch 20    Train: 0.4761      Valid: 0.5151





In [24]:
# torch.save(model.state_dict(), './models/pure-pytorch_resnet34-20epochs.pth');
# torch.save(model.state_dict(), './models/pure-pytorch_convnext-tiny-20epochs.pth');

In [26]:
device = torch.device(0)

model = create_model()
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)

# one-cycle schedule
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=3e-3, 
    steps_per_epoch=len(dl_train), 
    epochs=60
)

for epoch in range(1, 61):
    # train
    # =====
    model.train()
    train_loss = 0

    # continue grabbing batches of training data
    for xb, yb in iter(dl_train):
        xb = xb.to(device).float()
        yb = yb.to(device).float()
        prediction = model(xb)
        loss = torch.nn.functional.mse_loss(prediction, yb)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    # validate
    # ========
    valid_loss = 0

    # no need to compute gradients, since we're not trying to update the model
    # this let's us perform another forward pass after one epoch of training
    with torch.no_grad():  
        model.eval()
        for xb, yb in iter(dl_valid):   
            xb = xb.to(device).float()
            yb = yb.to(device).float()
            prediction = model(xb)
            loss = torch.nn.functional.mse_loss(prediction, yb)
            valid_loss += loss.item()

    # see how well we did
    train_loss = train_loss / len(dl_train)
    valid_loss = valid_loss / len(dl_valid)

    print(f'Epoch {epoch}    Train: {train_loss:.4f}      Valid: {valid_loss:.4f}')

Epoch 1    Train: 0.8116      Valid: 0.7632
Epoch 2    Train: 0.6781      Valid: 0.6604
Epoch 3    Train: 0.6344      Valid: 0.6425
Epoch 4    Train: 0.6155      Valid: 0.6283
Epoch 5    Train: 0.6078      Valid: 0.6149
Epoch 6    Train: 0.6005      Valid: 0.5931
Epoch 7    Train: 0.5973      Valid: 0.5886
Epoch 8    Train: 0.5935      Valid: 0.5805
Epoch 9    Train: 0.5893      Valid: 0.6002
Epoch 10    Train: 0.5894      Valid: 0.5974
Epoch 11    Train: 0.5921      Valid: 0.5898
Epoch 12    Train: 0.5814      Valid: 0.5765
Epoch 13    Train: 0.5779      Valid: 0.5931
Epoch 14    Train: 0.5828      Valid: 0.5970
Epoch 15    Train: 0.5710      Valid: 0.5933
Epoch 16    Train: 0.5669      Valid: 0.5659
Epoch 17    Train: 0.5622      Valid: 0.5781
Epoch 18    Train: 0.5629      Valid: 0.5675
Epoch 19    Train: 0.5557      Valid: 0.5656
Epoch 20    Train: 0.5519      Valid: 0.5530
Epoch 21    Train: 0.5450      Valid: 0.5519
Epoch 22    Train: 0.5414      Valid: 0.5456
Epoch 23    Train: 

In [27]:
# torch.save(model.state_dict(), './models/pure-pytorch_convnext-tiny-60epochs.pth');