In [1]:
# Imports

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.transforms as tr
from model.change_detction_dataset import ChangeDetectionDataset, RandomFlip, RandomRot

# Models
from model.unet import Unet
from model.siamunet_conc import SiamUnet_conc
from model.siamunet_diff import SiamUnet_diff
from model.fresunet import FresUNet
from model.model import Model, ModelConfig

# Other
import numpy as np
from skimage import io

%matplotlib inline
from tqdm import tqdm as tqdm
import time
import warnings
from pprint import pprint


In [2]:
# Global Variables' Definitions
PATH_TO_DATASET = './dataset-lite/'
MODEL_TYPE = 1 # 0-FC-EF | 1-FC-Siam-diff | 2-FC-Siam-conc | 3-FresUNet
GPU_ENABLED = torch.cuda.is_available()
LOAD_TRAINED = False


#Cofniguration
DATA_AUG = True
BATCH_SIZE = 32
PATCH_SIDE = 96
N_EPOCHS = 50
NORMALISE_IMGS = True
TRAIN_STRIDE = int(PATCH_SIDE/2) - 1

In [3]:
# Dataset
if DATA_AUG:
    data_transform = tr.Compose([RandomFlip(), RandomRot()])
else:
    data_transform = None


train_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train = True, stride = TRAIN_STRIDE, transform=data_transform)
weights = torch.FloatTensor(train_dataset.weights)
if GPU_ENABLED:
  print(f'GPU ENABLED')
  weights = weights.cuda()
  
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)
test_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train = False, stride = TRAIN_STRIDE)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)



4999it [00:35, 140.84it/s]


GPU ENABLED


499it [00:03, 136.40it/s]


In [4]:
# 0-FC-EF | 1-FC-Siam-diff | 2-FC-Siam-conc | 3-FresUNet
if MODEL_TYPE == 0:
    net, net_name = Unet(2*3, 2), 'FC-EF'
elif MODEL_TYPE == 1:
    net, net_name = SiamUnet_diff(3, 2), 'FC-Siam-diff'
elif MODEL_TYPE == 2:
    net, net_name = SiamUnet_conc(3, 2), 'FC-Siam-conc'
elif MODEL_TYPE == 3:
    net, net_name = FresUNet(2*3, 2), 'FresUNet'

if GPU_ENABLED:
    net.cuda()

criterion = nn.NLLLoss(weight=weights) # to be used with logsoftmax output

In [5]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Number of trainable parameters:', count_parameters(net))

Number of trainable parameters: 1350146


In [6]:
model_config = ModelConfig(n_epochs=N_EPOCHS, gpu_enabled=GPU_ENABLED)
model = Model(
    model=net,
    model_name=net_name,
    config=model_config,
    train_dataset=train_dataset,
    train_loader=train_loader,
    test_dataset=test_dataset,
    criterion=criterion,
)

In [None]:
if LOAD_TRAINED:
    net.load_state_dict(torch.load('net_final.pth.tar'))
    print('LOAD OK')
else:
    t_start = time.time()
    out_dic = model.train()
    t_end = time.time()
    print(out_dic)
    print('Elapsed time:')
    print(t_end - t_start)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 1 of 50


  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 2 of 50


  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 3 of 50


  0%|          | 0/2500 [00:00<?, ?it/s]

In [None]:
if not LOAD_TRAINED:
    torch.save(model.model.state_dict(), 'net_final.pth.tar')
    print('SAVE OK')

In [None]:
# Inference on test images
def save_test_results(dset):
    for idx, name in tqdm(dset.names.iterrows()):
        name = name[0]
        with warnings.catch_warnings():
            I1, I2, cm = dset.get_img(name)
            I1 = Variable(torch.unsqueeze(I1, 0).float())
            I2 = Variable(torch.unsqueeze(I2, 0).float())

            if GPU_ENABLED:
                I1 = I1.cuda()
                I2 = I2.cuda()

            out = net(I1, I2)
            _, predicted = torch.max(out.data, 1)
            I = np.stack((255*cm,255*np.squeeze(predicted.cpu().numpy()),255*cm),2)
            I = I.astype(np.uint8)
            io.imsave(f'{net_name}-{name}',I, check_contrast=False)

t_start = time.time()
save_test_results(test_dataset)
t_end = time.time()
print('Elapsed time: {}'.format(t_end - t_start))


In [None]:
results_train = model.evaluate(train_dataset)
results_test = model.evaluate(test_dataset)

In [None]:
print('Train Performance:')
print(results_train)
print('Test Performance:')
print(results_test)