In [1]:
import os

In [2]:
!rm -r ssl-medical-imaging
!git clone -b arash/training https://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@github.com/naga-karthik/ssl-medical-imaging

Cloning into 'ssl-medical-imaging'...
remote: Enumerating objects: 289, done.[K
remote: Counting objects: 100% (289/289), done.[K
remote: Compressing objects: 100% (201/201), done.[K
remote: Total 289 (delta 148), reused 177 (delta 80), pack-reused 0[K
Receiving objects: 100% (289/289), 3.34 MiB | 5.90 MiB/s, done.
Resolving deltas: 100% (148/148), done.


In [3]:
!pip -q install -r /content/ssl-medical-imaging/requirements.txt -f https://download.pytorch.org/whl/torch_stable.html

In [4]:
if not os.path.isdir('/content/ACDC'):
  !gdown --id 1-DAdhFAG-N57YW_UZEsN2Yz2PugvgvxP
  !unzip -q ACDC.zip

In [5]:
if not os.path.isdir('/content/Task05_Prostate'):
  !gdown --id 1F6zonQztBaNg8SX0rdhWdUDnH03tmuTY
  !unzip -q ./Task05_Prostate.zip

In [6]:
%%writefile /content/ssl-medical-imaging/supervised_train.py
# utility packages
import os
import time
import argparse

import numpy as np
import matplotlib.pyplot as plt
timestamp = time.time()

# machine learning packages
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# dataloaders and segmentation models
from seg_models import SegUnetFullModel, SegUnetEncoder_and_ProjectorG1, SegUnetDecoder
from Dataloader.init_data import acdc, md_prostate
from Dataloader.dataloader import DataloaderRandom
from Dataloader.experiments_paper import data_init_acdc, data_init_prostate_md
from loss import Loss

img_path = "/home/GRAMES.POLYMTL.CA/u114716/ssl_project/datasets/ACDC"
seg_path = "/home/GRAMES.POLYMTL.CA/u114716/ssl_project/datasets/ACDC"

parser = argparse.ArgumentParser(description="Supervised Strategy")

# all the arguments for the dataset, model, and training hyperparameters
parser.add_argument('--exp_name', default='ce loss test1', type=str, help='Name of the experiment/run')
# dataset
parser.add_argument('-data', '--dataset', default=acdc, help='Specifyg acdc or md_prostate without quotes')
parser.add_argument('-nti', '--num_train_imgs', default='tr52', type=str, help='Number of training images, options tr1, tr8 or tr52')
parser.add_argument('-cti', '--comb_train_imgs', default='c1', type=str, help='Combintation of Train imgs., options c1, c2, cr3, cr4, cr5')
parser.add_argument('--img_path', default=img_path, type=str, help='Absolute path of the training data')
parser.add_argument('--seg_path', default=seg_path, type=str, help='Same as path of training data')
# model
parser.add_argument('-in_ch', '--in_channels', default=1, type=int, help='Number of input channels')
parser.add_argument('-num_flt', '--num_filters_list', nargs='+', default=[1, 16, 32, 64, 128, 256], help='List containing no. of filters for Conv Layers')
parser.add_argument('-num_fc', '--fc_units_list', nargs='+', default=[3200, 1024], help='List containing no. of units in FC layers')
parser.add_argument('-g1_dim', '--g1_out_dim', default=128, type=int, help='Output dimension for the projector head')
parser.add_argument('-nc', '--num_classes', default=4, type=int, help='Number of classes to segment')
# optimization
parser.add_argument('-p', '--precision', default=32, type=int, help='Precision for training')
parser.add_argument('-ep', '--epochs', default=100, type=int, help='Number of epochs to train')
parser.add_argument('-bs', '--batch_size', default=256, type=int, help='Batch size')
parser.add_argument('-nw', '--num_workers', default=4, type=int, help='Number of worker processes')
parser.add_argument('-gpus', '--num_gpus', default=1, type=int, help="Number of GPUs to use")
parser.add_argument('-lr', '--learning_rate', default=5e-4, type=float, help="Learning rate to use")
parser.add_argument('-wd', '--weight_decay', default=1e-3, type=float, help='Default weight decay')

cfg = parser.parse_args()

class SegModel(pl.LightningModule):
    def __init__(self, cfg):
        super(SegModel, self).__init__()
        self.cfg = cfg
        self.net = SegUnetFullModel(
            in_channels=self.cfg.in_channels, 
            num_filters_list=self.cfg.num_filters_list,
            fc_units=self.cfg.fc_units_list,
            g1_out_dim=self.cfg.g1_out_dim, 
            num_classes=self.cfg.num_classes
        )

        self.train_ids_acdc = data_init_acdc.train_data(self.cfg.num_train_imgs, self.cfg.comb_train_imgs)
        self.val_ids_acdc = data_init_acdc.val_data(self.cfg.num_train_imgs, self.cfg.comb_train_imgs)
        self.test_ids_acdc = data_init_acdc.test_data()

        self.train_dataset = DataloaderRandom(self.cfg.dataset, self.train_ids_acdc, self.cfg.img_path,
                                              preprocessed_data=True, seg_path=self.cfg.seg_path)
        self.valid_dataset = DataloaderRandom(self.cfg.dataset, self.val_ids_acdc, self.cfg.img_path,
                                              preprocessed_data=True, seg_path=self.cfg.seg_path)
        self.test_dataset = DataloaderRandom(self.cfg.dataset, self.test_ids_acdc, self.cfg.img_path,
                                             preprocessed_data=True, seg_path=self.cfg.seg_path)

        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.net(x)[1] # returns the softmax
      
    def compute_loss(self, batch):
        imgs, gts = batch
        imgs, gts = imgs.float(), gts.long()
        logits, preds = self.net(imgs)
        # print(torch.unique(gts), gts.shape)

        loss = self.ce_loss(logits, gts.squeeze(1))
        return loss, preds, imgs, gts
    
    def training_step(self, batch, batch_nb):
        loss, preds, imgs, gts = self.compute_loss(batch)
        self.log('train_loss', loss, on_step=False, on_epoch=True)

        if batch_nb == 0: # once per epoch
            fig = visualize(preds, imgs, gts)
            wandb.log({"Training Output Visualizations": fig})
        return loss

    def validation_step(self, batch, batch_nb):
        loss, preds, imgs, gts = self.compute_loss(batch)
        self.log('valid_loss', loss, on_step=False, on_epoch=True)

        if batch_nb == 0: # once per epoch
            fig = visualize(preds, imgs, gts)
            wandb.log({"Validation Output Visualizations": fig})
    
    def test_step(self, batch, batch_nb):
        loss, preds, imgs, gts = self.compute_loss(batch)
        self.log('test_loss', loss, on_step=False, on_epoch=True)

        # qualitative results on wandb
        fig = visualize(preds, imgs, gts)
        wandb.log({"Test Output Visualizations": fig})

    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), lr=self.cfg.learning_rate, weight_decay=self.cfg.weight_decay)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=40, eta_min=1e-5)
        
        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size = self.cfg.batch_size,
                             shuffle = True, drop_last=True, num_workers=self.cfg.num_workers)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size = self.cfg.batch_size,
                             shuffle = False, drop_last=False, num_workers=self.cfg.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = self.cfg.batch_size,
                             shuffle = False, drop_last=False, num_workers=self.cfg.num_workers)
        
