# Loading importing

In [1]:
!nvidia-smi

import os
import argparse
import datetime
import time
import cv2
import random
import torch
import numpy as np
import glob
import matplotlib.pyplot as plt


os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Fri Sep 16 19:28:56 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:18:00.0 Off |                  Off |
| 55%   79C    P2   107W / 300W |   4378MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    Off  | 00000000:3B:00.0 Off |                  Off |
|100%   88C    P2   247W / 300W |  45210MiB / 48685MiB |     98%      Default |
|       

In [2]:
# fix random seeds for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

P: 75 (22.73%) N: 255 (77.27%) Total: 330
P: 14 (20.00%) N: 56  (80.00%) Total: 70
P: 21 (21.00%) N: 79  (79.00%) Total: 100

# Dataset

In [3]:
pwd

'/workspace/sunggu/7.KOHI/Multi_task_learning_tutorials'

In [4]:
DATA_DIR = './SSIM_seg/'

x_train_list = glob.glob(os.path.join(DATA_DIR, 'train/*'))
y_train_list = glob.glob(os.path.join(DATA_DIR, 'trainannot/*'))

x_valid_list = glob.glob(os.path.join(DATA_DIR, 'val/*'))
y_valid_list = glob.glob(os.path.join(DATA_DIR, 'valannot/*'))

x_test_list  = glob.glob(os.path.join(DATA_DIR, 'test/*'))
y_test_list  = glob.glob(os.path.join(DATA_DIR, 'testannot/*'))

In [5]:
x_train_list[0]

'./SSIM_seg/train/1.2.276.0.7230010.3.1.4.8323329.300.1517875162.258081.png'

In [6]:
x_valid_list[0]

'./SSIM_seg/val/1.2.276.0.7230010.3.1.4.8323329.1173.1517875166.626582.png'

In [7]:
from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader


# Dataset Class
class Dataset(BaseDataset):    
    def __init__(self, images_list, labels_list, transform):
        self.images_list  = images_list
        self.labels_list  = labels_list
        self.transform    = transform
    
    def __getitem__(self, i):
        # read data
        image = cv2.imread(self.images_list[i], cv2.IMREAD_GRAYSCALE)
        mask  = cv2.imread(self.labels_list[i], cv2.IMREAD_GRAYSCALE)
        path  = self.images_list[i]
        
        mask  = np.expand_dims(mask, axis=0)
            
        # apply transform
        sample = self.transform(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']        
        
        return image, mask, path
        
    def __len__(self):
        return len(self.images_list)

In [8]:
from albumentations.pytorch.transforms import ToTensorV2
import albumentations as albu


def minmax_normalize(image, **kwargs):
    if len(np.unique(image)) != 1:  # Sometimes it cause the nan inputs...
        image = image.astype('float32')
        image -= image.min()
        image /= image.max() 
    return image


train_transform = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.ShiftScaleRotate(scale_limit=0.10, shift_limit=0.10, rotate_limit=15, p=0.5),
    albu.GaussNoise(p=0.2),
    albu.OneOf(
        [
            albu.CLAHE(p=1),
            albu.RandomBrightnessContrast(p=1),
            albu.RandomGamma(p=1),
        ],
        p=0.3,
    ),
    albu.OneOf(
        [
            albu.Blur(blur_limit=3, p=1),
            albu.MotionBlur(blur_limit=3, p=1),
        ],
        p=0.3,
    ),
    albu.Lambda(image=minmax_normalize, always_apply=True),
    ToTensorV2(),    
])

valid_transform = albu.Compose([        
    albu.Lambda(image=minmax_normalize, always_apply=True),
    ToTensorV2(),
])


In [9]:
dataset_train = Dataset(images_list=x_train_list, labels_list=y_train_list, transform=train_transform)
dataset_valid = Dataset(images_list=x_valid_list, labels_list=y_valid_list, transform=valid_transform)

data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=10, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)
data_loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=1,  num_workers=4, shuffle=True, pin_memory=True, drop_last=False)

In [None]:
# # same image with different random transforms

# batch = next(iter(train_loader))
# x = batch['x'][0]
# y_seg = batch['y_seg'][0]
# y_cls = batch['y_cls'][0]

# print(x.shape,y_seg.shape,y_cls.shape)
# print(torch.unique(y_seg),y_cls)
# visualize(image=x, mask=y_seg)

# Model

In [10]:
from arch.smart_net import *
from losses import MTL_Loss

# Model
model        = STL_2_Net(encoder_name='resnet18').to('cuda')     

# Loss
criterion    = MTL_Loss(name='STL_SEG')

