In [1]:
import torch
import torch.nn as nn
import torchvision

import numpy as np
import matplotlib.pyplot as plt

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Datasets

In [None]:
# BSDS 500

# NOTE: remember to adjust PadTo2Power's k value according to gamma-net depth
# NOTE: data stays in CPU until right before forward pass (i.e. no need for memory pinning, etc)

from dataset import *
from transforms import *
from torch.utils.data import DataLoader

x_transform = torchvision.transforms.Compose([CenterCrop((256,256)),
                                              ToTensor(make_channel_first=True, div=True),
                                              PadTo2Power(axes=(1,2), k=5, mode='constant'),
                                              AssertWidthMajor()])
y_transform = torchvision.transforms.Compose([CenterCrop((256,256)),
                                              ToTensor(make_channel_first=False, float_out=False),
                                              PadTo2Power(axes=(0,1), k=5, mode='constant'),
                                              AssertWidthMajor()])
ds_train = SimpleDataset("bsds500/x_train.txt",
                         "bsds500/y_c_train.txt",
                         x_transform=x_transform, y_transform=y_transform)
dl_train = DataLoader(ds_train, batch_size=4, shuffle=True, num_workers=0)

In [2]:
# ACDC

# NOTE: remember to adjust PadTo2Power's k value according to gamma-net depth
# NOTE: data stays in CPU until right before forward pass (i.e. no need for memory pinning, etc)

from dataset import *
from transforms import *
from torch.utils.data import DataLoader
import torchvision

x_transform = torchvision.transforms.Compose([PadOrCenterCrop(size=(224,224)),
                                              ToTensor(make_channel_first=True, div=True),
                                              ExpandDims(dim=0)])
y_transform = torchvision.transforms.Compose([PadOrCenterCrop(size=(224,224)),
                                              ToTensor(make_channel_first=False, float_out=False)])
ds_train = SimpleDataset("ACDC_Dataset_p/training/x_train.txt",
                         "ACDC_Dataset_p/training/y_train.txt",
                        x_transform=x_transform, y_transform=y_transform)
dl_train = DataLoader(ds_train, batch_size=12, shuffle=True, num_workers=0)

# Gamma-Net

In [3]:
from gammanet import GammaNet

gammanet_config = {
    'in_channels': 24,
    'return_sequences': False,
    'num_filters': [24, 28, 36, 48, 64],
    'conv_kernel_size': [3, 3, 3, 3, 3],
    'conv_blocksize': [1, 1, 1, 1, 1],
    'conv_normtype': 'instancenorm',
    'conv_dropout_p': 0.2,  # 0.2
    'conv_residual': False,
    'fgru_hidden_size': [24, 28, 36, 48, 64],
    'fgru_kernel_size': [9,  7,  5,  3,  1,  1,  1,  1,  1],
    'fgru_timesteps': 4,
    'fgru_normtype': 'instancenorm',
    'fgru_channel_sym': True,
    'upsample_mode': 'bilinear',
    'upsample_all2all': True,
}

model = nn.Sequential(
    nn.Conv2d(1, 24, 3, padding=1), # Change the number of input channels! 
    nn.Conv2d(24, 24, 3, padding=1),
    GammaNet(gammanet_config),
    nn.Conv2d(24, 4, 5, padding=2)  # Change the expected number of output classes! 
)

if torch.cuda.is_available():
    model = model.cuda().float()
else:
    model = model.double()

In [4]:
load_model = False
if load_model:
    if torch.cuda.is_available():
        model.load_state_dict(torch.load("model_e38.pkl"))
    else:
        model.load_state_dict(torch.load("model_e38.pkl", map_location=torch.device('cpu')))

In [5]:
model = nn.DataParallel(model)

In [6]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda() if torch.cuda.is_available() else criterion
optimizer = optim.Adam(model.parameters(), lr=1e-3)

_ = model.train()

In [None]:
from datetime import datetime, timedelta
start_time = datetime.now()

save_model = True
save_period = 10
model_file_template = "model_e{}.pkl"

