# This notebook contains the code for training the CRF

In [1]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
import sys
import timeit
import warnings
import logging as logger
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn

from convcrf.convcrf import GaussCRF, default_conf
from utils.synthetic import augment_label

logger.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logger.INFO,
                    stream=sys.stdout)

warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info('Device is {}'.format(device))

2018-11-06 05:32:17,553 INFO Device is cuda


## Load Data

In [3]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from utils.pascal_loader import PascalDatasetLoader

path = '/home/jupyter/projects/ConvCRF/datasets/pascal/VOCdevkit/VOC2012'
pascal_dataset = PascalDatasetLoader(path)
num_classes = pascal_dataset.num_classes

# num_samples = 200
# weights = (1/num_samples)* np.ones(num_samples)

# sampler = WeightedRandomSampler(weights, num_samples, False)

data_loader = DataLoader(pascal_dataset, num_workers=8, shuffle=True)
len(data_loader)

1449

## Define the model

In [4]:
image, _ = iter(data_loader).next()
shape = image.shape[2:4]

config = default_conf
model = GaussCRF(conf=config, shape=(500, 375), nclasses=num_classes)
model.to(device)
model.shape

(500, 375)

## Define the loss function and optimizer

In [5]:
import torch.optim as optim

criterion= nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00005)

## Train the network

In [None]:
for epoch in range(10): 
    running_loss = 0.0
    
    counter = 0
    for i, data in enumerate(data_loader):
        images, labels = data
        
        shape = images.shape[2:4]
        
        if shape[0] == 500 and shape[1] == 375: 
            counter += 1
            images = images.to(device)

            optimizer.zero_grad()

            labels = labels[0]
            unary = augment_label(labels, num_classes=num_classes)
            unary = unary.transpose(2, 0, 1).reshape([1, num_classes, unary.shape[0], unary.shape[1]])

            unary = torch.from_numpy(unary).float().to(device)

            outputs = model(unary=unary, img=images)
            
            labels = labels.to(device)
            
            outputs = outputs.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
            labels = labels.view(-1)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
#             print("loss is: {} and running_loss is: {}".format(loss, running_loss))
            
            if counter % 20 == 0:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, counter, running_loss / 20))
                running_loss = 0.0
        

[1,    20] loss: 1.961
[1,    40] loss: 1.286
[1,    60] loss: 1.256
[1,    80] loss: 1.456
[1,   100] loss: 1.817
