# Loading importing

In [1]:
!nvidia-smi

import os
import argparse
import datetime
import time
import cv2
import random
import torch
import numpy as np
import glob
import matplotlib.pyplot as plt


os.environ['CUDA_VISIBLE_DEVICES'] = '2'

Fri Sep 16 19:29:21 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:18:00.0 Off |                  Off |
| 70%   89C    P2   246W / 300W |  11154MiB / 48685MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    Off  | 00000000:3B:00.0 Off |                  Off |
|100%   85C    P2   119W / 300W |  45210MiB / 48685MiB |      0%      Default |
|       

In [2]:
# fix random seeds for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

P: 75 (22.73%) N: 255 (77.27%) Total: 330
P: 14 (20.00%) N: 56  (80.00%) Total: 70
P: 21 (21.00%) N: 79  (79.00%) Total: 100

# Dataset

In [3]:
pwd

'/workspace/sunggu/7.KOHI/Multi_task_learning_tutorials'

In [4]:
DATA_DIR = './SSIM_seg/'

x_train_list = glob.glob(os.path.join(DATA_DIR, 'train/*'))
y_train_list = glob.glob(os.path.join(DATA_DIR, 'trainannot/*'))

x_valid_list = glob.glob(os.path.join(DATA_DIR, 'val/*'))
y_valid_list = glob.glob(os.path.join(DATA_DIR, 'valannot/*'))

x_test_list  = glob.glob(os.path.join(DATA_DIR, 'test/*'))
y_test_list  = glob.glob(os.path.join(DATA_DIR, 'testannot/*'))

In [5]:
x_train_list[0]

'./SSIM_seg/train/1.2.276.0.7230010.3.1.4.8323329.300.1517875162.258081.png'

In [6]:
x_valid_list[0]

'./SSIM_seg/val/1.2.276.0.7230010.3.1.4.8323329.1173.1517875166.626582.png'

In [7]:
from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader


# Dataset Class
class Dataset(BaseDataset):    
    def __init__(self, images_list, labels_list, transform):
        self.images_list  = images_list
        self.labels_list  = labels_list
        self.transform    = transform
    
    def __getitem__(self, i):
        # read data
        image = cv2.imread(self.images_list[i], cv2.IMREAD_GRAYSCALE)
        mask  = cv2.imread(self.labels_list[i], cv2.IMREAD_GRAYSCALE)
        path  = self.images_list[i]
        
        mask  = np.expand_dims(mask, axis=0)
            
        # apply transform
        sample = self.transform(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']        
        
        return image, mask, path
        
    def __len__(self):
        return len(self.images_list)

In [8]:
from albumentations.pytorch.transforms import ToTensorV2
import albumentations as albu


def minmax_normalize(image, **kwargs):
    if len(np.unique(image)) != 1:  # Sometimes it cause the nan inputs...
        image = image.astype('float32')
        image -= image.min()
        image /= image.max() 
    return image


train_transform = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.ShiftScaleRotate(scale_limit=0.10, shift_limit=0.10, rotate_limit=15, p=0.5),
    albu.GaussNoise(p=0.2),
    albu.OneOf(
        [
            albu.CLAHE(p=1),
            albu.RandomBrightnessContrast(p=1),
            albu.RandomGamma(p=1),
        ],
        p=0.3,
    ),
    albu.OneOf(
        [
            albu.Blur(blur_limit=3, p=1),
            albu.MotionBlur(blur_limit=3, p=1),
        ],
        p=0.3,
    ),
    albu.Lambda(image=minmax_normalize, always_apply=True),
    ToTensorV2(),    
])

valid_transform = albu.Compose([        
    albu.Lambda(image=minmax_normalize, always_apply=True),
    ToTensorV2(),
])


In [9]:
dataset_train = Dataset(images_list=x_train_list, labels_list=y_train_list, transform=train_transform)
dataset_valid = Dataset(images_list=x_valid_list, labels_list=y_valid_list, transform=valid_transform)

data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=10, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)
data_loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=1,  num_workers=4, shuffle=True, pin_memory=True, drop_last=False)

In [None]:
# # same image with different random transforms

# batch = next(iter(train_loader))
# x = batch['x'][0]
# y_seg = batch['y_seg'][0]
# y_cls = batch['y_cls'][0]

# print(x.shape,y_seg.shape,y_cls.shape)
# print(torch.unique(y_seg),y_cls)
# visualize(image=x, mask=y_seg)

# Model

In [10]:
from arch.smart_net import *
from losses import MTL_Loss

# Model
model        = MTL_2_Net(encoder_name='resnet18').to('cuda')     

