#### TorchLightening Script for project 2

https://lightning.ai/docs/pytorch/stable/data/datamodule.html

- 13, March, 2024
- By Jack Li


    IMPORT BASIC PACKAGES

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import h5py
h5py._errors.unsilence_errors()
from mpl_toolkits.axes_grid1 import make_axes_locatable


# SUGGESTION: create all folders for storing results
if not os.path.exists('./vis'):
    os.mkdir('./vis')

if not os.path.exists('./vis_results'):
    os.mkdir('./vis_results')

if not os.path.exists('./model256_weights'):
    os.mkdir('./model256_weights')


    Import Lightning: Import the necessary modules from PyTorch Lightning

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split


import lightning as L

import albumentations as albu

import segmentation_models_pytorch as smp
import numpy as np
from datetime import datetime
from tqdm.notebook import tqdm

import torch.nn.functional as F
import pandas as pd

from itertools import product
import h5py


from torch.utils.data import RandomSampler

    Define LightningModule: Create a LightningModule class that inherits from pl.LightningModule. This class will contain your model architecture and training logic.



In [3]:
class MyLightningModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = smp.Unet(
            encoder_name='resnet34',
            encoder_weights=None,
            in_channels=4,
            classes=1,
            activation='sigmoid'
        )
        self.l2_loss = torch.nn.MSELoss()
        self.l1_loss = torch.nn.L1Loss()
        
        #save all hyperparameters
        self.save_hyperparameters()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        preds = self.model(imgs).squeeze()
        loss = self.l2_loss(preds, masks)
        
        self.log('train_loss', loss, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        preds = self.model(imgs).squeeze()
        val_loss = self.l2_loss(preds, masks)
        
        self.log('val_loss', val_loss, prog_bar=True)
        
        return val_loss
    
    def test_step(self, batch, batch_idx):
        
        imgs, masks = batch
        preds = self.model(imgs).squeeze()
        test_loss = self.l2_loss(preds, masks)
        
        self.log("test_loss", test_loss, prog_bar=True)
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam([ dict(params=self.model.parameters(), lr=5e-4),])
        
        return optimizer




    Define LightningDataModule: If you're using custom data loaders, create a LightningDataModule class that inherits from pl.LightningDataModule. This class will contain your data loading logic.

In [4]:

from dataset import MyDataset

    
class MyDataModule(L.LightningDataModule):
    def __init__(self, augmentation=None, preprocessing=None, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.n_training_samples = 10
        self.n_valid_samples = 2
        self.n_test_samples = 4

        
    def setup(self, stage=None):

        #get the file names
        permutations = list(product(range(4), repeat=2))
        file_list = []
        properties_list = []
        for idx1, idx2 in permutations:
            file_name = f'256modelruns/Pe1_K1_{idx1}_{idx2}.hdf5'
            file_list.append(file_name)


        # # set breakpoint
        # import pdb
        #pdb.set_trace()
        self.example_dataset = MyDataset(file_list[:3],self.augmentation[0], self.preprocessing)
        
        self.train_dataset = MyDataset(file_list[:self.n_training_samples],self.augmentation[0], self.preprocessing)
        
        self.val_dataset = MyDataset(file_list[self.n_training_samples : self.n_training_samples+self.n_valid_samples],self.augmentation[1], self.preprocessing)
        
        self.test_dataset = MyDataset(file_list[-self.n_test_samples :], self.augmentation[2], self.preprocessing)
     
    
                      
    def train_dataloader(self):
        train_sampler = RandomSampler(self.train_dataset, replacement=True, num_samples=10000) 
        return DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=train_sampler, num_workers=2, drop_last=True, persistent_workers=True) # 

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=2, shuffle=False, drop_last=True, persistent_workers=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=2,persistent_workers=True)


    Training Loop with Trainer: Create a pl.Trainer object and use it to train your LightningModule.

* from commandline, type tensorboard --logdir=lightning_logs/

If you’re using a notebook environment such as colab or kaggle or jupyter, launch Tensorboard with this command

%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [5]:


def get_training_augmentation():
    train_transform = [
        albu.Resize(256, 256),  # not needed
        # albu.HorizontalFlip(p=0.5),
        # albu.VerticalFlip(p=0.5),
    ]
    return albu.Compose(train_transform)

