# Pytorch Lightning for ResNet using galaxy_datasets

## Imports

In [68]:
import os
from enum import Enum

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import lightning as pl
import albumentations as A

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch import Tensor
from torcheval.metrics import BinaryAccuracy
from torchvision.models.resnet import BasicBlock, Bottleneck 

from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
from typing import Any, Callable, List, Optional, Type, Union, Tuple

## Options

In [69]:
class class_mode(Enum):
    S_or_Z = 0
    S_or_Z_or_O = 1


USE_GPU = True
USE_DATA_SUBSET = False
SAVE_PATH = "../Models"

MODE = class_mode.S_or_Z_or_O

#Number of CW, ACW and EL to select
THRESHOLD = 0.8
N_CW = 5000
N_ACW = 5000
N_EL = 5000

IMG_SIZE = 160 # This is the output size of the generated image array

if USE_DATA_SUBSET:
    CATALOG_PATH = '../Data/subset_gz1_desi_cross_cat.csv'
    DATA_PATH = '../Data/Subset'
else:
    CATALOG_PATH = '../Data/gz1_desi_cross_cat.csv'
    DATA_PATH = '/share/nas2/walml/galaxy_zoo/decals/dr8/jpg'

torch.set_float32_matmul_precision("medium")

## GPU Test

In [70]:
#Run processes on CPU or GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"CPU cores available on device: {os.cpu_count()}")
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
print('Using device:', device)

CPU cores available on device: 24
NVIDIA A100-PCIE-40GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
Using device: cuda


## Reading in data

### Building catalog

In [71]:
catalog = pd.read_csv(CATALOG_PATH)
very_CW_galaxies = catalog[catalog['P_CW']>THRESHOLD]
very_ACW_galaxies = catalog[catalog['P_ACW']>THRESHOLD]
very_EL_galaxies = catalog[catalog['P_EL']>THRESHOLD]
print(f"Number of galaxies in GZ1 catalogue: {catalog.shape[0]}")
print(f"Very CW: {very_CW_galaxies.shape[0]}, Very ACW: {very_ACW_galaxies.shape[0]}, Very EL: {very_EL_galaxies.shape[0]}")

galaxy_subset = pd.concat([very_CW_galaxies[0:N_CW],very_ACW_galaxies[0:N_ACW],very_EL_galaxies[0:N_EL]])
catalog = galaxy_subset.reset_index()


if MODE == class_mode.S_or_Z:
    #Select only S or Z 
    catalog = catalog[catalog['P_EL']<0.8]
    #Select features (clockwise and anti-clockwise probabilities)
    Y = catalog[['P_CW','P_ACW']]
    classes = [r'P_CW',r'P_ACW']
    num_classes = 2

elif MODE == class_mode.S_or_Z_or_O:
    #Select only S or Z or other
    catalog['P_OTHER'] = catalog['P_EL']+catalog['P_EDGE']+catalog['P_DK']+catalog['P_MG']
    Y = catalog[['P_CW','P_ACW','P_OTHER']]
    classes = ['P_CW','P_ACW','P_OTHER']
    num_classes = 3

print(f"Loaded {catalog.shape[0]} galaxy images")

Number of galaxies in GZ1 catalogue: 647837
Very CW: 14243, Very ACW: 15420, Very EL: 143858
Loaded 15000 galaxy images


### Building file path list

In [72]:
def get_file_paths(catalog_to_convert,folder_path ):
    brick_ids = catalog_to_convert['dr8_id'].str.split("_",expand=True)[0]
    dr8_ids = catalog_to_convert['dr8_id']
    file_locations = folder_path+'/'+brick_ids+'/'+dr8_ids+'.jpg'
    print(f"Created {file_locations.shape[0]} galaxy filepaths")
    return file_locations

catalog['file_loc'] = get_file_paths(catalog,DATA_PATH)


Created 15000 galaxy filepaths


## Modified Resnet50 based on Jia et al (2023)

In [73]:
class JiaResnet(models.resnet.ResNet):
    def __init__(self,
        block: Type[Union[BasicBlock, Bottleneck]],
        use_max_pool: bool = True,
        use_avg_pool: bool = True,
        avg_pool_size: Tuple[int] = (1, 1),
        add_fc: Optional[List[int]] = None, *args, **kwargs):
        
        super().__init__(block, *args, **kwargs)

        self.avgpool = nn.AdaptiveAvgPool2d(avg_pool_size)
        pool_expansion = 1
        if not use_avg_pool:
            pool_expansion = 16 if use_max_pool else 64
        else:
            pool_expansion = np.prod(avg_pool_size) 

        self.fc = self._make_fc(512 * block.expansion * pool_expansion, num_classes, add_fc)

        self.use_max_pool = use_max_pool
        self.use_avg_pool = use_avg_pool

    def _make_fc(self, in_features: int, out_features: int, add_fc: Optional[List[int]]):
        if add_fc is None:
            return nn.Linear(in_features, out_features)
        else:
            add_fc.insert(0, in_features)
            add_fc.append(out_features)
            fc_layers = []
            for i in range(len(add_fc) - 1):
                fc_layers.append(nn.Linear(add_fc[i], add_fc[i + 1]))
                if i != len(add_fc) - 2:
                    fc_layers.append(nn.Tanh())
            return nn.Sequential(*fc_layers)
        
    def _forward_impl(self, x: Tensor) -> Tensor: #Override forward
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if self.use_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.use_avg_pool:
            x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
    
    def predict(self, x: Tensor) -> Tensor: #Override predict
        x_i = torch.flip(x, (-1,))
        a = self(x)
        a_i = self(x_i)
        return torch.cat((a[..., 0:1], a_i[..., 0:1], 0.5 * (a[..., 1:2] + a_i[..., 1:2])), dim=-1)


