In [None]:
# Install required libs
!sudo pip install -U segmentation-models-pytorch albumentations scikit-image monai livelossplot --user --upgrade

In [None]:
# !sudo pip install scikit-image

# Loading data

In [None]:
pwd

In [None]:
!git clone https://github.com/babbu3682/Multi-task-learning-tutorials.git 'test'

In [None]:
cd test

In [None]:
!cat SSIM_cls.tar.gz* | tar -zxvpf -
!cat SSIM_seg.tar.gz* | tar -zxvpf -

# Loading importing

In [1]:
!nvidia-smi

import os
import argparse
import datetime
import time
import cv2
import random
import torch
import numpy as np
import glob
import matplotlib.pyplot as plt


os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Fri Sep 16 19:28:28 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:18:00.0 Off |                  Off |
| 71%   73C    P8    28W / 300W |      3MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    Off  | 00000000:3B:00.0 Off |                  Off |
|100%   93C    P2   249W / 300W |  45210MiB / 48685MiB |     99%      Default |
|       

In [2]:
# fix random seeds for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

P: 75 (22.73%) N: 255 (77.27%) Total: 330
P: 14 (20.00%) N: 56  (80.00%) Total: 70
P: 21 (21.00%) N: 79  (79.00%) Total: 100

# Dataset

In [3]:
pwd

'/workspace/sunggu/7.KOHI/Multi_task_learning_tutorials'

In [4]:
DATA_DIR = './SSIM_cls/'

x_train_list = glob.glob(os.path.join(DATA_DIR, 'train/*'))
y_train_list = glob.glob(os.path.join(DATA_DIR, 'trainannot/*'))

x_valid_list = glob.glob(os.path.join(DATA_DIR, 'val/*'))
y_valid_list = glob.glob(os.path.join(DATA_DIR, 'valannot/*'))

x_test_list  = glob.glob(os.path.join(DATA_DIR, 'test/*'))
y_test_list  = glob.glob(os.path.join(DATA_DIR, 'testannot/*'))

In [5]:
x_train_list[0]

'./SSIM_cls/train/1.2.276.0.7230010.3.1.4.8323329.2775.1517875174.639059.png'

In [6]:
x_valid_list[0]

'./SSIM_cls/val/1.2.276.0.7230010.3.1.4.8323329.1764.1517875169.306622.png'

In [7]:
from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader


# Dataset Class
class Dataset(BaseDataset):    
    def __init__(self, images_list, labels_list, transform):
        self.images_list  = images_list
        self.labels_list  = labels_list
        self.transform    = transform
    
    def __getitem__(self, i):
        # read data
        image = cv2.imread(self.images_list[i], cv2.IMREAD_GRAYSCALE)
        label = np.load(self.labels_list[i])
        path  = self.images_list[i]

        # apply transform
        image = self.transform(image=image)['image']

        return image, label, path
        
    def __len__(self):
        return len(self.images_list)

In [8]:
from albumentations.pytorch.transforms import ToTensorV2
import albumentations as albu


def minmax_normalize(image, **kwargs):
    if len(np.unique(image)) != 1:  # Sometimes it cause the nan inputs...
        image = image.astype('float32')
        image -= image.min()
        image /= image.max() 
    return image


train_transform = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.ShiftScaleRotate(scale_limit=0.10, shift_limit=0.10, rotate_limit=15, p=0.5),
    albu.GaussNoise(p=0.2),
    albu.OneOf(
        [
            albu.CLAHE(p=1),
            albu.RandomBrightnessContrast(p=1),
            albu.RandomGamma(p=1),
        ],
        p=0.3,
    ),
    albu.OneOf(
        [
            albu.Blur(blur_limit=3, p=1),
            albu.MotionBlur(blur_limit=3, p=1),
        ],
        p=0.3,
    ),
    albu.Lambda(image=minmax_normalize, always_apply=True),
    ToTensorV2(),    
])

valid_transform = albu.Compose([        
    albu.Lambda(image=minmax_normalize, always_apply=True),
    ToTensorV2(),
])


In [9]:
dataset_train = Dataset(images_list=x_train_list, labels_list=y_train_list, transform=train_transform)
dataset_valid = Dataset(images_list=x_valid_list, labels_list=y_valid_list, transform=valid_transform)

