In [None]:
import torch
import torch.nn as nn
import torchvision

# %matplotlib nbagg
import numpy as np
import matplotlib.pyplot as plt

import time
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from dataset import *
from transforms import *
from criteria import *
from torch.utils.data import DataLoader

In [None]:
# BSDS 500

# NOTE: remember to adjust PadTo2Power's k value according to gamma-net depth
# NOTE: data stays in CPU until right before forward pass (i.e. no need for memory pinning, etc)

x_transform = torchvision.transforms.Compose([CenterCrop((256,256)),
                                              ToTensor(make_channel_first=True, div=True),
                                              PadTo2Power(axes=(1,2), k=5, mode='constant'),
                                              AssertWidthMajor()])
y_transform = torchvision.transforms.Compose([CenterCrop((256,256)),
                                              ToTensor(make_channel_first=False, float_out=False),
                                              PadTo2Power(axes=(0,1), k=5, mode='constant'),
                                              AssertWidthMajor()])
ds_train = SimpleDataset("bsds500/x_train.txt",
                         "bsds500/y_c_train.txt",
                         x_transform=x_transform, y_transform=y_transform)
dl_train = DataLoader(ds_train, batch_size=4, shuffle=True, num_workers=0)

In [None]:
# ACDC

# NOTE: remember to adjust PadTo2Power's k value according to gamma-net depth
# NOTE: data stays in CPU until right before forward pass (i.e. no need for memory pinning, etc)

x_transform = torchvision.transforms.Compose([GaussianSmooth(3,1),
                                              CLAHE(clipLimit=2.0, tileGridSize=(8,8)),
                                              PadOrCenterCrop(size=(224,224)),
                                              ToTensor(make_channel_first=True, div=True),
                                              ExpandDims(dim=0)])
y_transform = torchvision.transforms.Compose([PadOrCenterCrop(size=(224,224)),
                                              ToTensor(make_channel_first=False, float_out=False)])
ds_train = SimpleDataset("ACDC_Dataset_p/training/x_train.txt",
                         "ACDC_Dataset_p/training/y_train.txt",
                         x_transform=x_transform, y_transform=y_transform, 
                         use_cache=True)
dl_train = DataLoader(ds_train, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
ds_val = SimpleDataset("ACDC_Dataset_p/training/x_val.txt",
                       "ACDC_Dataset_p/training/y_val.txt",
                       x_transform=x_transform, y_transform=y_transform,
                       use_cache=True)
dl_val = DataLoader(ds_val, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)

# U-Net

In [None]:
from pytorch_unet.unet import UNet

model = UNet(in_channels=1, n_classes=4, depth=5, wf=3,
        padding=True, batch_norm=True, up_mode='upsample')

if torch.cuda.is_available():
    model = model.cuda().float()
else:
    model = model.double()

In [None]:
load_model = True
model_file = "models_unet_3/model_unet_e400.pkl"
if load_model:
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(model_file))
    else:
        model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

In [None]:
unet = nn.DataParallel(model)

# Gamma-Net

In [None]:
from gammanet import GammaNet

gammanet_config = {
    'in_channels': 1,
    'return_sequences': False,
    'num_filters': [8, 16, 32, 64, 128],
    'conv_kernel_size': [3, 3, 3, 3, 3],
    'conv_blocksize': [1, 1, 1, 1, 1],
    'conv_normtype': 'instancenorm',
    'conv_dropout_p': 0.2,  # 0.2
    'conv_residual': False,
    'fgru_hidden_size': [8, 16, 32, 64, 128],
    'fgru_kernel_size': [9,  7,  5,  3,  1,  1,  1,  1,  1],
    'fgru_timesteps': 4,
    'fgru_normtype': 'instancenorm',
    'fgru_channel_sym': True,
    'upsample_mode': 'bilinear',
    'upsample_all2all': True,
}

model = nn.Sequential(
#     nn.Conv2d(1, 24, 3, padding=1), # Change the number of input channels! 
#     nn.Conv2d(24, 24, 3, padding=1),
    GammaNet(gammanet_config),
    nn.Conv2d(8, 4, 5, padding=2)  # Change the expected number of output classes! 
)

if torch.cuda.is_available():
    model = model.cuda().float()
else:
    model = model.double()

In [None]:
load_model = True
model_file = "models_gn_6/model_e400.pkl"
if load_model:
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(model_file))
    else:
        model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

In [None]:
gnet = nn.DataParallel(model)

# Tests

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(unet.module))
print(count_parameters(gnet.module))

## Statistics

In [None]:
res_unet_arr = []
res_gnet_arr = []
label_arr = []
for i, (x,y) in enumerate(dl_val):
    with torch.no_grad():
        x = x.cuda().float()
        res_unet_arr += [unet(x)]
        res_gnet_arr += [gnet(x)]
        label_arr += [y]
    print("{}/{}".format(i+1, len(dl_val)), end='\r')
res_unet_arr = torch.cat(res_unet_arr, axis=0)
res_gnet_arr = torch.cat(res_gnet_arr, axis=0)
label_arr = torch.cat(label_arr, axis=0)

