#### TorchLightening Script for understanding Task 2

https://lightning.ai/docs/pytorch/stable/data/datamodule.html

- 13, March, 2024
- By Jack Li


    IMPORT BASIC PACKAGES

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import h5py
h5py._errors.unsilence_errors()
from mpl_toolkits.axes_grid1 import make_axes_locatable


# SUGGESTION: create all folders for storing results
if not os.path.exists('./vis'):
    os.mkdir('./vis')

if not os.path.exists('./vis_results'):
    os.mkdir('./vis_results')

if not os.path.exists('./model256_weights'):
    os.mkdir('./model256_weights')


    Import Lightning: Import the necessary modules from PyTorch Lightning

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl

import albumentations as albu


    Define LightningModule: Create a LightningModule class that inherits from pl.LightningModule. This class will contain your model architecture and training logic.



In [3]:
import torch
import segmentation_models_pytorch as smp
import numpy as np
from datetime import datetime
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd
import pytorch_lightning as pl

class MyLightningModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = smp.Unet(
            encoder_name='resnet34',
            encoder_weights=None,
            in_channels=4,
            classes=1,
            activation='sigmoid'
        )
        self.l2_loss = torch.nn.MSELoss()
        
        
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        preds = self(imgs).squeeze()
        loss = self.l2_loss(preds, masks)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        preds = self(imgs).squeeze()
        val_loss = self.l2_loss(preds, masks)
        self.log('val_loss', val_loss, prog_bar=True)
        return val_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer




    Define LightningDataModule: If you're using custom data loaders, create a LightningDataModule class that inherits from pl.LightningDataModule. This class will contain your data loading logic.

In [4]:
from itertools import product
import h5py
import numpy as np
from torch.utils.data import DataLoader, Dataset, RandomSampler
import pytorch_lightning as pl

from torch.utils.data import Dataset
from dataset import MyDataset

    
class MyDataModule(pl.LightningDataModule):
    def __init__(self, augmentation=None, preprocessing=None, batch_size=4):
        super().__init__()
        self.batch_size = batch_size
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.n_training_samples = 12

    def setup(self, stage=None):

        #get the file names
        permutations = list(product(range(4), repeat=2))
        file_list = []
        properties_list = []
        for idx1, idx2 in permutations:
            file_name = f'256modelruns/Pe1_K1_{idx1}_{idx2}.hdf5'
            file_list.append(file_name)

        
        self.train_dataset = MyDataset(file_list[:self.n_training_samples],self.augmentation[0], self.preprocessing)
        
        # print(len(self.train_dataset)) 
        # print(len(self.train_dataset[0])) 
              
        self.val_dataset = MyDataset(file_list[self.n_training_samples:],self.augmentation[1], self.preprocessing)
        
        # print( len(self.val_dataset) ) 
        # print(len(self.val_dataset[0]) )  
              
    def train_dataloader(self):
        #train_sampler = RandomSampler(self.train_dataset, replacement=True, num_samples=10000) 
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1, drop_last=True) # sampler=train_sampler

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=4, num_workers=1, shuffle=False, drop_last=True)
    
    



    Training Loop with Trainer: Create a pl.Trainer object and use it to train your LightningModule.



In [5]:




def get_training_augmentation():
    train_transform = [
        albu.Resize(256, 256),  # not needed
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),
    ]
    return albu.Compose(train_transform)

def get_validation_augmentation():
    """Resize to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(256, 256),  
    ]
    return albu.Compose(test_transform)

In [6]:

import torch
import numpy as np
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
from torch.utils.data import RandomSampler
from datetime import datetime

import torch
import segmentation_models_pytorch as smp
from tqdm import tqdm
import pandas as pd


from pre_processing import get_preprocessing

# if __name__ == "__main__":
model = MyLightningModel()
data_module = MyDataModule(augmentation=[get_training_augmentation(), get_validation_augmentation()], preprocessing=get_preprocessing())


trainer = pl.Trainer(max_epochs=20)
trainer.fit(model, data_module)




GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024-03-13 16:02:39.284440: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

  | Name    | Type    | Params
------------------------------------
0 | model   | Unet    | 24.4 M
1 | l2_loss | MSELoss | 0     
------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.758    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/captainjack/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/captainjack/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

/Users/captainjack/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
