In [1]:
import os

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Subset, ConcatDataset
import numpy as np

from tqdm import tqdm
import random

In [2]:
import model
import dataset
import augmentation as aug

In [3]:
import matplotlib.pyplot as plt

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, 'gray')
    plt.show()

In [4]:
def train_epoch(model, optimizer, dataloader, device):


    model.train()

    total_loss = 0
    
    dice_total = 0
    kl_total = 0
    dl_total = 0
    
    dice_loss = smp.utils.losses.DiceLoss()

    for index, data in tqdm(enumerate(dataloader)):

        optimizer.zero_grad()

        img, msk, _ = data

        img = img.to(device)
        msk = msk.to(device)

        pr, kl_loss, dl_loss = model(img)

        ### Predicted mask loss
        pr = pr.squeeze(1)
        
#         print(torch.max(pr), torch.min(pr))

        dice = dice_loss(pr, msk)
        kl = torch.mean(kl_loss)
        dl = torch.mean(dl_loss)
        
#         print(dice, kl, dl)

        loss = dice + kl + dl
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        dice_total += dice.item()
        kl_total += kl.item()
        dl_total += dl.item()

    total_loss = total_loss/(index+1)
    dice_total = dice_total/(index+1)
    kl_total = kl_total/(index+1)
    dl_total = dl_total/(index+1)

    return total_loss, dice_total, dl_total, dl_total

In [5]:
@torch.no_grad()
def eval_epoch(model, dataloader, device):

    import math
    from torch.utils.data import DataLoader

    model.eval()

    iou_score = []
    
    metric_iou = smp.utils.metrics.IoU()

    for index, data in tqdm(enumerate(dataloader)):

        img, msk, _ = data

        img = img.to(device)
        msk = msk.to(device)

        pr, _, _ = model(img)
        iou = metric_iou(pr, msk)

        iou_score.append(iou.item())

    return sum(iou_score)/len(iou_score)

In [6]:
@torch.no_grad()
def test_epoch(model, dataset, device):

    import math
    from torch.utils.data import DataLoader

    model.eval()

    imgs = []
    predict = []
    msks = []

    dataloader = DataLoader(dataset, batch_size=1,
                            shuffle=False, num_workers=2)

    for index, data in tqdm(enumerate(dataloader)):

        img, msk, cpy = data

        img = img.to(device)
        msk = msk.to(device)

        pr = model(img)

        pr = torch.squeeze(pr, dim=0).detach().cpu().numpy()
        msk = torch.squeeze(msk, dim=0).detach().cpu().numpy()
        cpy = torch.squeeze(cpy, dim=0).detach().cpu().numpy()

        predict.append(pr.transpose(1, 2, 0))
        imgs.append(cpy.transpose(1, 2, 0))
        msks.append(msk)


    return imgs, predict, msks

In [7]:
batch = 4
n_channels = 3
n_classes = 1
epochs = 1000

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [8]:
ENCODER = 'densenet161'
ENCODER_WEIGHTS = 'imagenet'

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [9]:
unet = smp.Unet(encoder_name=ENCODER, 
                 encoder_weights=ENCODER_WEIGHTS,
                decoder_attention_type=None,
                 in_channels=3, classes=1, activation="sigmoid", aux_params=None)

In [10]:
encoder = unet.encoder

In [11]:
decoder = model.SCGDecoder(None, None, torch.nn.Sigmoid())

In [12]:
scg_net = model.SCGNet(encoder=encoder, 
               decoder=decoder,).to(device)

optimizer = torch.optim.SGD(scg_net.parameters(), lr=1e-4, momentum=0.9)

In [13]:
trainset = dataset.JSRTset(root=os.path.join(os.getcwd(), "data", "trainset"),
                          augmentation=aug.get_training_augmentation(), 
                           preprocessing=aug.get_preprocessing(preprocessing_fn),)
valset = dataset.JSRTset(root=os.path.join(os.getcwd(), "data", "valset"),
                          augmentation=aug.get_validation_augmentation(), 
                           preprocessing=aug.get_preprocessing(preprocessing_fn),)
testset = dataset.JSRTset(root=os.path.join(os.getcwd(), "data", "testset"),
                          augmentation=aug.get_validation_augmentation(), 
                           preprocessing=aug.get_preprocessing(preprocessing_fn),)

In [14]:
trainloader = DataLoader(trainset, batch_size=batch, shuffle=True, num_workers=2)
validloader = DataLoader(valset, batch_size=batch, shuffle=False, num_workers=2)
testloader = DataLoader(testset, batch_size=batch, shuffle=False, num_workers=2)

In [15]:
epoch_logs = {
    "diceloss": [],
    "kl divergence": [],
    "diagonal loss": [],
    "iou-train": [],
    "iou-valid": []
}

In [None]:
iou_valid = 0.0

for epoch in range(epochs):
    
    loss = train_epoch(scg_net, optimizer, trainloader, device)
    eval_train = eval_epoch(scg_net, trainloader, device)
    eval_valid = eval_epoch(scg_net, validloader, device)
    
    print("Epoch: {}, total loss={:.5f}, dice loss={:.5f}, kl loss={:.5f}, dl loss={:.5f}".format(epoch, 
                                                                                                  loss[0],
                                                                                                 loss[1],
                                                                                                 loss[2],
                                                                                                 loss[3]))
    print("Valid-IoU: {:.5f}, Train-IoU: {:.5f}".format(eval_valid, eval_train))
    
    epoch_logs['diceloss'].append(loss[1])
    epoch_logs['kl divergence'].append(loss[2])
    epoch_logs['diagonal loss'].append(loss[3])
    epoch_logs['iou-train'].append(eval_train)
    epoch_logs['iou-valid'].append(eval_valid)
   
    if epoch == int(epochs*0.5):
        optimizer.param_groups[0]['lr'] = 1e-4
        print('Decrease learning rate to 1e-4!')
    elif epoch == int(epochs*0.75):
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease learning rate to 1e-5!')
        
    if eval_valid > iou_valid:
        iou_valid = eval_valid
        checkpoint = {
            'model_stat': unet.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        torch.save(checkpoint, os.path.join(os.getcwd(), "{:04d}_{:04d}_{:04d}.pth".format(int(eval_valid*1000),
                                                                                   int(eval_train*1000),
                                                                                   int(loss[0]*1000))))
        print("Model Saved")
    

41it [00:09,  4.30it/s]
41it [00:07,  5.47it/s]
8it [00:01,  5.16it/s]


Epoch: 0, total loss=2.78479, dice loss=0.99716, kl loss=0.00067, dl loss=0.00067
Valid-IoU: 0.00424, Train-IoU: 0.00529
Model Saved


41it [00:09,  4.37it/s]
41it [00:07,  5.53it/s]
8it [00:01,  5.08it/s]

Epoch: 1, total loss=1.44083, dice loss=0.99708, kl loss=0.00066, dl loss=0.00066
Valid-IoU: 0.00411, Train-IoU: 0.00540



41it [00:09,  4.39it/s]
41it [00:07,  5.52it/s]
8it [00:01,  5.14it/s]

Epoch: 2, total loss=1.26611, dice loss=0.99738, kl loss=0.00067, dl loss=0.00067
Valid-IoU: 0.00409, Train-IoU: 0.00588



41it [00:09,  4.40it/s]
41it [00:07,  5.51it/s]
8it [00:01,  5.24it/s]

Epoch: 3, total loss=1.19314, dice loss=0.99700, kl loss=0.00067, dl loss=0.00067
Valid-IoU: 0.00414, Train-IoU: 0.00579



41it [00:09,  4.41it/s]
41it [00:07,  5.46it/s]