def visualize(preds, imgs, gts, num_imgs=10):
    main_colors = torch.tensor([
                              [0, 0, 0],
                              [1, 0, 0],
                              [0, 1, 0],
                              [0, 0, 1],
                              [1, 1, 0],
                              [0, 1, 1],
                              [1, 0, 1],
                              [1, 1, 1]
                                ]).view(8, 3).float()
    # getting ready for post processing
    imgs, gts, preds = imgs.detach().cpu(), gts.detach().cpu(), preds.detach().cpu()
    imgs = imgs.squeeze(dim=1).numpy()
    gts = gts.squeeze(dim=1)

    num_classes = preds.shape[1]
    colors = main_colors[:num_classes]
    # coloring the predictions
    preds[preds < torch.max(preds, dim=1, keepdims=True)[0]] = 0
    preds_colored = torch.tensordot(preds, colors, dims=[[1], [0]]).numpy()
    # coloring the ground truth masks
    gts_onehot = F.one_hot(gts, num_classes=num_classes).permute(0, 3, 1, 2)
    gts_colored = torch.tensordot(gts_onehot.float(), colors, dims=[[1], [0]]).numpy()

    fig, axs = plt.subplots(3, num_imgs, figsize=(9, 3))
    fig.suptitle('Original --> Ground Truth --> Prediction')
    for i in range(num_imgs):
        img_num = np.random.randint(0, len(imgs))
        axs[0, i].imshow(imgs[img_num], cmap='gray'); axs[0, i].axis('off') 
        axs[1, i].imshow(gts_colored[img_num]); axs[1, i].axis('off')    
        axs[2, i].imshow(preds_colored[img_num]); axs[2, i].axis('off')
    fig.show()
    return fig

def main(cfg):
    # experiment tracker (you need to sign in with your account)
    wandb_logger = pl.loggers.WandbLogger(
                            name='%s <- %d'%(cfg.exp_name, timestamp), 
                            group= '%s'%(cfg.exp_name), 
                            log_model=True, # save best model using checkpoint callback
                            project='supervised-train-arash',
                            entity='ssl-medical-imaging',
                            config=cfg,
    )

    # to save the best model on validation
    checkpoint = pl.callbacks.ModelCheckpoint(
        filename="best_model"+str(timestamp),
        monitor="valid_loss",
        save_top_k=1,
        mode="min",
        save_last=False,
        save_weights_only=True,
    )
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')

    trainer = pl.Trainer(
        devices=cfg.num_gpus, 
        accelerator="gpu", strategy="ddp",
        logger=wandb_logger,
        callbacks=[checkpoint, lr_monitor],
        max_epochs=cfg.epochs,
        precision=cfg.precision,
    )

    model = SegModel(cfg)
    
    trainer.fit(model)
    print("------- Training Done! -------")

    print("------- Testing Begins! -------")
    trainer.test(model)

if __name__ == '__main__':
    main(cfg)

Overwriting /content/ssl-medical-imaging/supervised_train.py


In [None]:
!python /content/ssl-medical-imaging/supervised_train.py --img_path /content/ACDC --seg_path /content/ACDC -ep 500

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
{'name': 'ACDC', 'dimension': (192, 192), 'resolution': (1.367, 1.367), 'num_class': 4} final shape (487, 2, 192, 192)
{'name': 'ACDC', 'dimension': (192, 192), 'resolution': (1.367, 1.367), 'num_class': 4} final shape (90, 2, 192, 192)
{'name': 'ACDC', 'dimension': (192, 192), 'resolution': (1.367, 1.367), 'num_class': 4} final shape (182, 2, 192, 192)
initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33marash0ash[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Tracking run with