In [1]:
import os

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Subset, ConcatDataset
import numpy as np

from tqdm import tqdm
import random

In [2]:
import model
import dataset
import augmentation as aug

In [3]:
import matplotlib.pyplot as plt

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, 'gray')
    plt.show()

In [4]:
def train_epoch(model, optimizer, dataloader, device):


    model.train()

    total_loss = 0
    
    dice_total = 0
    kl_total = 0
    dl_total = 0
    bce_total = 0
    
    dice_loss = smp.utils.losses.DiceLoss()
#     bce_loss = torch.nn.BCELoss(reduction='none')

    for index, data in tqdm(enumerate(dataloader)):

        optimizer.zero_grad()

        img, msk, _ = data

        img = img.to(device)
        msk = msk.to(device, dtype=torch.float)

        pr, kl_loss, dl_loss = model(img)

        ### Predicted mask loss
        pr = pr.squeeze(1)


        dice = dice_loss(pr, msk)
        
        ### bce loss
#         bce = bce_loss(pr, msk)
#         weight = msk.clone().detach()
#         weight = torch.where(weight == 1, 100, 1)
#         bce = bce * weight # weighted foreground/background
#         bce = torch.mean(bce)
        
        kl = torch.mean(kl_loss)
        dl = torch.mean(dl_loss)

#         loss = dice
        loss = dice + kl + dl
#         loss = bce + kl + dl
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        dice_total += dice.item()
#         bce_total += bce.item()
        kl_total += kl.item()
        dl_total += dl.item()

    total_loss = total_loss/(index+1)
    dice_total = dice_total/(index+1)
#     bce_total = bce_total/(index+1)
    kl_total = kl_total/(index+1)
    dl_total = dl_total/(index+1)

#     return total_loss, dice_total
    return total_loss, dice_total, kl_total, dl_total

In [5]:
@torch.no_grad()
def eval_epoch(model, dataloader, device):

    import math
    from torch.utils.data import DataLoader

    model.eval()

    iou_score = []
    
    metric_iou = smp.utils.metrics.IoU()

    for index, data in tqdm(enumerate(dataloader)):

        img, msk, _ = data

        img = img.to(device)
        msk = msk.to(device)

        pr, _, _ = model(img)
        iou = metric_iou(pr, msk)

        iou_score.append(iou.item())

    return sum(iou_score)/len(iou_score)

In [6]:
@torch.no_grad()
def test_epoch(model, dataset, device):

    import math
    from torch.utils.data import DataLoader

    model.eval()

    metric_iou = smp.utils.metrics.IoU()
    
    imgs = []
    predict = []
    msks = []
    iou_score = []

    dataloader = DataLoader(dataset, batch_size=1,
                            shuffle=False, num_workers=2)

    for index, data in tqdm(enumerate(dataloader)):

        img, msk, cpy = data

        img = img.to(device)
        msk = msk.to(device)

        pr, _, _ = model(img)

        iou = metric_iou(pr, msk)

        pr = torch.squeeze(pr, dim=0).detach().cpu().numpy()
        msk = torch.squeeze(msk, dim=0).detach().cpu().numpy()
        cpy = torch.squeeze(cpy, dim=0).detach().cpu().numpy()

        predict.append(pr.transpose(1, 2, 0))
        imgs.append(cpy)
        msks.append(msk)
        iou_score.append(iou.item())


    return imgs, predict, msks, iou_score

In [7]:
batch = 4
n_channels = 3
n_classes = 1
epochs = 1000

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [8]:
ENCODER = 'densenet161'
ENCODER_WEIGHTS = 'imagenet'

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [9]:
unet = smp.Unet(encoder_name=ENCODER, 
                 encoder_weights=ENCODER_WEIGHTS,
                decoder_attention_type=None,
                 in_channels=3, classes=1, activation="sigmoid", aux_params=None)

In [10]:
encoder = unet.encoder

In [11]:
decoder = model.SCGraphUnetDecoder(None, None, None, device=device)

In [12]:
scg_net = model.SCGraphUnet(encoder=encoder, 
               decoder=decoder,).to(device)

optimizer = torch.optim.SGD(scg_net.parameters(), lr=1e-3, momentum=0.9)

In [13]:
trainset = dataset.JSRTset(root=os.path.join(os.getcwd(), "data", "trainset"),
                          augmentation=aug.get_training_augmentation(), 
                           preprocessing=aug.get_preprocessing(preprocessing_fn),)
valset = dataset.JSRTset(root=os.path.join(os.getcwd(), "data", "valset"),
                          augmentation=aug.get_validation_augmentation(), 
                           preprocessing=aug.get_preprocessing(preprocessing_fn),)
testset = dataset.JSRTset(root=os.path.join(os.getcwd(), "data", "testset"),
                          augmentation=aug.get_validation_augmentation(), 
                           preprocessing=aug.get_preprocessing(preprocessing_fn),)

In [14]:
trainloader = DataLoader(trainset, batch_size=batch, shuffle=True, num_workers=2)
validloader = DataLoader(valset, batch_size=batch, shuffle=False, num_workers=2)
testloader = DataLoader(testset, batch_size=batch, shuffle=False, num_workers=2)

