### GLOBAL IMPORTS AND PARAMETERS ###


In [1]:
# !pip3 install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html


In [2]:
import os
import torch 
from models_scripts import i3_res50, i3_res50_nl, disable_bn, enable_bn
from utilities_scripts import SAM, LR_Scheduler, get_criterion, LoadingBar, Log, initialize, RandAugment
from dataset_scripts import CTDataset
import json

from torch.utils.data import DataLoader
import torchvision


batch_size = 2
cuda_device_index = 1
rho = 0.05
learning_rate = 0.001
momentum = 0.9
weight_decay = 0.005
warmup_epochs = 5
epochs = 150
n_class = 2 # extend number of classes
fold_id = "1" #the current fold running
root = "/home/sentic/storage2/iccv_madu/fold_1"
num_workers = 2 # workers for dataloader
fold_train_path = "./train_folding.json"
fold_valid_path = "./valid_folding.json"
checkpoint_dir = "/home/sentic/storage2/iccv_madu/checkpoints/"
# checkpoint_dir = "/home/sentic/Documents/data/storage2/LEUKEMIA/C-NMC_Leukemia/checkpoints/"
device = torch.device("cuda:" + str(cuda_device_index) if torch.cuda.is_available() else "cpu")
prepath = ""
# replacer = "/home/sentic/Documents/data/storage2/LEUKEMIA/C-NMC_Leukemia"
replacer = ""
clip_len = 128

### MODEL STUFF ###
#### I) ResNet50_3D_NL ####


In [3]:
pretrained = None

model = i3_res50_nl(n_class)

if pretrained is not None:
    model.load_state_dict(torch.load(pretrained, map_location='cuda:' + str(cuda_device_index)))


######################
model.to(device)