# Optimizer & LR Schedule   
optimizer    = torch.optim.AdamW(params=model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of Learnable Params:', n_parameters)   

Number of Learnable Params: 14322198


# Loop

In [None]:
import json
import math
import torch
import utils
from torch import nn
import torch.nn.functional as F

from metrics import *
from losses import soft_dice_score
from sklearn.metrics import roc_auc_score

print_freq = 10
output_dir = './checkpoints/stl_seg/'
device     = 'cuda'

# Whole LOOP
for epoch in range(0, 200):
    
    ################################################################################################
    # Training 
    ################################################################################################
    
    model.train(True)
    metric_logger = utils.MetricLogger(delimiter="  ", n=10)
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    
    for batch_data in metric_logger.log_every(data_loader_train, print_freq, header):
        
        input   = batch_data[0].to(device).float()
        seg_gt  = batch_data[1].to(device).float()

        seg_pred = model(input)

        loss, loss_detail = criterion(seg_pred=seg_pred, seg_gt=seg_gt)
        
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss_value)
        if loss_detail is not None:
            metric_logger.update(**loss_detail)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    
    train_stats = {k: round(meter.global_avg, 7) for k, meter in metric_logger.meters.items()}
    print("Averaged train_stats: ", train_stats)
    
    ################################################################################################
    # Validation
    ################################################################################################
    
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ", n=1)
    header = 'Valid:'

    
    for batch_data in metric_logger.log_every(data_loader_valid, print_freq, header):
        
        input  = batch_data[0].to(device).float()
        seg_gt = batch_data[1].to(device).float()

        seg_pred = model(input)

        loss, loss_detail = criterion(seg_pred=seg_pred, seg_gt=seg_gt)
    
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))

        # LOSS
        metric_logger.update(loss=loss_value)        
        if loss_detail is not None:
            metric_logger.update(**loss_detail)

        # post-processing
        seg_pred = torch.sigmoid(seg_pred)


        # Metrics SEG
        if seg_gt.any():
            dice = soft_dice_score(output=seg_pred.round(), target=seg_gt, smooth=0.0)    # pred_seg must be round() !! 
            metric_logger.update(dice=dice.item())     
    
    valid_stats = {k: round(meter.global_avg, 7) for k, meter in metric_logger.meters.items()}
    print("Averaged valid_stats: ", valid_stats)
    
    ################################################################################################
    # Save & Log
    ################################################################################################
    
    checkpoint_paths = output_dir + '/epoch_' + str(epoch) + '_checkpoint.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }, checkpoint_paths)

    log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                **{f'valid_{k}': v for k, v in valid_stats.items()},
                'epoch': epoch}

    if output_dir:
        with open(output_dir + "/log.txt", "a") as f:
            f.write(json.dumps(log_stats) + "\n")

