# Test of Pytorch Lightning for ResNet using galaxy_datasets

## Imports

In [1]:
from enum import Enum
import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from scipy import ndimage
from PIL import Image
import lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader,TensorDataset,Subset, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torcheval.metrics import BinaryAccuracy

  from .autonotebook import tqdm as notebook_tqdm


## Options

In [2]:
class class_mode(Enum):
    S_or_Z = 0
    S_or_Z_or_O = 1


USE_GPU = True
CATALOG_PATH = '../Data/subset_gz1_desi_cross_cat.csv'
DATA_PATH = '../Data/Subset'
#Number of CW, ACW and EL to select
THRESHOLD = 0.8
N_CW = 500
N_ACW = 500
N_EL = 500

#On galahad
CATALOG_PATH = '../Data/gz1_desi_cross_cat.csv'
DATA_PATH = '/share/nas2/walml/galaxy_zoo/decals/dr8/jpg'
MODE = class_mode.S_or_Z_or_O

## GPU Test

In [3]:
#Run processes on CPU or GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
print('Using device:', device)

NVIDIA A100-PCIE-40GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
Using device: cuda


## Reading in data

### Building catalog

In [4]:
catalog = pd.read_csv(CATALOG_PATH)
very_CW_galaxies = catalog[catalog['P_CW']>THRESHOLD]
very_ACW_galaxies = catalog[catalog['P_ACW']>THRESHOLD]
very_EL_galaxies = catalog[catalog['P_EL']>THRESHOLD]
print(f"Number of galaxies in GZ1 catalogue: {catalog.shape[0]}")
print(f"Very CW: {very_CW_galaxies.shape[0]}, Very ACW: {very_ACW_galaxies.shape[0]}, Very EL: {very_EL_galaxies.shape[0]}")

galaxy_subset = pd.concat([very_CW_galaxies[0:N_CW],very_ACW_galaxies[0:N_ACW],very_EL_galaxies[0:N_EL]])
catalog = galaxy_subset.reset_index()


if MODE == class_mode.S_or_Z:
    #Select only S or Z 
    catalog = catalog[catalog['P_EL']<0.8]
    #Select features (clockwise and anti-clockwise probabilities)
    Y = catalog[['P_CW','P_ACW']]
    classes = [r'P_CW',r'P_ACW']
    num_classes = 2

elif MODE == class_mode.S_or_Z_or_O:
    #Select only S or Z or other
    catalog['P_OTHER'] = catalog['P_EL']+catalog['P_EDGE']+catalog['P_DK']+catalog['P_MG']
    Y = catalog[['P_CW','P_ACW','P_OTHER']]
    classes = ['P_CW','P_ACW','P_OTHER']
    num_classes = 3

print(f"Loaded {catalog.shape[0]} galaxy images")

Number of galaxies in GZ1 catalogue: 647837
Number of very CW galaxies in GZ1 catalogue: 14243
Number of very ACW galaxies in GZ1 catalogue: 15420
Number of very EL galaxies in GZ1 catalogue: 143858
Loaded 1500 galaxy images


### Building file path list

In [5]:
def get_file_paths(catalog_to_convert,folder_path ):
    brick_ids = catalog_to_convert['dr8_id'].str.split("_",expand=True)[0]
    dr8_ids = catalog_to_convert['dr8_id']
    file_locations = folder_path+'/'+brick_ids+'/'+dr8_ids+'.jpg'
    print(f"Created {file_locations.shape[0]} galaxy filepaths")
    return file_locations

catalog['file_loc'] = get_file_paths(catalog,DATA_PATH)


Created 1500 galaxy filepaths


## Resnet classifier module