I3Res50(
  (conv1): Conv3d(3, 64, kernel_size=(5, 7, 7), stride=(2, 2, 2), padding=(2, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool1): MaxPool3d(kernel_size=(2, 3, 3), stride=(2, 2, 2), padding=(0, 0, 0), dilation=1, ceil_mode=False)
  (maxpool2): MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0), dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv3d(64, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn3): BatchNorm3d(256, e

### DATASET STUFF ###

In [4]:
with open(fold_train_path) as fhandle:
    fold_splitter_train = json.load(fhandle)
    
with open(fold_valid_path) as fhandle:
    fold_splitter_valid = json.load(fhandle)
    
dataset_train = CTDataset(root=root, 
                      fold_id=fold_id, 
                      fold_splitter=fold_splitter_train,
                      transforms=None,
                      replacer="",
                      prepath="",
                      clip_len=clip_len,
                      split="train"
                      )

dataset_valid = CTDataset(root=root, 
                      fold_id=fold_id, 
                      fold_splitter=fold_splitter_valid,
                      transforms=None,
                      replacer="",
                      prepath="",
                      clip_len=clip_len,
                      split="val"
                      )

dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
dataloader_valid = DataLoader(dataset_valid, batch_size=1, shuffle=True, num_workers=num_workers)

### CHECKPOINTING MODEL ###

In [5]:
### CHECKPOINTING ###
#checkpoint = "/home/sentic/storage2/iccv_madu/checkpoints/checkpoint_model1_1_17.pth"
checkpoint = None

epoch_checkpoint = None
net_state_dict = None
optimizer_state_dict = None

if checkpoint is not None:
    dict_checkpoint = torch.load(checkpoint)
    epoch_checkpoint = dict_checkpoint['epoch'] + 1
    net_state_dict = dict_checkpoint['model_state_dict']
    optimizer_state_dict = dict_checkpoint['optimizer_state_dict']
    print("Initializing from checkpoint")

for param in model.parameters():
    param.requires_grad = True

if net_state_dict is not None:
    model.load_state_dict(net_state_dict)
    print("Loading model weights from checkpoint")
    
if epoch_checkpoint is not None:
    if epoch_checkpoint > warmup_epochs:
        warmup_epochs = 0
    else:
        warmup_epochs = warmup_epochs - epoch_checkpoint
    print("Setting warmup_epochs to {}".format(warmup_epochs))

if epoch_checkpoint is None:
    epoch_checkpoint = 0
    


### UTILS STUFF ###

In [6]:
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=rho, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

if optimizer_state_dict is not None:
    optimizer.load_state_dict(optimizer_state_dict)

scheduler = LR_Scheduler('cos',
                        base_lr=learning_rate,
                        num_epochs=epochs - epoch_checkpoint,
                        iters_per_epoch=len(dataloader_train),
                        warmup_epochs=warmup_epochs)

criterion = get_criterion(smooth=0.1)
log = Log(log_each=10)




### TRAIN LOOP with CHECKPOINTING OPTIMIZER ###

In [7]:

saving_epochs = list(range(epochs))

best_pred = 0

print("Starting from epoch {}".format(epoch_checkpoint))
for epoch in range(epoch_checkpoint, epochs):
    model.train()
    log.train(len_dataset=len(dataloader_train))
    
    for ix, batch in enumerate(dataloader_train):
        scheduler(optimizer, ix, epoch, best_pred)
        inputs, targets = (b.to(device) for b in batch)
        predictions = model(inputs)
        loss = criterion(predictions, targets)
        loss.mean().backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward step
        # disable_bn(model)
        criterion(model(inputs), targets).mean().backward()
        # enable_bn(model)
        optimizer.second_step(zero_grad=True)

        with torch.no_grad():
            correct = torch.argmax(predictions.data, 1) == targets
            log(model, loss.cpu(), correct.cpu(), optimizer.param_groups[0]["lr"])
                
    model.eval()
    log.eval(len_dataset=len(dataloader_valid))

    with torch.no_grad():
        for batch in dataloader_valid:
            inputs, targets = (b.to(device) for b in batch)

            predictions = model(inputs)
            loss = criterion(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(model, loss.cpu(), correct.cpu())
            
    if epoch in saving_epochs:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            }, os.path.join(checkpoint_dir, "checkpoint_model1_150_" + str(fold_id) + "_" + str(epoch) + ".pth")
        )

log.flush()   

Starting from epoch 0
┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨

┃           0  ┃      0.3457  │     54.77 %  ┃   1.997e-04  │   13:04 min  ┃┈██████████████████████████▓┈┨      0.7023  │     55.59 %  ┃
┃           1  ┃      0.3415  │     57.10 %  ┃   3.997e-04  │   13:06 min  ┃┈██████████████████████████▓┈┨      0.6794  │     55.59 %  ┃
┃           2  ┃      0.3420  │     55.50 %  ┃   5.997e-04  │   13:06 min  ┃┈██████████████████████████▓┈┨      0.6811  │     55.59 %  ┃
┃           3  ┃      0.3431  │     55.37 %  ┃   7.997e-04  │   13:06 min  ┃┈██████████████████████████▓┈┨      0.6840  │     55.59 %  ┃
┃

┃          56  ┃      0.1666  │     92.51 %  ┃   7.149e-04  │   12:59 min  ┃┈██████████████████████████▓┈┨      0.4204  │     85.29 %  ┃
┃          57  ┃      0.1628  │     93.44 %  ┃   7.050e-04  │   12:59 min  ┃┈██████████████████████████▓┈┨      0.3878  │     86.92 %  ┃
┃          58  ┃      0.1719  │     91.64 %  ┃   6.951e-04  │   13:01 min  ┃┈██████████████████████████▓┈┨      0.4177  │     85.01 %  ┃
┃          59  ┃      0.1681  │     93.10 %  ┃   6.851e-04  │   13:00 min  ┃┈██████████████████████████▓┈┨      0.3944  │     87.47 %  ┃
┃          60  ┃      0.1693  │     93.24 %  ┃   6.750e-04  │   12:59 min  ┃┈██████████████████████████▓┈┨      0.4015  │     86.38 %  ┃
┃          61  ┃      0.1657  │     92.64 %  ┃   6.648e-04  │   12:59 min  ┃┈██████████████████████████▓┈┨      0.4025  │     86.65 %  ┃
┃          62  ┃      0.1739  │     91.51 %  ┃   6.545e-04  │   12:58 min  ┃┈██████████████████████████▓┈┨      0.4112  │     86.10 %  ┃
┃          63  ┃      0.1701  │     91.98

┃         116  ┃      0.1452  │     95.36 %  ┃   1.225e-04  │   13:05 min  ┃┈██████████████████████████▓┈┨      0.3608  │     90.46 %  ┃
┃         117  ┃      0.1439  │     95.82 %  ┃   1.154e-04  │   13:05 min  ┃┈██████████████████████████▓┈┨      0.3653  │     90.19 %  ┃
┃         118  ┃      0.1423  │     96.22 %  ┃   1.086e-04  │   13:09 min  ┃┈██████████████████████████▓┈┨      0.3634  │     89.92 %  ┃
┃         119  ┃      0.1414  │     95.95 %  ┃   1.020e-04  │   13:10 min  ┃┈██████████████████████████▓┈┨      0.3725  │     90.46 %  ┃
┃         120  ┃      0.1404  │     96.42 %  ┃   9.550e-05  │   13:09 min  ┃┈██████████████████████████▓┈┨      0.3611  │     90.46 %  ┃
┃         121  ┃      0.1404  │     96.42 %  ┃   8.923e-05  │   13:03 min  ┃┈██████████████████████████▓┈┨      0.3698  │     90.19 %  ┃
┃         122  ┃      0.1421  │     95.76 %  ┃   8.315e-05  │   13:02 min  ┃┈██████████████████████████▓┈┨      0.3666  │     89.65 %  ┃
┃         123  ┃      0.1401  │     95.95