# Loss
criterion    = MTL_Loss(name='MTL_2')

# Optimizer & LR Schedule   
optimizer    = torch.optim.AdamW(params=model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of Learnable Params:', n_parameters)   

Number of Learnable Params: 18077613


# Loop

In [None]:
import json
import math
import torch
import utils
from torch import nn
import torch.nn.functional as F

from metrics import *
from losses import soft_dice_score
from sklearn.metrics import roc_auc_score

print_freq = 10
output_dir = './checkpoints/mtl_2/'
device     = 'cuda'

# Whole LOOP
for epoch in range(0, 200):
    
    ################################################################################################
    # Training 
    ################################################################################################
    
    model.train(True)
    metric_logger = utils.MetricLogger(delimiter="  ", n=10)
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    
    for batch_data in metric_logger.log_every(data_loader_train, print_freq, header):
        
        input   = batch_data[0].to(device).float()
        seg_gt  = batch_data[1].to(device).float()
        cls_gt  = seg_gt.flatten(1).bool().any(dim=1).float().unsqueeze(1)
        
        cls_pred, seg_pred, rec_pred = model(input)

        loss, loss_detail = criterion(cls_pred=cls_pred, seg_pred=seg_pred, rec_pred=rec_pred, cls_gt=cls_gt, seg_gt=seg_gt, rec_gt=input)
        
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss_value)
        if loss_detail is not None:
            metric_logger.update(**loss_detail)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    
    train_stats = {k: round(meter.global_avg, 7) for k, meter in metric_logger.meters.items()}
    print("Averaged train_stats: ", train_stats)
    
    ################################################################################################
    # Validation
    ################################################################################################
    
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ", n=1)
    header = 'Valid:'

    total_cls_pred  = torch.tensor([])
    total_cls_true  = torch.tensor([])
    
    for batch_data in metric_logger.log_every(data_loader_valid, print_freq, header):
        
        input   = batch_data[0].to(device).float()
        seg_gt  = batch_data[1].to(device).float()
        cls_gt  = seg_gt.flatten(1).bool().any(dim=1).float().unsqueeze(1)

        cls_pred, seg_pred, rec_pred = model(input)

        loss, loss_detail = criterion(cls_pred=cls_pred, seg_pred=seg_pred, rec_pred=rec_pred, cls_gt=cls_gt, seg_gt=seg_gt, rec_gt=input)
    
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))

        # LOSS
        metric_logger.update(loss=loss_value)        
        if loss_detail is not None:
            metric_logger.update(**loss_detail)

        # Post-processing
        cls_pred = torch.sigmoid(cls_pred)
        seg_pred = torch.sigmoid(seg_pred)

        total_cls_pred  = torch.cat([total_cls_pred, cls_pred.detach().cpu()])
        total_cls_true  = torch.cat([total_cls_true, cls_gt.detach().cpu()])

        # Metrics SEG
        if seg_gt.any():
            dice = soft_dice_score(output=seg_pred.round(), target=seg_gt, smooth=0.0)    # pred_seg must be round() !! 
            metric_logger.update(dice=dice.item())     

    # Metric CLS
    auc            = roc_auc_score(y_true=total_cls_true, y_score=total_cls_pred)
    tp, fp, fn, tn = get_stats(total_cls_pred.round().long(), total_cls_true.long(), mode="binary")        
    f1             = f1_score(tp, fp, fn, tn, reduction="macro")
    acc            = accuracy(tp, fp, fn, tn, reduction="macro")
    sen            = sensitivity(tp, fp, fn, tn, reduction="macro")
    spe            = specificity(tp, fp, fn, tn, reduction="macro")

    metric_logger.update(auc=auc, f1=f1, acc=acc, sen=sen, spe=spe)          
    
    valid_stats = {k: round(meter.global_avg, 7) for k, meter in metric_logger.meters.items()}
    print("Averaged valid_stats: ", valid_stats)
    
    ################################################################################################
    # Save & Log
    ################################################################################################
    
    checkpoint_paths = output_dir + '/epoch_' + str(epoch) + '_checkpoint.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }, checkpoint_paths)

    log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                **{f'valid_{k}': v for k, v in valid_stats.items()},
                'epoch': epoch}

    if output_dir:
        with open(output_dir + "/log.txt", "a") as f:
            f.write(json.dumps(log_stats) + "\n")

