# This notebook contains the code for training the CRF

In [1]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import timeit
import time
import warnings
import logging as logger
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from dotmap import DotMap

import torch
import torch.nn as nn

from convcrf.convcrf import GaussCRF, default_conf
from utils.synthetic import augment_label
from utils.metrics import Metrics, Averages
from demo import do_crf_inference

logger.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logger.INFO,
                    stream=sys.stdout)

warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info('Device is {}'.format(device))

2018-11-11 16:30:27,371 INFO Device is cuda


## Load Data

In [3]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from utils.pascal_loader import PascalDatasetLoader

path = '/home/jupyter/projects/ConvCRF/datasets/pascal/VOCdevkit/VOC2012'
traincrf_dataset = PascalDatasetLoader(path, split='traincrf')
val_dataset = PascalDatasetLoader(path, split='val')

num_classes = traincrf_dataset.num_classes

traincrf_loader = DataLoader(traincrf_dataset, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, num_workers=8)

## Define the model

In [4]:
config = default_conf
model = GaussCRF(conf=config, shape=(500, 500), nclasses=num_classes)
model.to(device)

GaussCRF(
  (CRF): ConvCRF()
)

## Define the loss function and optimizer

In [5]:
import torch.optim as optim

criterion= nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00005)

## Train the network

In [7]:
args = DotMap()
args.pyinn = False
args.nospeed = False
args.output = None

running_metrics = Metrics(num_classes)
train_loss_avg = Averages()
val_loss_avg = Averages()
time_avg = Averages()

best_iou = -100.0

num_epochs = 10

In [8]:
for epoch in range(num_epochs): 
    running_loss = 0.0

    for i, data in enumerate(traincrf_loader):
        
        start_ts = time.time()
        images, labels = data
        
        model.train()

        images = images.to(device)

        optimizer.zero_grad()

        labels = labels[0]
        unary = augment_label(labels, num_classes=num_classes)
        unary = unary.transpose(2, 0, 1).reshape([1, num_classes, unary.shape[0], unary.shape[1]])

        unary = torch.from_numpy(unary).float().to(device)
        labels = labels.to(device)

        outputs = model(unary=unary, img=images)
        
        outputs = outputs.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
        labels = labels.view(-1)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss_avg.update(loss.item())
        time_avg.update(time.time() - start_ts)

        if i % 20 == 0:
            print('[{:d}, {:d}] Average loss: {:.4f} Average Time: {:.4f} '
                  .format(epoch + 1, i, train_loss_avg.avg, time_avg.avg))
            
            train_loss_avg.reset()
            time_avg.reset()
            
        if i == len(traincrf_loader) - 1:
            logger.info('Doing validation at Epoch: {}'.format(epoch + 1))
            model.eval()
            with torch.no_grad():
                val_len = len(val_loader)
                for i_val, (images_val, labels_val) in enumerate(val_loader):
                    labels_val = labels_val[0] # remove batch dimension
                    unary = augment_label(labels_val, num_classes=num_classes)
                    
                    unary = unary.transpose(2, 0, 1).reshape([1, num_classes, unary.shape[0], unary.shape[1]])
                    unary = torch.from_numpy(unary).float().to(device)
                    
                    images_val = images_val.to(device)
                    labels_val = labels_val.to(device)
                    
                    predictions = model(unary=unary, img=images_val)
                    pred = predictions.transpose(1,2).transpose(2,3).contiguous().view(-1, 21)
                    
                    labels = labels_val.view(-1)
                    val_loss = criterion(pred, labels)
                    
                    preds_np = predictions.data.max(1)[1].cpu().numpy()[0]
                    labels_np = labels_val.data.cpu().numpy()
                    
                    running_metrics.update(labels_np, preds_np)
                    val_loss_avg.update(val_loss.item())
                    
                    if i_val > 0 and i_val % 200 == 0:
                        print("{}/{} Loss: {}: ".format(i_val, val_len))
                
            logger.info("Epoch %d Loss: %.4f" % (epoch + 1, val_loss_avg.avg))
            score, class_iou = running_metrics.get_scores()
            
            print('\nEpoch: {} Validation Suammry'.format(epoch + 1))
            for k, v in score.items():
                print(k, v)
            #   writer.add_scalar('val_metrics/{}'.format(k), v, i+1)
        
            running_metrics.reset()
        
            if score["Mean IoU : \t"] >= best_iou:
                best_iou = score["Mean IoU : \t"]
                logger.info('Found new best_iou: {}'.format(best_iou))
                state = {
                        "epoch": epoch + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "best_iou": best_iou,
                        }
                save_path = os.path.join("/home/jupyter/projects/ConvCRF/datasets",
                                         "best_model.pkl")
                logger.info(save_path)
                torch.save(state, save_path)

[1, 0] Average loss: 1.0660 Average Time: 1.7075 
[1, 20] Average loss: 1.1590 Average Time: 1.5371 
[1, 40] Average loss: 1.0457 Average Time: 1.5706 
[1, 60] Average loss: 1.2082 Average Time: 1.5735 
[1, 80] Average loss: 0.9830 Average Time: 1.5658 
[1, 100] Average loss: 1.1298 Average Time: 1.5721 
[1, 120] Average loss: 1.2845 Average Time: 1.5681 
[1, 140] Average loss: 1.1175 Average Time: 1.5803 
[1, 160] Average loss: 0.8125 Average Time: 1.5774 
[1, 180] Average loss: 1.0220 Average Time: 1.5830 
2018-11-11 16:36:42,817 INFO Doing validation at Epoch: 1
2018-11-11 17:00:34,100 INFO Epoch 1 Loss: 1.1842

Epoch: 1 Validation Suammry
Mean IoU : 	 0.8853168076810695
FreqW Acc : 	 0.9580514563346489
Overall Acc: 	 0.9781673512767426
Mean Acc : 	 0.9374194802572818
2018-11-11 17:00:34,107 INFO Found new best_iou: 0.8853168076810695
2018-11-11 17:00:34,108 INFO /home/jupyter/projects/ConvCRF/datasets/best_model.pkl
[2, 0] Average loss: 1.4249 Average Time: 1.5881 
[2, 20] Average 

Process Process-84:
Process Process-88:
Process Process-85:
Process Process-86:
Process Process-87:
Process Process-82:
Process Process-83:
Process Process-81:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    s

KeyboardInterrupt: 