In [15]:
epoch_logs = {
    "diceloss": [],
#     "bce loss":[],
    "kl divergence": [],
    "diagonal loss": [],
    "iou-train": [],
    "iou-valid": []
}

In [16]:
iou_valid = 0.0

for epoch in range(epochs):
    
    loss = train_epoch(scg_net, optimizer, trainloader, device)
    eval_train = eval_epoch(scg_net, trainloader, device)
    eval_valid = eval_epoch(scg_net, validloader, device)

    print("Epoch: {}, total loss={:.5f}, dice loss={:.5f}, kl loss={:.5f}, dl loss={:.5f}".format(epoch, 
                                                                                                  loss[0],
                                                                                                 loss[1],
                                                                                                 loss[2],
                                                                                                 loss[3]))
#     print("Epoch: {}, total loss={:.5f}, dice loss={:.5f}".format(epoch, 
#                                                                   loss[0],
#                                                                  loss[1],))
    print("Valid-IoU: {:.5f}, Train-IoU: {:.5f}".format(eval_valid, eval_train))
    
    epoch_logs['diceloss'].append(loss[1])
#     epoch_logs["bce loss"].append(loss[1])
    epoch_logs['kl divergence'].append(loss[2])
    epoch_logs['diagonal loss'].append(loss[3])
    epoch_logs['iou-train'].append(eval_train)
    epoch_logs['iou-valid'].append(eval_valid)
   
    if epoch == int(epochs*0.5):
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.5
        print('Decrease learning rate to 1e-4!')
    elif epoch == int(epochs*0.75):
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.5
        print('Decrease learning rate to 1e-5!')
        
    if eval_valid > iou_valid:
        iou_valid = eval_valid
        checkpoint = {
            'model_stat': unet.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        torch.save(checkpoint, os.path.join(os.getcwd(), "{:04d}_{:04d}_{:04d}.pth".format(int(eval_valid*1000),
                                                                                   int(eval_train*1000),
                                                                                   int(loss[0]*1000))))
        print("Model Saved")
    

0it [00:00, ?it/s]

tensor([[[[ 0.1807,  0.1807,  0.1807,  ...,  0.2932,  0.2932,  0.2932],
          [ 0.1807,  0.1807,  0.1807,  ...,  0.2932,  0.2932,  0.2932],
          [ 0.1807,  0.1807,  0.1807,  ...,  0.2932,  0.2932,  0.2932],
          ...,
          [ 0.1987,  0.1987,  0.1987,  ...,  0.1507,  0.1507,  0.1507],
          [ 0.1987,  0.1987,  0.1987,  ...,  0.1507,  0.1507,  0.1507],
          [ 0.1987,  0.1987,  0.1987,  ...,  0.1507,  0.1507,  0.1507]],

         [[ 0.0050,  0.0050,  0.0050,  ...,  0.0354,  0.0354,  0.0354],
          [ 0.0050,  0.0050,  0.0050,  ...,  0.0354,  0.0354,  0.0354],
          [ 0.0050,  0.0050,  0.0050,  ...,  0.0354,  0.0354,  0.0354],
          ...,
          [ 0.0268,  0.0268,  0.0268,  ...,  0.0318,  0.0318,  0.0318],
          [ 0.0268,  0.0268,  0.0268,  ...,  0.0318,  0.0318,  0.0318],
          [ 0.0268,  0.0268,  0.0268,  ...,  0.0318,  0.0318,  0.0318]],

         [[ 0.0304,  0.0304,  0.0304,  ...,  0.3121,  0.3121,  0.3121],
          [ 0.0304,  0.0304,  




TypeError: 'tuple' object is not callable

In [None]:
imgs, predict, msks, iou_score = test_epoch(scg_net, testset, device)

In [None]:
print(sum(iou_score)/len(iou_score))

In [None]:
for index, data in enumerate(zip(imgs, predict, msks, iou_score)):
    img, pr, msk, iou = data
    print("\n Index:{}, IoU={:.5f}".format(index+1, iou))
    visualize(
        img = img,
        groundtruth = msk,
        prediction = pr
    )

In [None]:
# create figure and axis objects with subplots()
fig, axs = plt.subplots(2, 1, figsize=(20, 20))
# make a plot
axs[0].plot(epoch_logs['iou-valid'], color="orange", label="valid iou")
# set x-axis label
axs[0].set_xlabel("epoch",fontsize=14)
# set y-axis label
axs[0].set_ylabel("valid-iou",color="orange",fontsize=14)


# twin object for two different y-axis on the sample plot
ax2 = axs[0].twinx()
# make a plot with different y-axis using second axis object
ax2.plot(epoch_logs['iou-train'], color="blue", label="train iou")
ax2.set_ylabel("train-iou", color="blue", fontsize=14)



axs[1].plot(epoch_logs['diceloss'], label="diceloss")
axs[1].plot(epoch_logs['kl divergence'], label="kl divergence")
axs[1].plot(epoch_logs['diagonal loss'], label="diagonal loss")
axs[1].set_xlabel("epoch",fontsize=14)
axs[1].set_ylabel("loss", color="blue", fontsize=14)
plt.show()

fig.savefig(os.path.join(os.getcwd(),'Ex6.png'),
            bbox_inches='tight',
           facecolor='white')