Epoch: [0]  [ 0/33]  eta: 0:00:26  lr: 0.000100  loss: 0.3585 (0.3585)  SEG_Loss: 0.3585 (0.3585)  time: 0.7908  data: 0.5204  max mem: 2750
Epoch: [0]  [10/33]  eta: 0:00:04  lr: 0.000100  loss: 0.4467 (0.4196)  SEG_Loss: 0.4467 (0.4196)  time: 0.2016  data: 0.0475  max mem: 2917
Epoch: [0]  [20/33]  eta: 0:00:02  lr: 0.000100  loss: 0.2561 (0.3276)  SEG_Loss: 0.2561 (0.3276)  time: 0.1330  data: 0.0002  max mem: 2917
Epoch: [0]  [30/33]  eta: 0:00:00  lr: 0.000100  loss: 0.1983 (0.3209)  SEG_Loss: 0.1983 (0.3209)  time: 0.1124  data: 0.0001  max mem: 2917
Epoch: [0]  [32/33]  eta: 0:00:00  lr: 0.000100  loss: 0.2331 (0.3298)  SEG_Loss: 0.2331 (0.3298)  time: 0.1099  data: 0.0001  max mem: 2917
Epoch: [0] Total time: 0:00:04 (0.1463 s / it)
Averaged train_stats:  {'lr': 0.0001, 'loss': 0.32983, 'SEG_Loss': 0.32983}
Valid:  [ 0/70]  eta: 0:00:24  loss: 0.0000 (0.0000)  SEG_Loss: 0.0000 (0.0000)  time: 0.3543  data: 0.3229  max mem: 2917
Valid:  [10/70]  eta: 0:00:02  loss: 0.0000 (0.08

Valid: Total time: 0:00:01 (0.0221 s / it)
Averaged valid_stats:  {'loss': -0.0512388, 'SEG_Loss': -0.0512388, 'dice': 1.265149}
Epoch: [4]  [ 0/33]  eta: 0:00:24  lr: 0.000100  loss: -0.4186 (-0.4186)  SEG_Loss: -0.4186 (-0.4186)  time: 0.7299  data: 0.4678  max mem: 3308
Epoch: [4]  [10/33]  eta: 0:00:07  lr: 0.000100  loss: -0.0439 (0.1165)  SEG_Loss: -0.0439 (0.1165)  time: 0.3336  data: 0.0427  max mem: 3308
Epoch: [4]  [20/33]  eta: 0:00:03  lr: 0.000100  loss: 0.0000 (0.1233)  SEG_Loss: 0.0000 (0.1233)  time: 0.2645  data: 0.0001  max mem: 3308
Epoch: [4]  [30/33]  eta: 0:00:00  lr: 0.000100  loss: 0.0944 (0.1781)  SEG_Loss: 0.0944 (0.1781)  time: 0.1864  data: 0.0001  max mem: 3308
Epoch: [4]  [32/33]  eta: 0:00:00  lr: 0.000100  loss: 0.0944 (0.1588)  SEG_Loss: 0.0944 (0.1588)  time: 0.1774  data: 0.0001  max mem: 3308
Epoch: [4] Total time: 0:00:07 (0.2371 s / it)
Averaged train_stats:  {'lr': 0.0001, 'loss': 0.1588362, 'SEG_Loss': 0.1588362}
Valid:  [ 0/70]  eta: 0:00:25  lo

Valid:  [60/70]  eta: 0:00:00  loss: 0.0000 (-0.0612)  SEG_Loss: 0.0000 (-0.0612)  dice: 1.3465 (1.4099)  time: 0.0155  data: 0.0001  max mem: 3308
Valid:  [69/70]  eta: 0:00:00  loss: 0.0000 (-0.0658)  SEG_Loss: 0.0000 (-0.0658)  dice: 1.3092 (1.3601)  time: 0.0147  data: 0.0001  max mem: 3308
Valid: Total time: 0:00:01 (0.0233 s / it)
Averaged valid_stats:  {'loss': -0.0657751, 'SEG_Loss': -0.0657751, 'dice': 1.360118}
Epoch: [8]  [ 0/33]  eta: 0:00:21  lr: 0.000100  loss: -0.5948 (-0.5948)  SEG_Loss: -0.5948 (-0.5948)  time: 0.6465  data: 0.4413  max mem: 3308
Epoch: [8]  [10/33]  eta: 0:00:06  lr: 0.000100  loss: -0.0006 (0.0241)  SEG_Loss: -0.0006 (0.0241)  time: 0.2621  data: 0.0403  max mem: 3308
Epoch: [8]  [20/33]  eta: 0:00:03  lr: 0.000100  loss: -0.0006 (0.0503)  SEG_Loss: -0.0006 (0.0503)  time: 0.2258  data: 0.0002  max mem: 3308
Epoch: [8]  [30/33]  eta: 0:00:00  lr: 0.000100  loss: 0.0437 (0.1148)  SEG_Loss: 0.0437 (0.1148)  time: 0.2571  data: 0.0001  max mem: 3308
Epo

# Log check

In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt

def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

In [None]:
51, 36, 35

In [None]:
log_list = read_log(path = '/workspace/sunggu/6.Kakao/checkpoints/[Baseline]Face_Net_DAC_fold0_upsample/log.txt')
print(log_list[0].keys())
result_dict = {}

for key in log_list[0].keys():
    exec( "result_dict['"+str(key)+"']" + " = [ log_list[i]['"+str(key)+"'] for i in range(len(log_list)) ]")

for key in result_dict.keys():
    plt.plot(result_dict[key])
    plt.title(key)
    print("###########################################################")
    print("Argsort = ", np.argsort(result_dict[key])[:5])
    print("Value   = ", [result_dict[key][i] for i in np.argsort(result_dict[key])[:5]])
    plt.show()
    
    if key == 'valid_loss':
        print("Valid_Loss = ", np.argsort(result_dict[key])[:3])

In [None]:
# Log check

import glob
import numpy as np
import matplotlib.pyplot as plt

def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

51, 36, 35

log_list = read_log(path = '/workspace/sunggu/6.Kakao/checkpoints/[Baseline]Face_Net_DAC_fold0_upsample/log.txt')
print(log_list[0].keys())
result_dict = {}

for key in log_list[0].keys():
    exec( "result_dict['"+str(key)+"']" + " = [ log_list[i]['"+str(key)+"'] for i in range(len(log_list)) ]")

for key in result_dict.keys():
    plt.plot(result_dict[key])
    plt.title(key)
    print("###########################################################")
    print("Argsort = ", np.argsort(result_dict[key])[:5])
    print("Value   = ", [result_dict[key][i] for i in np.argsort(result_dict[key])[:5]])
    plt.show()
    
    if key == 'valid_loss':
        print("Valid_Loss = ", np.argsort(result_dict[key])[:3])

# TEST

In [None]:
# Resume

print("Loading... Resume")
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])        
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])        
args.start_epoch = checkpoint['epoch'] + 1  
try:
    log_path = os.path.dirname(args.resume)+'/log.txt'
    lines    = open(log_path,'r').readlines()
    val_loss_list = []
    for l in lines:
        exec('log_dict='+l.replace('NaN', '0'))
        val_loss_list.append(log_dict['valid_loss'])
    print("Epoch: ", np.argmin(val_loss_list), " Minimum Val Loss ==> ", np.min(val_loss_list))