def get_validation_augmentation():
    """Resize to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(256, 256),  
    ]
    return albu.Compose(test_transform)

In [6]:
from lightning.pytorch.callbacks import Callback

class VisualizationCallback(Callback):
    def on_setup(self, trainer, pl_module, datamodule):
        for idx_ in range(4):
            current_timestep = 10 * idx_
            print(f'plotting for time step: {current_timestep}')
            image, mask = datamodule.example_dataset[current_timestep]  # get some sample
            self.visualize_tensorboard(trainer, pl_module, image, mask, current_timestep)

    def visualize_tensorboard(self, trainer, pl_module, image, mask, current_timestep):
        # Add image to TensorBoard
        concentration = image[0, :, :].squeeze()
        eps = image[1, :, :].squeeze()
        Ux = image[2, :, :].squeeze()
        Uy = image[3, :, :].squeeze()
        dissolution = mask.squeeze()

        # You need to add code here to convert these tensors to numpy arrays for visualization
        # For example, you can use image.cpu().numpy()

        # Assuming you have converted tensors to numpy arrays, you can add them to TensorBoard like this:
        trainer.logger.experiment.add_images(
            f'Visualization/Time_Step_{current_timestep}',
            torch.stack([concentration, eps, Ux, Uy, dissolution], dim=0),
            global_step=trainer.global_step,
        )


In [7]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks import DeviceStatsMonitor
from lightning.pytorch.profilers import AdvancedProfiler

from pre_processing import get_preprocessing

augumentations=[get_training_augmentation(), get_validation_augmentation(), get_validation_augmentation()]
# if __name__ == "__main__":

model = MyLightningModel()

data_module = MyDataModule(augmentation=augumentations, preprocessing=get_preprocessing())



# profiler = AdvancedProfiler(dirpath=".", filename="lightning_logs/perf_logs")
# 
#profiler=profiler, default_root_dir='/Users/captainjack/Desktop/CO2_Storage_Jack/'

#consider trying this https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html

#fast_dev_run=True,

trainer = L.Trainer(max_epochs=20, profiler="advanced", \
                     callbacks=[VisualizationCallback(), DeviceStatsMonitor(), ModelSummary(max_depth=-1), TQDMProgressBar(refresh_rate=10), EarlyStopping(monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="max")])

# check validation before large training step
#num_sanity_val_steps=2, 


# 
# trainer = Trainer(callbacks=[DeviceStatsMonitor()])

trainer.fit(model, data_module)


# test the model 
trainer.test(model, data_module) 


Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024-03-14 00:43:43.829408: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

    | Name                                        | Type             | Params
-----------------------------------------------------------------------------------
0   | model                                       | Unet             | 24.4 M
1   | model.encoder                               | ResNetEncoder    | 21.3 M
2   | model.encoder.conv1                    

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

#### data plot

* maybe call back function to do the plotting!

In [None]:

def visualize(**images):
    """Plot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

for idx_ in range(4):
    current_timestep = 10*idx_
    print(f'plotting for time step: {current_timestep}')
    image, mask = data_module.example_dataset[current_timestep] # get some sample
    visualize(
        concentration=image[0,:, :].squeeze(),
        eps=image[1,:, :].squeeze(),
        Ux=image[2,:, :].squeeze(),
        Uy=image[3,:, :].squeeze(),
        dissolution=mask.squeeze(),
    )

# to be plotted 

In [None]:
# def matshow_error(pred, truth, figsize=(40, 18), scale=False, title=None, filename=None):
#     fig, ax = plt.subplots(1, 3, figsize=figsize)
    
#     v_max = max(truth.max(), pred.max())
#     v_min = max(truth.min(), pred.min())

#     if scale:
#         im = ax[0].matshow(pred, vmin=0, vmax=1, cmap=plt.get_cmap('Reds'))# 'inferno_r'))
#     else:
#         im = ax[0].matshow(pred, vmin=v_min, vmax=v_max, cmap=plt.get_cmap('Reds'))# 'inferno_r'))
#     # im.set_clim(0.0, 0.3)
#     ax[0].set_title(f'{title} prediction')
#     divider = make_axes_locatable(ax[0])
#     cax = divider.append_axes("right", size="5%", pad=0.05)
#     cbar = plt.colorbar(im, cax=cax)

#     if scale:
#         im = ax[1].matshow(truth, vmin=0, vmax=1, cmap=plt.get_cmap('Reds'))# 'inferno_r'))
#     else:
#         im = ax[1].matshow(truth, vmin=v_min, vmax=v_max, cmap=plt.get_cmap('Reds'))# 'inferno_r'))
#     # im.set_clim(0.0, 0.3)
#     ax[1].set_title(f'{title} reference')
#     divider = make_axes_locatable(ax[1])
#     cax = divider.append_axes("right", size="5%", pad=0.05)
#     plt.colorbar(im, cax=cax)

#     # error = np.abs(pred-truth)
#     error = pred-truth

#     im = ax[2].matshow(error, cmap=plt.get_cmap('seismic')) #.get_cmap('RdGy'))
#     max_abs_error = np.max(np.abs(error))
#     # Set the color limits dynamically centered around zero
#     clim = (-max_abs_error, max_abs_error)
#     im.set_clim(clim)

#     ax[2].set_title(f'{title} error')
#     divider = make_axes_locatable(ax[2])
#     cax = divider.append_axes("right", size="5%", pad=0.05)
#     plt.colorbar(im, cax=cax)
#     plt.tight_layout()
#     if filename is not None:
#         plt.savefig(filename)
#     plt.show()

# for sample_idx in range(1): #12):
#     for time_step in [0, 1, 3, 7, 10, 20, 40, 60, 90, 99]:
#         preds = preds_list_train[sample_idx*100+time_step, :, :]
#         masks = masks_list_train[sample_idx*100+time_step, :, :]
#         # matshow2(scaling_func(preds), scaling_func(masks), title=f'train sample: {sample_idx}, scaled prediction eps', filename='original_eps.pdf')
#         matshow_error(
#             preds,
#             masks, 
#             title=f'train sample: {sample_idx}, timestep: {time_step}, eps: ', 
#             filename=f'vis_results/training_eps_{sample_idx}_{time_step}.pdf',
#             figsize=(15, 7))
        
# for sample_idx in range(1): #4):
#     for time_step in [0, 1, 3, 7, 10, 20, 40, 60, 90, 99]:
#         preds = preds_list_val[sample_idx*100+time_step, :, :]
#         masks = masks_list_val[sample_idx*100+time_step, :, :]
#         # matshow2(scaling_func(preds), scaling_func(masks), title=f'validation sample: {sample_idx}, scaled prediction eps', filename='original_eps.pdf')
#         matshow_error(
#             preds,
#             masks,
#             title=f'validation sample: {sample_idx}, timestep: {time_step}, eps: ', 
#             filename=f'vis_results/validation_eps_{sample_idx}_{time_step}.pdf',
#             figsize=(15, 7))