def JiaResnet50(**kwargs: Any) -> JiaResnet:
    model = JiaResnet(block=Bottleneck, layers=[3, 4, 6, 3], use_max_pool=True,
     use_avg_pool=True, avg_pool_size=(1, 1), add_fc=[512, 512, 64, 64],**kwargs)
    return model

## Resnet classifier module

In [74]:
class ResNetClassifier(pl.LightningModule):
    model_versions = {
        "resnet18": models.resnet18,
        "resnet34": models.resnet34,
        "resnet50": models.resnet50,
        "resnet101": models.resnet101,
        "resnet152": models.resnet152,
        "jiaresnet50": JiaResnet50
    }
    optimizers = {"adamw": optim.AdamW, "sgd": optim.SGD}
    schedulers = {"steplr": optim.lr_scheduler.StepLR}

    def __init__(
        self,
        num_classes,
        model_version,
        optimizer="adamw",
        scheduler  ="steplr",
        lr=1e-3,
        weight_decay=0,
        step_size=5,
        gamma=0.85,
        batch_size=16
    ):
        super().__init__()

        self.num_classes = num_classes
        self.lr = lr
        self.weight_decay = weight_decay
        self.step_size = step_size
        self.gamma = gamma
        self.batch_size = batch_size
        self.optimizer = self.optimizers[optimizer]
        self.scheduler = self.schedulers[scheduler]
        self.loss_fn = nn.CrossEntropyLoss()
        self.acc = self.accuracy_metric #Accuracy(task="multiclass", num_classes=num_classes)
        self.model = self.model_versions[model_version](num_classes=num_classes)

    def forward(self, X):
        return self.model(X)

    def configure_optimizers(self):
        optimizer_class = self.optimizer(self.parameters(), lr=self.lr,weight_decay=self.weight_decay)
        scheduler = self.scheduler(optimizer_class, step_size=self.step_size, gamma=self.gamma)
        return {
        "optimizer": optimizer_class,
        "lr_scheduler": {"scheduler": scheduler},
        }

    def _step(self, batch):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        acc = self.acc(preds, y)
        return loss, acc

    def training_step(self, batch, batch_idx):
        #time here
        loss, acc = self._step(batch)
        # perform logging
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("val_loss", loss, on_epoch=True, prog_bar=False, logger=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("test_loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("test_acc", acc, on_step=True, prog_bar=True, logger=True)

    def accuracy_metric(self,predicted_labels,true_labels):
        #Takes in softmaxed labels, checks if max column is the same

        true_highest_prob = torch.argmax(true_labels, dim=1)
        predicted_highest_prob = torch.argmax(predicted_labels, dim=1)   
        
        metric = BinaryAccuracy()
        metric.update(predicted_highest_prob,true_highest_prob)
        test_accuracy = metric.compute()
        return test_accuracy

## Code to run

In [75]:
def generate_transforms(resize_after_crop=IMG_SIZE):

    transforms_to_apply = [
        A.ToFloat(), #Converts from 0-255 to 0-1

        A.Resize( #Resizes to
            height=resize_after_crop,
            width=resize_after_crop,
            interpolation=1,
            always_apply=True
        ),
    ]

    return A.Compose(transforms_to_apply)

datamodule = GalaxyDataModule(
    label_cols=['P_CW','P_ACW','P_OTHER'],
    catalog=catalog,
    train_fraction=0.7,
    val_fraction=0.15,
    test_fraction=0.15,
    custom_albumentation_transform=generate_transforms(),
    batch_size=200,
    num_workers=11,
    #prefetch_factor=4,
)

datamodule.prepare_data()
datamodule.setup()

In [76]:
RUN_TEST = False
model = ResNetClassifier(
    num_classes=2, #2 for Jia et al version
    model_version="jiaresnet50",
    optimizer="adamw",
    scheduler  ="steplr",
    lr=0.0001,
    weight_decay=0,
    step_size=5,
    gamma=0.85,
    batch_size=60,
)

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
#stopping_callback = EarlyStopping(monitor="val_loss", mode="min")

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=50,
    devices=1,
    #callbacks=[stopping_callback]
)

trainer.fit(model,train_dataloaders=datamodule.train_dataloader(),val_dataloaders=datamodule.val_dataloader() )

if RUN_TEST:
    trainer.test(model,test_dataloader=datamodule.test_dataloader())
    
torch.save(trainer.model.state_dict(), SAVE_PATH + "/trained_model.pt")

/share/nas2/npower/miniconda3/envs/mphys-galaxy/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/nas2/npower/miniconda3/envs/mphys-galaxy/lib/ ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type             | Params
---------------------------------------------
0 | loss_fn | CrossEntropyLoss | 0     
1 | model   | JiaResnet        | 24.9 M
---------------------------------------------
24.9 M    Trainable params
0         Non-trainable params
24.9 M    Total params
99.428    Total estimated model params size (MB)


Epoch 49: 100%|██████████| 53/53 [01:01<00:00,  0.87it/s, v_num=13, train_loss_step=0.227, train_acc_step=0.640, val_acc=0.637, train_loss_epoch=0.232, train_acc_epoch=0.665]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 53/53 [01:04<00:00,  0.82it/s, v_num=13, train_loss_step=0.227, train_acc_step=0.640, val_acc=0.637, train_loss_epoch=0.232, train_acc_epoch=0.665]