In [14]:
class ResNetClassifier(pl.LightningModule):
    resnets = {
        18: models.resnet18,
        34: models.resnet34,
        50: models.resnet50,
        101: models.resnet101,
        152: models.resnet152,
    }
    optimizers = {"adam": optim.Adam, "sgd": optim.SGD}

    def __init__(
        self,
        num_classes,
        resnet_version,
        #datamodule,
        #train_dataset,
        #val_dataset,
        #test_dataset,
        optimizer="adam",
        lr=1e-3,
        batch_size=16
    ):
        super().__init__()

        self.num_classes = num_classes
        #self.train_dataset = train_dataset
        #self.val_dataset = val_dataset
        #self.test_dataset = test_dataset
        #self.datamodule = datamodule
        self.lr = lr
        self.batch_size = batch_size
        self.optimizer = optimizer #self.optimizers[optimizer]
        self.loss_fn = nn.CrossEntropyLoss()
        self.acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.model = self.resnets[resnet_version]()

    def forward(self, X):
        return self.model(X)

    def configure_optimizers(self):
        optimizer_class = self.optimizers[self.optimizer]
        return optimizer_class(self.parameters(), lr=self.lr)

    def _step(self, batch):
        x, y = batch
        preds = self(x)
        print(preds)
        loss = self.loss_fn(preds, y)
        acc = self.acc(preds, y)
        return loss, acc

    # def _dataloader(self, data_set, shuffle=False):
    #     # Can update this to deal with paths and process images

    #     return DataLoader(data_set, batch_size=self.batch_size, shuffle=shuffle)

    # def train_dataloader(self):
    #     return self._dataloader(self.train_dataset, shuffle=True)

    def training_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    # def val_dataloader(self):
    #     return self._dataloader(self.val_dataset)

    def validation_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("val_loss", loss, on_epoch=True, prog_bar=False, logger=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    # def test_dataloader(self):
    #     return self._dataloader(self.test_dataset)

    def test_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("test_loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("test_acc", acc, on_step=True, prog_bar=True, logger=True)


## Code to run

In [10]:
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

datamodule = GalaxyDataModule(
    label_cols=['P_CW','P_ACW','P_OTHER'],
    catalog=catalog,
    train_fraction=0.7,
    val_fraction=0.15,
    test_fraction=0.15,
    resize_after_crop=160,
    greyscale=False
)

#datamodule.prepare_data()
#datamodule.setup()
# print(len(datamodule.train_dataloader().dataset))
# print(len(datamodule.test_dataloader().dataset))
# print(len(datamodule.val_dataloader().dataset))
# for images, labels in datamodule.train_dataloader():
#     print(images.shape, labels.shape)
#     break


In [15]:
RUN_TEST = False #Run trained model on test dataset
from lightning.pytorch.callbacks import Callback

model = ResNetClassifier(
    num_classes=3,
    resnet_version=50,
    #train_path=args.train_set,
    #val_path=args.val_set,
    #test_path=args.test_set,
    #train_dataset = train_dataset,
    #val_dataset = val_dataset,
    #test_dataset = test_dataset,
    optimizer="adam",
    lr=0.0001,
    batch_size=60,
)

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
stopping_callback = EarlyStopping(monitor="val_loss", mode="min")

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=2,
    devices=1,
    callbacks=[stopping_callback]
)

save_path = "./models"


trainer.fit(model,train_dataloaders=datamodule.train_dataloader(),val_dataloaders=datamodule.val_dataloader() )

if RUN_TEST:
    trainer.test(model,test_dataloader=datamodule.test_dataloader())
    
torch.save(trainer.model.resnet_model.state_dict(), save_path + "/trained_model.pt")

/share/nas2/npower/miniconda3/envs/mphys-galaxy/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/nas2/npower/miniconda3/envs/mphys-galaxy/lib/ ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type               | Params
-----------------------------------------------
0 | loss_fn | CrossEntropyLoss   | 0     
1 | acc     | MulticlassAccuracy | 0     
2 | model   | ResNet             | 25.6 M
-----------------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.228   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]tensor([[  21.6169,  224.9325,  301.4130,  ..., -266.0323,   69.1480,
          312.2371],
        [  77.1853,  342.0848,  547.2253,  ..., -403.4380,   70.7112,
          613.3783],
        [  17.4018,  277.2077,  345.9117,  ..., -321.6996,   17.6744,
          465.7989],
        ...,
        [  25.0531,  500.0874,  844.5449,  ..., -669.0000,  -24.4091,
          875.6603],
        [  11.1815,  357.7877,  497.5670,  ..., -393.1017,  120.8533,
          599.5856],
        [  18.7194,  245.9380,  370.5876,  ..., -274.7961,   38.0122,
          416.9408]], device='cuda:0')


RuntimeError: 0D or 1D target tensor expected, multi-target not supported