epochs = 1000
for epoch in range(epochs):  # loop over the dataset multiple times

    for i, data in enumerate(dl_train):
        # get the inputs; data is a list of [inputs, labels]
        inputs = data[0].cuda().float() if torch.cuda.is_available() else data[0].double()
        labels = data[1].cuda() if torch.cuda.is_available() else data[1]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model.forward(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        progress = (i+1 + epoch*len(dl_train)) / (epochs*len(dl_train))
        time_elapsed = datetime.now() - start_time
        time_to_completion = time_elapsed / progress - time_elapsed
        print("Epoch: {}, Batch {}/{}, ETA: ".format(epoch+1, i+1, len(dl_train)) + 
              str(time_to_completion), end='\r')
        
        # save model
        if save_model and epoch+1 % save_period == 0:
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), model_file_template.format(epoch+1))
            else:
                torch.save(model.state_dict(), model_file_template.format(epoch+1))

print('Finished Training')

Epoch: 2, Batch 36/126, ETA: 1 day, 10:08:22.103916

# Tests

In [None]:
model.eval()

In [None]:
from dataset import *
from transforms import *
from torch.utils.data import DataLoader

x_transform = torchvision.transforms.Compose([Resize(size=(150,150)),
                                              ToTensor(make_channel_first=True, div=False),
                                              PadTo2Power(axes=(1,2), k=5, mode='constant')])
ds_test = SimpleDataset("test/test_images.txt", x_transform=x_transform)
dl_test = DataLoader(ds_test, batch_size=1, shuffle=False, num_workers=0)

In [None]:
from dataset import *
from transforms import *
from torch.utils.data import DataLoader

x_transform = torchvision.transforms.Compose([Resize(size=(150,150)),
                                              ToTensor(make_channel_first=True, div=False),
                                              PadTo2Power(axes=(1,2), k=5, mode='constant'),
                                              AssertWidthMajor()])
ds_test = SimpleDataset("bsds500/x_test.txt", x_transform=x_transform)
dl_test = DataLoader(ds_test, batch_size=1, shuffle=False, num_workers=0)

In [None]:
x = next(iter(dl_test)).double()

In [None]:
res = model(x)

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(15,10))
plt.imshow(x[0].numpy()[0])
plt.show()
plt.figure(figsize=(15,10))
plt.imshow(res.cpu().detach().numpy()[0,1])
plt.show()
# plt.figure(figsize=(15,10))
# plt.imshow(y[0])
# plt.show()

In [None]:
m = model.module

In [None]:
m[2].config['return_sequences'] = True

In [None]:
out = m[0](x)
out = m[1](out)
out = m[2](out)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(x[0].cpu().detach().numpy().transpose((1,2,0)))
plt.show()
for i in out[1]:
    o = m[3](i)
    plt.imshow(o[0,1].cpu().detach().numpy())
    plt.show()

In [None]:
# inhibition kernels in first down block fGRU
k = model.module.state_dict()['2.fgru_down.0.params.w_inh']
for i in range(k.shape[0]):
    print(i)
    plt.figure(figsize=(15,10))
    for j in range(k.shape[1]):
        plt.subplot(4,6,j+1)
        plt.imshow(k.cpu().detach().numpy()[i,j])
    plt.show()

In [None]:
# excitation kernels in first down block fGRU
k = model.module.state_dict()['2.fgru_down.0.params.w_exc']
for i in range(k.shape[0]):
    print(i)
    plt.figure(figsize=(15,10))
    for j in range(k.shape[1]):
        plt.subplot(4,6,j+1)
        plt.imshow(k.cpu().detach().numpy()[i,j])
    plt.show()

In [None]:
k = model.module.state_dict()['2.fgru_down.1.params.w_inh']
plt.figure(figsize=(15,10))
for i in range(k.shape[0]):
    plt.subplot(4,7,i+1)
    plt.imshow(k.cpu().detach().numpy()[0,i])
plt.show()

In [None]:
k = model.module.state_dict()['2.fgru_down.1.params.w_exc']
plt.figure(figsize=(15,10))
for i in range(k.shape[0]):
    plt.subplot(4,7,i+1)
    plt.imshow(k.cpu().detach().numpy()[0,i])
plt.show()

In [None]:
x,y = next(iter(dl_train))
print(x.shape, y.shape)
print(x.max(),x.min(),y.max(),y.min())