In [None]:
def cross_entropy(predictions, targets, epsilon=1e-12):
    predictions = torch.clamp(predictions, epsilon, 1. - epsilon)
    N = predictions.shape[0]
    ce = -torch.sum(targets * torch.log(predictions+1e-9)) / N
    return ce
def dice_coeff(pred, target, smooth = 1.):
    pred = nn.Sigmoid()(pred)
    num = pred.size(0)
    m1 = pred.view(num, -1).float()  # Flatten
    m2 = target.view(num, -1).float()  # Flatten
    intersection = (m1 * m2).sum().float()
    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

In [None]:
print(cross_entropy(res_unet_arr[:,1].cpu(), label_arr.cpu()) / len(label_arr))
print(dice_coeff_mc(res_unet_arr.cpu(), label_arr.cpu()))
print(cross_entropy(res_gnet_arr[:,1].cpu(), label_arr.cpu()) / len(label_arr))
print(dice_coeff_mc(res_gnet_arr.cpu(), label_arr.cpu()))

## Result visualizations

In [None]:
# load from val data

x,y = next(iter(dl_val))
x = x.float() if torch.cuda.is_available() else x.double()
print(x.shape, y.shape)
print(x.max(),x.min(),y.max(),y.min())

In [None]:
# run test data through model

model = unet
model.eval()
with torch.no_grad():
    res_unet = model(x)
    print(dice_coeff_mc(res_unet.cpu(),y))
    print(nn.CrossEntropyLoss()(res_unet.cpu(),y.long()))

In [None]:
# run test data through model

model = gnet
model.eval()
with torch.no_grad():
    res_gnet = model(x)
    print(dice_coeff_mc(res_gnet.cpu(),y))
    print(nn.CrossEntropyLoss()(res_gnet.cpu(),y.long()))

In [None]:
# visualize ACDC results
thres = 0.
ss = nn.Sigmoid() 
# ss = nn.Softmax(dim=1)

for i in range(x.shape[0]):
    print(i)
    plt.figure(figsize=(15,10))
    plt.subplot(2,3,1)
    plt.imshow(x[i].numpy()[0], cmap='gray')
    if "y" in dir():
        plt.subplot(2,3,2)
        plt.imshow(y[i])
    for j in range(res_unet.shape[1]):
        plt.subplot(2,3,3+j)
        plt.imshow(ss(res_unet).cpu().detach().numpy()[i,j])
#         plt.colorbar()
    plt.show()
    
    plt.figure(figsize=(15,10))
    plt.subplot(2,3,1)
    plt.imshow(x[i].numpy()[0], cmap='gray')
    if "y" in dir():
        plt.subplot(2,3,2)
        plt.imshow(y[i])
    for j in range(res_gnet.shape[1]):
        plt.subplot(2,3,3+j)
        plt.imshow(ss(res_gnet).cpu().detach().numpy()[i,j])
#         plt.colorbar()
    plt.show()

## Result visualization per timestep

In [None]:
# request gammanet to return sequence

m = model.module
m[2].config['return_sequences'] = True

In [None]:
# get returned sequences from gammanet 

out = m[0](x)
out = m[1](out)
out = m[2](out)

In [None]:
# visualize  per-timestep outputs

idx = 3
label = 1

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(x[idx].cpu().detach().numpy()[0], cmap='gray')
if "y" in dir():
    plt.subplot(1,2,2)
    plt.imshow(y[idx])
plt.show()

plt.figure(figsize=(20,20))
for i, o in enumerate(out[1]):
    o = m[3](o)
    plt.subplot(len(out[1])//4+1, 4, i+1)
    plt.imshow(o[idx,label].cpu().detach().numpy())
plt.show()

In [None]:
del(out)

## fGRU kernel visualization

In [None]:
# inhibition kernels in first down block fGRU
k = model.module.state_dict()['2.fgru_down.0.params.w_inh']
for i in range(k.shape[0]):
    print(i)
    plt.figure(figsize=(15,10))
    for j in range(k.shape[1]):
        plt.subplot(4,6,j+1)
        plt.imshow(k.cpu().detach().numpy()[i,j])
    plt.show()

In [None]:
# excitation kernels in first down block fGRU
k = model.module.state_dict()['2.fgru_down.0.params.w_exc']
for i in range(k.shape[0]):
    print(i)
    plt.figure(figsize=(15,10))
    for j in range(k.shape[1]):
        plt.subplot(4,6,j+1)
        plt.imshow(k.cpu().detach().numpy()[i,j])
    plt.show()

In [None]:
k = model.module.state_dict()['2.fgru_down.1.params.w_inh']
plt.figure(figsize=(15,10))
for i in range(k.shape[0]):
    plt.subplot(4,7,i+1)
    plt.imshow(k.cpu().detach().numpy()[0,i])
plt.show()

In [None]:
k = model.module.state_dict()['2.fgru_down.1.params.w_exc']
plt.figure(figsize=(15,10))
for i in range(k.shape[0]):
    plt.subplot(4,7,i+1)
    plt.imshow(k.cpu().detach().numpy()[0,i])
plt.show()