Epoch: [0]  [ 0/33]  eta: 0:00:28  lr: 0.000100  loss: 2.2062 (2.2062)  CLS_Loss: 0.8414 (0.8414)  SEG_Loss: 0.6883 (0.6883)  REC_Loss: 0.6766 (0.6766)  time: 0.8579  data: 0.6162  max mem: 3590
Epoch: [0]  [10/33]  eta: 0:00:04  lr: 0.000100  loss: 1.6213 (1.7730)  CLS_Loss: 0.6420 (0.7967)  SEG_Loss: 0.5226 (0.4925)  REC_Loss: 0.4522 (0.4837)  time: 0.2134  data: 0.0561  max mem: 3804
Epoch: [0]  [20/33]  eta: 0:00:02  lr: 0.000100  loss: 1.4848 (1.5501)  CLS_Loss: 0.6211 (0.7082)  SEG_Loss: 0.4752 (0.4557)  REC_Loss: 0.3191 (0.3862)  time: 0.1495  data: 0.0001  max mem: 3804
Epoch: [0]  [30/33]  eta: 0:00:00  lr: 0.000100  loss: 1.1940 (1.4644)  CLS_Loss: 0.5642 (0.6850)  SEG_Loss: 0.3560 (0.4468)  REC_Loss: 0.2463 (0.3326)  time: 0.1496  data: 0.0002  max mem: 3804
Epoch: [0]  [32/33]  eta: 0:00:00  lr: 0.000100  loss: 1.2248 (1.4599)  CLS_Loss: 0.6028 (0.6918)  SEG_Loss: 0.3560 (0.4444)  REC_Loss: 0.2341 (0.3237)  time: 0.1493  data: 0.0001  max mem: 3804
Epoch: [0] Total time: 0:

Valid:  [60/70]  eta: 0:00:00  loss: 0.3776 (0.7219)  CLS_Loss: 0.2825 (0.5945)  SEG_Loss: 0.0000 (0.0303)  REC_Loss: 0.0986 (0.0971)  dice: 1.3389 (1.2743)  time: 0.0143  data: 0.0001  max mem: 4345
Valid:  [69/70]  eta: 0:00:00  loss: 0.3546 (0.6894)  CLS_Loss: 0.2546 (0.5598)  SEG_Loss: 0.0000 (0.0310)  REC_Loss: 0.1025 (0.0986)  dice: 1.2452 (1.2559)  time: 0.0143  data: 0.0001  max mem: 4345
Valid: Total time: 0:00:01 (0.0249 s / it)
Averaged valid_stats:  {'loss': 0.6894331, 'CLS_Loss': 0.5598408, 'SEG_Loss': 0.0310394, 'REC_Loss': 0.0985529, 'dice': 1.2559432, 'auc': 0.4489796, 'f1': 0.0, 'acc': 0.7857143, 'sen': 0.0, 'spe': 0.9821429}
Epoch: [3]  [ 0/33]  eta: 0:00:22  lr: 0.000100  loss: 0.8479 (0.8479)  CLS_Loss: 0.6614 (0.6614)  SEG_Loss: 0.0848 (0.0848)  REC_Loss: 0.1018 (0.1018)  time: 0.6928  data: 0.5431  max mem: 4349
Epoch: [3]  [10/33]  eta: 0:00:04  lr: 0.000100  loss: 1.0947 (1.1036)  CLS_Loss: 0.6955 (0.6894)  SEG_Loss: 0.1869 (0.3120)  REC_Loss: 0.1025 (0.1023)  t

