# This notebook contains the code for training the CRF

In [1]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import timeit
import time
import warnings
import logging as logger
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from dotmap import DotMap

import torch
import torch.nn as nn

from convcrf.convcrf import GaussCRF, default_conf
from utils.synthetic import augment_label
from utils.metrics import Metrics, Averages
from demo import do_crf_inference

logger.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logger.INFO,
                    stream=sys.stdout)

warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info('Device is {}'.format(device))

2018-11-15 12:54:11,686 INFO Device is cuda


## Load Data

In [3]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from utils.pascal_loader import PascalDatasetLoader

path = '/home/jupyter/projects/ConvCRF/datasets/pascal/VOCdevkit/VOC2012'
traincrf_dataset = PascalDatasetLoader(path, split='traincrf')
val_dataset = PascalDatasetLoader(path, split='val')

num_classes = traincrf_dataset.num_classes

traincrf_loader = DataLoader(traincrf_dataset, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, num_workers=8)

## Load stored model parameters

In [4]:
save_path = os.path.join("/home/jupyter/projects/ConvCRF/datasets", "best_model.pkl")
saved_state = torch.load(save_path)

## Define the model

In [5]:
config = default_conf
model = GaussCRF(conf=config, shape=(500, 500), nclasses=num_classes)
model.load_state_dict(saved_state['model_state'])
model.to(device)

GaussCRF(
  (CRF): ConvCRF()
)

## Define the loss function and optimizer

In [6]:
import torch.optim as optim

criterion= nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00005)

## Train the network

In [7]:
args = DotMap()
args.pyinn = False
args.nospeed = False
args.output = None

running_metrics = Metrics(num_classes)
train_loss_avg = Averages()
val_loss_avg = Averages()
time_avg = Averages()

best_iou = saved_state['best_iou']

logger.info('Starting from iou: {}'.format(best_iou))

num_epochs = 50

2018-11-15 12:56:57,031 INFO Starting from iou: 0.8890707921190808


## Define the learning rate decay

In [12]:
# lambda_lr_decay = lambda epoch: ((1 - (epoch / num_epochs)) ** 0.9) ** 2
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr_decay)

In [None]:
for epoch in range(num_epochs): 
    running_loss = 0.0
    actual_epoch = epoch + 1
#     scheduler.step()
    model.train()

    for i, data in enumerate(traincrf_loader):
        iteration = i + 1
        
        start_ts = time.time()
        images, labels = data

        images = images.to(device)

        optimizer.zero_grad()

        labels = labels[0]
        unary = augment_label(labels, num_classes=num_classes)
        unary = unary.transpose(2, 0, 1).reshape([1, num_classes, unary.shape[0], unary.shape[1]])

        unary = torch.from_numpy(unary).float().to(device)
        labels = labels.to(device)

        outputs = model(unary=unary, img=images)
        
        outputs = outputs.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
        labels = labels.view(-1)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss_avg.update(loss.item())
        time_avg.update(time.time() - start_ts)

        if iteration % 20 == 0:
            print('[{:d}, {:d}] Average loss: {:.4f} Average Time: {:.4f} Learning rate: {}'
                  .format(actual_epoch, iteration, train_loss_avg.avg, time_avg.avg, optimizer.param_groups[0]['lr']))
            
            train_loss_avg.reset()
            time_avg.reset()
            
        if actual_epoch % 10 == 0 and iteration == 200:
            logger.info('Doing validation at Epoch: {}'.format(actual_epoch))
            model.eval()
            with torch.no_grad():
                val_len = len(val_loader)
                for i_val, (images_val, labels_val) in enumerate(val_loader):
                    iter_val = i_val + 1
                    labels_val = labels_val[0] # remove batch dimension
                    unary = augment_label(labels_val, num_classes=num_classes)
                    
                    unary = unary.transpose(2, 0, 1).reshape([1, num_classes, unary.shape[0], unary.shape[1]])
                    unary = torch.from_numpy(unary).float().to(device)
                    
                    images_val = images_val.to(device)
                    labels_val = labels_val.to(device)
                    
                    predictions = model(unary=unary, img=images_val)
                    pred = predictions.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
                    
                    labels = labels_val.view(-1)
                    val_loss = criterion(pred, labels)
                    
                    preds_np = predictions.data.max(1)[1].cpu().numpy()[0]
                    labels_np = labels_val.data.cpu().numpy()
                    
                    running_metrics.update(labels_np, preds_np)
                    val_loss_avg.update(val_loss.item())
                    
                    if iter_val % 200 == 0:
                        print("{}/{} Loss: {}: ".format(iter_val, val_len, val_loss_avg.avg))
                        val_loss_avg.reset()
                
            logger.info("Epoch %d Loss: %.4f" % (actual_epoch, val_loss_avg.avg))
            score, class_iou = running_metrics.get_scores()
            
            print('\nEpoch: {} Validation Suammry'.format(actual_epoch))
            for k, v in score.items():
                print(k, v)
            #   writer.add_scalar('val_metrics/{}'.format(k), v, i+1)
        
            running_metrics.reset()
        
            if score["Mean IoU : \t"] >= best_iou:
                best_iou = score["Mean IoU : \t"]
                logger.info('Found new best_iou: {}'.format(best_iou))
                state = {
                        "epoch": actual_epoch,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "best_iou": best_iou,
                        }
                logger.info(save_path)
                torch.save(state, save_path)

[1, 20] Average loss: 0.7369 Average Time: 1.6061 Learning rate: 5e-05
[1, 40] Average loss: 0.8456 Average Time: 1.5885 Learning rate: 5e-05
[1, 60] Average loss: 0.6376 Average Time: 1.5626 Learning rate: 5e-05
[1, 80] Average loss: 0.6051 Average Time: 1.5647 Learning rate: 5e-05
[1, 100] Average loss: 0.6747 Average Time: 1.6116 Learning rate: 5e-05
[1, 120] Average loss: 0.6719 Average Time: 1.5565 Learning rate: 5e-05
[1, 140] Average loss: 0.5067 Average Time: 1.5628 Learning rate: 5e-05
[1, 160] Average loss: 0.7389 Average Time: 1.5850 Learning rate: 5e-05
[1, 180] Average loss: 0.4766 Average Time: 1.5571 Learning rate: 5e-05
[1, 200] Average loss: 0.5366 Average Time: 1.5416 Learning rate: 5e-05
[2, 20] Average loss: 0.8146 Average Time: 1.5246 Learning rate: 5e-05
[2, 40] Average loss: 0.5297 Average Time: 1.5575 Learning rate: 5e-05
[2, 60] Average loss: 0.5168 Average Time: 1.5311 Learning rate: 5e-05
[2, 80] Average loss: 0.5479 Average Time: 1.5591 Learning rate: 5e-05
