# Pytorch Lightning Framework for training S+Z Galaxy Classifiers

## Imports

In [25]:
import os
import gc
from enum import Enum
import pandas as pd
import torch
from torch.utils.data import random_split
import lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger,CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping


from ChiralityClassifier import ChiralityClassifier
from dataset_utils import *

## Options

In [26]:
class datasets(Enum):
    FULL_DATASET = 0 #Use all 600,000 galaxies in GZ1 catalog
    CUT_DATASET = 1 #Use cut of 200,000 galaxies, with pre-selected test data and downsampled train data
    BEST_SUBSET = 2 #Select N best S,Z & other galaxies, evenly split
    LOCAL_SUBSET = 3 #Use local cache of 1500 galaxies
    FULL_DESI_DATASET = 4 #Use all 7 million galaxies in DESI catalog, minus those that appear in cut catalog

class modes(Enum):
    TRAIN = 0 #Train on a dataset
    TEST = 1 #Test an existing saved model on a dataset
    PREDICT = 2 #Use an existing saved model on an unlabelled dataset

DATASET = datasets.BEST_SUBSET #Select which dataset to run
MODE = modes.PREDICT #Select which mode

# Models:
#resnet18,resnet34,resnet50,resnet101,resnet152,
#jiaresnet50,LeNet,G_ResNet18,G_LeNet,
MODEL_NAME = "resnet18"
CUSTOM_ID = ""

USE_TENSORBOARD = False #Log to tensorboard as well as csv logger
SAVE_MODEL = False #Save model weights to .pt file
REPEAT_RUNS = 1 #Set to 1 for 1 run
IMG_SIZE = 160 #This is the output size of the generated image array
BATCH_SIZE = 100 #Number of images per batch
NUM_WORKERS = 11 #Number of workers in dataloader (no of CPU cores - 1)

PATHS = dict(
    METRICS_PATH = "../Metrics",
    LOG_PATH= "../Code/lightning_logs",
    FULL_DATA_PATH =  "/share/nas2/walml/galaxy_zoo/decals/dr8/jpg",
    LOCAL_SUBSET_DATA_PATH =  "../Data/Subset",
    FULL_CATALOG_PATH =  "../Data/gz1_desi_cross_cat.csv",
    FULL_DESI_CATALOG_PATH =  "../Data/desi_full_cat.csv",
    CUT_CATALOG_TEST_PATH =  "../Data/gz1_desi_cross_cat_testing.csv",
    CUT_CATALOG_TRAIN_PATH = "../Data/gz1_desi_cross_cat_train_val_downsample.csv",
    BEST_SUBSET_CATALOG_PATH =  "../Data/gz1_desi_cross_cat_best_subset.csv",
    LOCAL_SUBSET_CATALOG_PATH =  "../Data/gz1_desi_cross_cat_local_subset.csv",
)

torch.set_float32_matmul_precision("medium")
if len(CUSTOM_ID) == 0:
    MODEL_ID = f"{MODEL_NAME}_{DATASET.name.lower()}"
else:
     MODEL_ID = f"{MODEL_NAME}_{DATASET.name.lower()}_{CUSTOM_ID}"

PRETRAINED_MODEL_PATH = f"{PATHS['METRICS_PATH']}/{MODEL_ID}/version_{0}/model.pt"
if MODE != modes.TRAIN:
    USE_TENSORBOARD = False #Don"t log to tensorboard if not training
    SAVE_MODEL = False #Don"t save weights if testing or predicting model
    REPEAT_RUNS = 1 #Don"t repeat runs if if testing or predicting model

## GPU Test

In [27]:
print(f"Using pytorch {torch.__version__}. CPU cores available on device: {os.cpu_count()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print(f"Allocated Memory: {round(torch.cuda.memory_allocated(0)/1024**3,1)} GB")
    print(f"Cached Memory: {round(torch.cuda.memory_reserved(0)/1024**3,1)} GB")
print("Using device:", device)

Using pytorch 2.2.1. CPU cores available on device: 24
NVIDIA A100-PCIE-40GB
Allocated Memory: 0.0 GB
Cached Memory: 0.0 GB
Using device: cuda


## Reading in data

### Building catalog

In [28]:
datamodule = generate_datamodule(DATASET,MODE,PATHS,datasets,modes,IMG_SIZE,BATCH_SIZE,NUM_WORKERS)

Created 15000 galaxy filepaths


## Code to run

In [29]:
datamodule.prepare_data()
datamodule.setup()

In [30]:
for run in range(0,REPEAT_RUNS):
    
    save_dir = f"{PATHS['METRICS_PATH']}/{MODEL_ID}/version_{run}"
    create_folder(save_dir)

    model = ChiralityClassifier(
        num_classes=(2 if (MODEL_NAME=="jiaresnet50") else 3), #2 for Jia et al version
        model_version=MODEL_NAME,
        optimizer="adamw",
        scheduler  ="steplr",
        lr=0.0001,
        weight_decay=0,
        step_size=5,
        gamma=0.85,
        weights=(PRETRAINED_MODEL_PATH if MODE != modes.TRAIN else None),
        model_save_path=f"{save_dir}/model.pt",
        graph_save_path=(f"{save_dir}/val_matrix.png" if MODE == modes.TRAIN else f"{save_dir}/{MODE.name.lower()}_matrix.png")
    )

    tb_logger = TensorBoardLogger(PATHS["LOG_PATH"], name=MODEL_ID,version=(f"{run}_val" if MODE == modes.TRAIN else f"{run}_{MODE.name.lower()}"))
    csv_logger = CSVLogger(PATHS["LOG_PATH"],name=MODEL_ID,version=(f"{run}_val" if MODE == modes.TRAIN else f"{run}_{MODE.name.lower()}"))
    trainer = pl.Trainer(
        accelerator=("gpu" if device.type=="cuda" else "cpu"),
        max_epochs=60,
        devices=1,
        logger=([tb_logger,csv_logger] if USE_TENSORBOARD else csv_logger),
        default_root_dir=f"{PATHS['LOG_PATH']}/{MODEL_ID}",
        enable_checkpointing=False,
        #profiler="pytorch"
        #callbacks=EarlyStopping(monitor="val_loss", mode="min")
    )

    #compiled_model = torch.compile(model, backend="eager")
    
    if MODE==modes.TRAIN:
        trainer.fit(model,train_dataloaders=datamodule.train_dataloader(),val_dataloaders=datamodule.val_dataloader())
        trainer.test(model,dataloaders=datamodule.val_dataloader())

        if SAVE_MODEL:
            torch.save(trainer.model.state_dict(), model.model_save_path)
        
    elif MODE==modes.TEST:
        trainer.test(model,dataloaders=datamodule.test_dataloader())
           
    elif MODE==modes.PREDICT:
        trainer.predict(model,dataloaders=datamodule.predict_dataloader())        

FileNotFoundError: [Errno 2] No such file or directory: '../Metrics/jiaresnet50_best_subset/version_0/model.pt'

In [None]:
#Dereference all objects, clear cuda cache and run garbage collection
datamodule=None
model=None
trainer=None
with torch.no_grad():
    torch.cuda.empty_cache()
gc.collect()

6928