# This notebook contains the code for training the CRF

In [2]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [19]:
import sys
import timeit
import warnings
import logging as logger
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn

from convcrf.convcrf import GaussCRF, default_conf
from utils.synthetic import augment_label
from utils.metrics import Metrics, Averages
from demo import do_crf_inference

logger.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logger.INFO,
                    stream=sys.stdout)

warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info('Device is {}'.format(device))

2018-11-11 11:11:50,046 INFO Device is cpu


## Load Data

In [8]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from utils.pascal_loader import PascalDatasetLoader

path = '/home/jimiolaniyan/Documents/Research/DataSets/PASCAL/VOCdevkit/VOC2012'
traincrf_dataset = PascalDatasetLoader(path, split='traincrf')
val_dataset = PascalDatasetLoader(path, split='val')

num_classes = traincrf_dataset.num_classes

traincrf_loader = DataLoader(traincrf_dataset, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, num_workers=8)
len(val_loader)

1449

## Define the model

In [None]:
config = default_conf
model = GaussCRF(conf=config, shape=(500, 500), nclasses=num_classes)
model.to(device)
model.shape

## Define the loss function and optimizer

In [5]:
import torch.optim as optim

criterion= nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00005)

## Train the network

In [18]:
running_metrics = Metrics(num_classes)
val_loss_avg = Averages()
time_avg = Averages()

best_iou = -100.0

In [None]:
for epoch in range(10): 
    running_loss = 0.0

    for i, data in enumerate(data_loader):
        images, labels = data
        
        model.train()

        images = images.to(device)

        optimizer.zero_grad()

        labels = labels[0]
        unary = augment_label(labels, num_classes=num_classes)
        unary = unary.transpose(2, 0, 1).reshape([1, num_classes, unary.shape[0], unary.shape[1]])

        unary = torch.from_numpy(unary).float().to(device)

        outputs = model(unary=unary, img=images)

        labels = labels.to(device)

        outputs = outputs.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
        labels = labels.view(-1)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

#             print("loss is: {} and running_loss is: {}".format(loss, running_loss))

        if i % 20 == 0:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, counter, running_loss / 20))
            running_loss = 0.0
            
        if i == len(data_loader) - 1:
            model.eval()
            with torch.no_grad():
                for i_val, (images_val, labels_val) in tqdm(enumerate(val_loader)):
                    labels_val = labels_val[0] # remove batch dimension
                    unary = augment_label(labels_val, num_classes=num_classes)
                    predictions = do_crf_inference(images_val, unary, args)
                    
                    predictions = predictions.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
                    labels_val = labels.view(-1)
                    
                    loss = criterion(predictions, labels_val)
                    running_metrics_val.update()
                    
        

[1,    20] loss: 1.961
[1,    40] loss: 1.286
[1,    60] loss: 1.256
[1,    80] loss: 1.456
[1,   100] loss: 1.817