Valid:  [40/70]  eta: 0:00:00  loss: 0.3479 (0.8900)  CLS_Loss: 0.2538 (0.8274)  SEG_Loss: 0.0000 (-0.0361)  REC_Loss: 0.0869 (0.0987)  dice: 1.3670 (1.4974)  time: 0.0141  data: 0.0001  max mem: 4349
Valid:  [50/70]  eta: 0:00:00  loss: 0.3479 (0.8218)  CLS_Loss: 0.2538 (0.7606)  SEG_Loss: 0.0000 (-0.0359)  REC_Loss: 0.0868 (0.0971)  dice: 1.5132 (1.5299)  time: 0.0142  data: 0.0001  max mem: 4349
Valid:  [60/70]  eta: 0:00:00  loss: 0.2688 (0.8130)  CLS_Loss: 0.1886 (0.7346)  SEG_Loss: 0.0000 (-0.0175)  REC_Loss: 0.0871 (0.0960)  dice: 1.3670 (1.4030)  time: 0.0147  data: 0.0001  max mem: 4349
Valid:  [69/70]  eta: 0:00:00  loss: 0.2688 (0.7632)  CLS_Loss: 0.1759 (0.6810)  SEG_Loss: 0.0000 (-0.0139)  REC_Loss: 0.0914 (0.0961)  dice: 1.3367 (1.3856)  time: 0.0146  data: 0.0001  max mem: 4349
Valid: Total time: 0:00:01 (0.0240 s / it)
Averaged valid_stats:  {'loss': 0.7631899, 'CLS_Loss': 0.6809594, 'SEG_Loss': -0.013901, 'REC_Loss': 0.0961315, 'dice': 1.3856257, 'auc': 0.2908163, 'f1'

Valid:  [20/70]  eta: 0:00:01  loss: 0.3639 (0.6602)  CLS_Loss: 0.2819 (0.5990)  SEG_Loss: 0.0000 (-0.0058)  REC_Loss: 0.0659 (0.0670)  dice: 1.1942 (1.1525)  time: 0.0157  data: 0.0001  max mem: 4349
Valid:  [30/70]  eta: 0:00:01  loss: 0.3295 (0.6471)  CLS_Loss: 0.2675 (0.6154)  SEG_Loss: 0.0000 (-0.0379)  REC_Loss: 0.0706 (0.0697)  dice: 1.2187 (1.3061)  time: 0.0172  data: 0.0001  max mem: 4349
Valid:  [40/70]  eta: 0:00:00  loss: 0.3283 (0.5813)  CLS_Loss: 0.2555 (0.5570)  SEG_Loss: 0.0000 (-0.0444)  REC_Loss: 0.0665 (0.0686)  dice: 1.2783 (1.3675)  time: 0.0166  data: 0.0001  max mem: 4349
Valid:  [50/70]  eta: 0:00:00  loss: 0.2523 (0.5173)  CLS_Loss: 0.1773 (0.4838)  SEG_Loss: 0.0000 (-0.0357)  REC_Loss: 0.0680 (0.0692)  dice: 1.2783 (1.3675)  time: 0.0152  data: 0.0001  max mem: 4349
Valid:  [60/70]  eta: 0:00:00  loss: 0.2420 (0.5353)  CLS_Loss: 0.1607 (0.5009)  SEG_Loss: 0.0000 (-0.0353)  REC_Loss: 0.0719 (0.0697)  dice: 1.2187 (1.3329)  time: 0.0148  data: 0.0001  max mem: 

Exception ignored in: <function _releaseLock at 0x7fa2d10893a0>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/logging/__init__.py", line 223, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


# Log check

In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt

def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

In [None]:
51, 36, 35

In [None]:
log_list = read_log(path = '/workspace/sunggu/6.Kakao/checkpoints/[Baseline]Face_Net_DAC_fold0_upsample/log.txt')
print(log_list[0].keys())
result_dict = {}

for key in log_list[0].keys():
    exec( "result_dict['"+str(key)+"']" + " = [ log_list[i]['"+str(key)+"'] for i in range(len(log_list)) ]")

for key in result_dict.keys():
    plt.plot(result_dict[key])
    plt.title(key)
    print("###########################################################")
    print("Argsort = ", np.argsort(result_dict[key])[:5])
    print("Value   = ", [result_dict[key][i] for i in np.argsort(result_dict[key])[:5]])
    plt.show()
    
    if key == 'valid_loss':
        print("Valid_Loss = ", np.argsort(result_dict[key])[:3])

In [None]:
# Log check

import glob
import numpy as np
import matplotlib.pyplot as plt

def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

51, 36, 35

log_list = read_log(path = '/workspace/sunggu/6.Kakao/checkpoints/[Baseline]Face_Net_DAC_fold0_upsample/log.txt')
print(log_list[0].keys())
result_dict = {}

for key in log_list[0].keys():
    exec( "result_dict['"+str(key)+"']" + " = [ log_list[i]['"+str(key)+"'] for i in range(len(log_list)) ]")

for key in result_dict.keys():
    plt.plot(result_dict[key])
    plt.title(key)
    print("###########################################################")
    print("Argsort = ", np.argsort(result_dict[key])[:5])
    print("Value   = ", [result_dict[key][i] for i in np.argsort(result_dict[key])[:5]])
    plt.show()
    
    if key == 'valid_loss':
        print("Valid_Loss = ", np.argsort(result_dict[key])[:3])

# TEST

In [None]:
# Resume

print("Loading... Resume")
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])        
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])        
args.start_epoch = checkpoint['epoch'] + 1  
try:
    log_path = os.path.dirname(args.resume)+'/log.txt'
    lines    = open(log_path,'r').readlines()
    val_loss_list = []
    for l in lines:
        exec('log_dict='+l.replace('NaN', '0'))
        val_loss_list.append(log_dict['valid_loss'])
    print("Epoch: ", np.argmin(val_loss_list), " Minimum Val Loss ==> ", np.min(val_loss_list))