data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=10, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)
data_loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=1,  num_workers=4, shuffle=True, pin_memory=True, drop_last=False)

##### Visulize

In [None]:
# # same image with different random transforms

# batch = next(iter(train_loader))
# x = batch['x'][0]
# y_seg = batch['y_seg'][0]
# y_cls = batch['y_cls'][0]

# print(x.shape,y_seg.shape,data_loader_train.shape)
# print(torch.unique(y_seg),y_cls)
# visualize(image=x, mask=y_seg)

# Model

In [10]:
from arch.smart_net import *
from losses import MTL_Loss

# Model
model        = STL_1_Net(encoder_name='resnet18').to('cuda')
# Loss
criterion    = MTL_Loss(name='STL_CLS')

# Optimizer & LR Schedule   
optimizer    = torch.optim.AdamW(params=model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of Learnable Params:', n_parameters)   

Number of Learnable Params: 11170753


# Loop

In [11]:
import json
import math
import torch
import utils
from torch import nn
import torch.nn.functional as F

from metrics import *
from losses import soft_dice_score
from sklearn.metrics import roc_auc_score

print_freq = 10
output_dir = './checkpoints/stl_cls/'
device     = 'cuda'

# Whole LOOP
for epoch in range(0, 200):
    
    ################################################################################################
    # Training 
    ################################################################################################
    
    model.train(True)
    metric_logger = utils.MetricLogger(delimiter="  ", n=10)
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    
    for batch_data in metric_logger.log_every(data_loader_train, print_freq, header):
        
        input  = batch_data[0].to(device).float()
        cls_gt = batch_data[1].to(device).float()

        cls_pred = model(input)

        loss, loss_detail = criterion(cls_pred=cls_pred, cls_gt=cls_gt)
        
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss_value)
        if loss_detail is not None:
            metric_logger.update(**loss_detail)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    
    train_stats = {k: round(meter.global_avg, 7) for k, meter in metric_logger.meters.items()}
    print("Averaged train_stats: ", train_stats)
    
    ################################################################################################
    # Validation
    ################################################################################################
    
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ", n=1)
    header = 'Valid:'

    total_cls_pred  = torch.tensor([])
    total_cls_true  = torch.tensor([])
    
    for batch_data in metric_logger.log_every(data_loader_valid, print_freq, header):
        
        input  = batch_data[0].to(device).float()
        cls_gt = batch_data[1].to(device).float()

        cls_pred = model(input)

        loss, loss_detail = criterion(cls_pred=cls_pred, cls_gt=cls_gt)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))

        # LOSS
        metric_logger.update(loss=loss_value)        
        if loss_detail is not None:
            metric_logger.update(**loss_detail)

        # post-processing
        cls_pred = torch.sigmoid(cls_pred)

        total_cls_pred  = torch.cat([total_cls_pred, cls_pred.detach().cpu()])
        total_cls_true  = torch.cat([total_cls_true, cls_gt.detach().cpu()])


    # Metric CLS
    auc            = roc_auc_score(y_true=total_cls_true, y_score=total_cls_pred)
    tp, fp, fn, tn = get_stats(total_cls_pred.round().long(), total_cls_true.long(), mode="binary")        
    f1             = f1_score(tp, fp, fn, tn, reduction="macro")
    acc            = accuracy(tp, fp, fn, tn, reduction="macro")
    sen            = sensitivity(tp, fp, fn, tn, reduction="macro")
    spe            = specificity(tp, fp, fn, tn, reduction="macro")

    metric_logger.update(auc=auc, f1=f1, acc=acc, sen=sen, spe=spe)          
    
    valid_stats = {k: round(meter.global_avg, 7) for k, meter in metric_logger.meters.items()}
    print("Averaged valid_stats: ", valid_stats)
    
    ################################################################################################
    # Save & Log
    ################################################################################################
    
    checkpoint_paths = output_dir + '/epoch_' + str(epoch) + '_checkpoint.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }, checkpoint_paths)

    log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                **{f'valid_{k}': v for k, v in valid_stats.items()},
                'epoch': epoch}

    if output_dir:
        with open(output_dir + "/log.txt", "a") as f:
            f.write(json.dumps(log_stats) + "\n")


