In [None]:
import torch
import torch.nn as nn
import torchvision

import numpy as np
import matplotlib.pyplot as plt

# BSDS500

In [None]:
# prepare dataset
# NOTE: remember to adjust PadTo2Power's k value according to gamma-net depth
# NOTE: data stays in CPU until right before forward pass (i.e. no need for memory pinning, etc)

from dataset import *
from transforms import *
from torch.utils.data import DataLoader

x_transform = torchvision.transforms.Compose([ToTensor(make_channel_first=True, div=True),
                                              PadTo2Power(axes=(1,2), k=5, mode='constant'),
                                              AssertWidthMajor()])
y_transform = torchvision.transforms.Compose([ToTensor(make_channel_first=False, float_out=False),
                                              PadTo2Power(axes=(0,1), k=5, mode='constant'),
                                              AssertWidthMajor()])
ds_train = SimpleDataset("bsds500/x_train.txt",
                         "bsds500/y_c_train.txt",
                         x_transform=x_transform, y_transform=y_transform)
dl_train = DataLoader(ds_train, batch_size=1, shuffle=True, num_workers=8)

In [None]:
from gammanet import GammaNet

gammanet_config = {
    'in_channels': 24,
    'return_sequences': False,
    'num_filters': [24, 28, 36, 48, 64],
    'conv_kernel_size': [3, 3, 3, 3, 3],
    'conv_blocksize': [1, 1, 1, 1, 1],
    'conv_normtype': 'instancenorm',
    'conv_dropout_p': 0.2,  # 0.2
    'conv_residual': False,
    'fgru_hidden_size': [24, 28, 36, 48, 64],
    'fgru_kernel_size': [9,  7,  5,  3,  1,  1,  1,  1,  1],
    'fgru_timesteps': 2,
    'fgru_normtype': 'instancenorm',
    'fgru_channel_sym': True,
    'upsample_mode': 'bilinear',
    'upsample_all2all': True,
}

model = nn.Sequential(
    nn.Conv2d(3, 24, 3, padding=1),
    nn.Conv2d(24, 24, 3, padding=1),
    GammaNet(gammanet_config),
    nn.Conv2d(24, 2, 5, padding=2)
)

model = nn.DataParallel(model)

if torch.cuda.is_available():
    model = model.cuda()
else:
    model = model.double()

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda() if torch.cuda.is_available() else criterion
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
import time
start_time = time.time()

epochs = 2
for epoch in range(epochs):  # loop over the dataset multiple times
    
    start_time_epoch = time.time()

    for i, data in enumerate(dl_train):
        # get the inputs; data is a list of [inputs, labels]
        inputs = data[0].cuda() if torch.cuda.is_available() else data[0].double()
        labels = data[1].cuda() if torch.cuda.is_available() else data[1]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model.forward(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        progress = (i+1 + epoch*len(dl_train)) / (epochs*len(dl_train))
        time_elapsed = time.time() - start_time
        time_to_completion = time_elapsed / progress
        print("{}, {}/{}, ETA: ".format(epoch+1, i+1, len(dl_train)) + 
              time.strftime("%H:%M:%S", time.gmtime(time_to_completion)), end='\r')

print('Finished Training')

In [None]:
torch.save(model.state_dict(), "model_e2.pkl")