# Pytorch Lightning for ResNet using galaxy_datasets

## Imports

In [5]:
import os
from enum import Enum

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import lightning as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import albumentations as A

import torch
from torch import Tensor

from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

from custom_models.Jia_ResNet import JiaResnet50
from ChiralityClassifier import ChiralityClassifier

## Options

In [None]:
class class_mode(Enum):
    
    S_or_Z = 0
    S_or_Z_or_O = 1

USE_GPU = True
USE_DATA_SUBSET = False
SAVE_PATH = "../Models"

MODE = class_mode.S_or_Z_or_O

#Number of CW, ACW and EL to select
THRESHOLD = 0.8
N_CW = 5000
N_ACW = 5000
N_EL = 5000

IMG_SIZE = 160 # This is the output size of the generated image array

if USE_DATA_SUBSET:
    CATALOG_PATH = '../Data/subset_gz1_desi_cross_cat.csv'
    DATA_PATH = '../Data/Subset'
else:
    CATALOG_PATH = '../Data/gz1_desi_cross_cat.csv'
    DATA_PATH = '/share/nas2/walml/galaxy_zoo/decals/dr8/jpg'

torch.set_float32_matmul_precision("medium")

## GPU Test

In [None]:
#Run processes on CPU or GPU
print(f"Using pytorch {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"CPU cores available on device: {os.cpu_count()}")
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
print('Using device:', device)

## Reading in data

### Building catalog

In [None]:
catalog = pd.read_csv(CATALOG_PATH)
very_CW_galaxies = catalog[catalog['P_CW']>THRESHOLD]
very_ACW_galaxies = catalog[catalog['P_ACW']>THRESHOLD]
very_EL_galaxies = catalog[catalog['P_EL']>THRESHOLD]
print(f"Number of galaxies in GZ1 catalogue: {catalog.shape[0]}")
print(f"Very CW: {very_CW_galaxies.shape[0]}, Very ACW: {very_ACW_galaxies.shape[0]}, Very EL: {very_EL_galaxies.shape[0]}")

galaxy_subset = pd.concat([very_CW_galaxies[0:N_CW],very_ACW_galaxies[0:N_ACW],very_EL_galaxies[0:N_EL]])
catalog = galaxy_subset.reset_index()


if MODE == class_mode.S_or_Z:
    #Select only S or Z 
    catalog = catalog[catalog['P_EL']<0.8]
    #Select features (clockwise and anti-clockwise probabilities)
    Y = catalog[['P_CW','P_ACW']]
    classes = [r'P_CW',r'P_ACW']
    num_classes = 2

elif MODE == class_mode.S_or_Z_or_O:
    #Select only S or Z or other
    catalog['P_OTHER'] = catalog['P_EL']+catalog['P_EDGE']+catalog['P_DK']+catalog['P_MG']
    Y = catalog[['P_CW','P_ACW','P_OTHER']]
    classes = ['P_CW','P_ACW','P_OTHER']
    num_classes = 3

print(f"Loaded {catalog.shape[0]} galaxy images")

### Building file path list

In [None]:
def get_file_paths(catalog_to_convert,folder_path ):
    brick_ids = catalog_to_convert['dr8_id'].str.split("_",expand=True)[0]
    dr8_ids = catalog_to_convert['dr8_id']
    file_locations = folder_path+'/'+brick_ids+'/'+dr8_ids+'.jpg'
    print(f"Created {file_locations.shape[0]} galaxy filepaths")
    return file_locations

catalog['file_loc'] = get_file_paths(catalog,DATA_PATH)

## Code to run

In [None]:
def generate_transforms(resize_after_crop=IMG_SIZE):

    transforms_to_apply = [
        A.ToFloat(), #Converts from 0-255 to 0-1

        A.Resize( #Resizes to 160x160
            height=resize_after_crop,
            width=resize_after_crop,
            interpolation=1,
            always_apply=True
        ),
    ]

    return A.Compose(transforms_to_apply)

datamodule = GalaxyDataModule(
    label_cols=['P_CW','P_ACW','P_OTHER'],
    catalog=catalog,
    train_fraction=0.7,
    val_fraction=0.15,
    test_fraction=0.15,
    custom_albumentation_transform=generate_transforms(),
    batch_size=200,
    num_workers=11,
)

datamodule.prepare_data()
datamodule.setup()

In [None]:
RUN_TEST = False

model = ChiralityClassifier(
    num_classes=3, #2 for Jia et al version
    model_version="G_ResNet18",
    optimizer="adamw",
    scheduler  ="steplr",
    lr=0.0001,
    weight_decay=0,
    step_size=5,
    gamma=0.85,
    batch_size=60,
)


#stopping_callback = EarlyStopping(monitor="val_loss", mode="min")

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=60,
    devices=1,
    #callbacks=[stopping_callback]
)

compiled_model = torch.compile(model, backend="aot_eager")

trainer.fit(model,train_dataloaders=datamodule.train_dataloader(),val_dataloaders=datamodule.val_dataloader() )

if RUN_TEST:
    trainer.test(model,test_dataloader=datamodule.test_dataloader())
    
torch.save(trainer.model.state_dict(), SAVE_PATH + "/trained_model.pt")