Epoch: [0]  [ 0/80]  eta: 0:00:47  lr: 0.000100  loss: 0.7343 (0.7343)  CLS_Loss: 0.7343 (0.7343)  time: 0.5990  data: 0.4865  max mem: 1157
Epoch: [0]  [10/80]  eta: 0:00:06  lr: 0.000100  loss: 0.5632 (0.6660)  CLS_Loss: 0.5632 (0.6660)  time: 0.0885  data: 0.0444  max mem: 1288
Epoch: [0]  [20/80]  eta: 0:00:03  lr: 0.000100  loss: 0.5956 (0.7238)  CLS_Loss: 0.5956 (0.7238)  time: 0.0380  data: 0.0001  max mem: 1288
Epoch: [0]  [30/80]  eta: 0:00:02  lr: 0.000100  loss: 0.6437 (0.7090)  CLS_Loss: 0.6437 (0.7090)  time: 0.0386  data: 0.0001  max mem: 1288
Epoch: [0]  [40/80]  eta: 0:00:02  lr: 0.000100  loss: 0.5757 (0.6602)  CLS_Loss: 0.5757 (0.6602)  time: 0.0385  data: 0.0001  max mem: 1288
Epoch: [0]  [50/80]  eta: 0:00:01  lr: 0.000100  loss: 0.5803 (0.6484)  CLS_Loss: 0.5803 (0.6484)  time: 0.0385  data: 0.0001  max mem: 1288
Epoch: [0]  [60/80]  eta: 0:00:00  lr: 0.000100  loss: 0.5803 (0.6448)  CLS_Loss: 0.5803 (0.6448)  time: 0.0390  data: 0.0001  max mem: 1288
Epoch: [0]  [

Valid:  [ 99/100]  eta: 0:00:00  loss: 0.1578 (0.5418)  CLS_Loss: 0.1578 (0.5418)  time: 0.0057  data: 0.0001  max mem: 1387
Valid: Total time: 0:00:01 (0.0114 s / it)
Averaged valid_stats:  {'loss': 0.5417567, 'CLS_Loss': 0.5417567, 'auc': 0.6365132, 'f1': 0.0, 'acc': 0.76, 'sen': 0.0, 'spe': 1.0}
Epoch: [3]  [ 0/80]  eta: 0:00:34  lr: 0.000100  loss: 0.5023 (0.5023)  CLS_Loss: 0.5023 (0.5023)  time: 0.4252  data: 0.3909  max mem: 1387
Epoch: [3]  [10/80]  eta: 0:00:05  lr: 0.000100  loss: 0.4382 (0.4455)  CLS_Loss: 0.4382 (0.4455)  time: 0.0739  data: 0.0357  max mem: 1387
Epoch: [3]  [20/80]  eta: 0:00:03  lr: 0.000100  loss: 0.4382 (0.4984)  CLS_Loss: 0.4382 (0.4984)  time: 0.0387  data: 0.0001  max mem: 1387
Epoch: [3]  [30/80]  eta: 0:00:02  lr: 0.000100  loss: 0.4694 (0.5253)  CLS_Loss: 0.4694 (0.5253)  time: 0.0390  data: 0.0001  max mem: 1387
Epoch: [3]  [40/80]  eta: 0:00:01  lr: 0.000100  loss: 0.5447 (0.5340)  CLS_Loss: 0.5447 (0.5340)  time: 0.0391  data: 0.0001  max mem: 

Valid:  [ 60/100]  eta: 0:00:00  loss: 0.1936 (0.6383)  CLS_Loss: 0.1936 (0.6383)  time: 0.0107  data: 0.0001  max mem: 1387
Valid:  [ 70/100]  eta: 0:00:00  loss: 0.1548 (0.5717)  CLS_Loss: 0.1548 (0.5717)  time: 0.0109  data: 0.0001  max mem: 1387
Valid:  [ 80/100]  eta: 0:00:00  loss: 0.1567 (0.5961)  CLS_Loss: 0.1567 (0.5961)  time: 0.0108  data: 0.0001  max mem: 1387
Valid:  [ 90/100]  eta: 0:00:00  loss: 0.1812 (0.5939)  CLS_Loss: 0.1812 (0.5939)  time: 0.0106  data: 0.0001  max mem: 1387
Valid:  [ 99/100]  eta: 0:00:00  loss: 0.2101 (0.5780)  CLS_Loss: 0.2101 (0.5780)  time: 0.0105  data: 0.0001  max mem: 1387
Valid: Total time: 0:00:01 (0.0160 s / it)
Averaged valid_stats:  {'loss': 0.5779929, 'CLS_Loss': 0.5779929, 'auc': 0.6315789, 'f1': 0.0, 'acc': 0.76, 'sen': 0.0, 'spe': 1.0}
Epoch: [6]  [ 0/80]  eta: 0:00:35  lr: 0.000100  loss: 0.2102 (0.2102)  CLS_Loss: 0.2102 (0.2102)  time: 0.4437  data: 0.4075  max mem: 1387
Epoch: [6]  [10/80]  eta: 0:00:06  lr: 0.000100  loss: 0.45

Valid:  [ 20/100]  eta: 0:00:02  loss: 0.1561 (0.3735)  CLS_Loss: 0.1561 (0.3735)  time: 0.0103  data: 0.0001  max mem: 1387
Valid:  [ 30/100]  eta: 0:00:01  loss: 0.1896 (0.5235)  CLS_Loss: 0.1896 (0.5235)  time: 0.0104  data: 0.0001  max mem: 1387
Valid:  [ 40/100]  eta: 0:00:01  loss: 0.2253 (0.5484)  CLS_Loss: 0.2253 (0.5484)  time: 0.0105  data: 0.0001  max mem: 1387
Valid:  [ 50/100]  eta: 0:00:00  loss: 0.1962 (0.5682)  CLS_Loss: 0.1962 (0.5682)  time: 0.0106  data: 0.0001  max mem: 1387
Valid:  [ 60/100]  eta: 0:00:00  loss: 0.1933 (0.5081)  CLS_Loss: 0.1933 (0.5081)  time: 0.0108  data: 0.0001  max mem: 1387
Valid:  [ 70/100]  eta: 0:00:00  loss: 0.1429 (0.5244)  CLS_Loss: 0.1429 (0.5244)  time: 0.0096  data: 0.0001  max mem: 1387
Valid:  [ 80/100]  eta: 0:00:00  loss: 0.1588 (0.5511)  CLS_Loss: 0.1588 (0.5511)  time: 0.0075  data: 0.0001  max mem: 1387
Valid:  [ 90/100]  eta: 0:00:00  loss: 0.2551 (0.5305)  CLS_Loss: 0.2551 (0.5305)  time: 0.0062  data: 0.0001  max mem: 1387


Epoch: [11]  [70/80]  eta: 0:00:01  lr: 0.000100  loss: 0.5122 (0.5193)  CLS_Loss: 0.5122 (0.5193)  time: 0.0892  data: 0.0002  max mem: 1387
Epoch: [11]  [79/80]  eta: 0:00:00  lr: 0.000100  loss: 0.5017 (0.5264)  CLS_Loss: 0.5017 (0.5264)  time: 0.0858  data: 0.0001  max mem: 1387
Epoch: [11] Total time: 0:00:08 (0.1084 s / it)
Averaged train_stats:  {'lr': 0.0001, 'loss': 0.5263824, 'CLS_Loss': 0.5263824}
Valid:  [  0/100]  eta: 0:00:35  loss: 1.8008 (1.8008)  CLS_Loss: 1.8008 (1.8008)  time: 0.3570  data: 0.3385  max mem: 1387
Valid:  [ 10/100]  eta: 0:00:03  loss: 0.2339 (0.8611)  CLS_Loss: 0.2339 (0.8611)  time: 0.0436  data: 0.0309  max mem: 1387
Valid:  [ 20/100]  eta: 0:00:02  loss: 0.2585 (0.6458)  CLS_Loss: 0.2585 (0.6458)  time: 0.0122  data: 0.0002  max mem: 1387
Valid:  [ 30/100]  eta: 0:00:01  loss: 0.2702 (0.6043)  CLS_Loss: 0.2702 (0.6043)  time: 0.0120  data: 0.0002  max mem: 1387
Valid:  [ 40/100]  eta: 0:00:01  loss: 0.2702 (0.5411)  CLS_Loss: 0.2702 (0.5411)  time:

Epoch: [14]  [30/80]  eta: 0:00:06  lr: 0.000100  loss: 0.4497 (0.4893)  CLS_Loss: 0.4497 (0.4893)  time: 0.1048  data: 0.0002  max mem: 1387
Epoch: [14]  [40/80]  eta: 0:00:04  lr: 0.000100  loss: 0.4418 (0.4861)  CLS_Loss: 0.4418 (0.4861)  time: 0.0976  data: 0.0001  max mem: 1387
Epoch: [14]  [50/80]  eta: 0:00:03  lr: 0.000100  loss: 0.4255 (0.4897)  CLS_Loss: 0.4255 (0.4897)  time: 0.0679  data: 0.0001  max mem: 1387
Epoch: [14]  [60/80]  eta: 0:00:01  lr: 0.000100  loss: 0.4352 (0.4941)  CLS_Loss: 0.4352 (0.4941)  time: 0.0472  data: 0.0001  max mem: 1387
Epoch: [14]  [70/80]  eta: 0:00:00  lr: 0.000100  loss: 0.5323 (0.5050)  CLS_Loss: 0.5323 (0.5050)  time: 0.0528  data: 0.0001  max mem: 1387
Epoch: [14]  [79/80]  eta: 0:00:00  lr: 0.000100  loss: 0.5323 (0.5068)  CLS_Loss: 0.5323 (0.5068)  time: 0.0542  data: 0.0001  max mem: 1387
Epoch: [14] Total time: 0:00:07 (0.0879 s / it)
Averaged train_stats:  {'lr': 0.0001, 'loss': 0.5068171, 'CLS_Loss': 0.5068171}
Valid:  [  0/100]  e

KeyboardInterrupt: 

# Log check

In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt

def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

In [None]:
log_list = read_log(path = '/workspace/sunggu/6.Kakao/checkpoints/[Baseline]Face_Net_DAC_fold0_upsample/log.txt')
print(log_list[0].keys())
result_dict = {}

for key in log_list[0].keys():
    exec( "result_dict['"+str(key)+"']" + " = [ log_list[i]['"+str(key)+"'] for i in range(len(log_list)) ]")

for key in result_dict.keys():
    plt.plot(result_dict[key])
    plt.title(key)
    print("###########################################################")
    print("Argsort = ", np.argsort(result_dict[key])[:5])
    print("Value   = ", [result_dict[key][i] for i in np.argsort(result_dict[key])[:5]])
    plt.show()
    
    if key == 'valid_loss':
        print("Valid_Loss = ", np.argsort(result_dict[key])[:3])

In [None]:
# Log check

import glob
import numpy as np
import matplotlib.pyplot as plt

def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

51, 36, 35

log_list = read_log(path = '/workspace/sunggu/6.Kakao/checkpoints/[Baseline]Face_Net_DAC_fold0_upsample/log.txt')
print(log_list[0].keys())
result_dict = {}

for key in log_list[0].keys():
    exec( "result_dict['"+str(key)+"']" + " = [ log_list[i]['"+str(key)+"'] for i in range(len(log_list)) ]")

for key in result_dict.keys():
    plt.plot(result_dict[key])
    plt.title(key)
    print("###########################################################")
    print("Argsort = ", np.argsort(result_dict[key])[:5])
    print("Value   = ", [result_dict[key][i] for i in np.argsort(result_dict[key])[:5]])
    plt.show()
    
    if key == 'valid_loss':
        print("Valid_Loss = ", np.argsort(result_dict[key])[:3])

# TEST

In [None]:
# Resume

print("Loading... Resume")
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])        
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])        
args.start_epoch = checkpoint['epoch'] + 1  
try:
    log_path = os.path.dirname(args.resume)+'/log.txt'
    lines    = open(log_path,'r').readlines()
    val_loss_list = []
    for l in lines:
        exec('log_dict='+l.replace('NaN', '0'))
        val_loss_list.append(log_dict['valid_loss'])
    print("Epoch: ", np.argmin(val_loss_list), " Minimum Val Loss ==> ", np.min(val_loss_list))