In [None]:
val_size_ratio = 0.18
batch_size = 10
num_epochs = 300
lr = 5e-4
weight_decay = 2e-4
checkpoint_path = "./checkpoint/current_checkpoint.pt"
best_model_path = "./best_model/best_model.pt"
path_to_data = "/content/drive/MyDrive/Interpretable Classifier Data/SemSeg Data"

# Importing Libraries

In [None]:
import cv2
import glob
import random
import numpy as np
import os
import shutil
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from model import ENet
from utils import decode_segmap, to_device, IOU, save_ckp, load_ckp
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

# Loading data


In [None]:
# Tranformation for the input images 
transform_img = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

In [None]:
class dataset(Dataset):
  """Custom Dataset class for ease of operation"""

  def __init__(self, images_X, images_Y, transform_img):
    self.data=images_X
    self.labels=images_Y
    self.transform=transform_img

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    x = self.data[index]
    y = self.labels[index]
    x = self.transform(x)
    return x, y

In [None]:
image_list_X = []

for imagename in sorted(os.listdir(path_to_data+'/images/')) : 
  im=cv2.imread(path_to_data+'/images/'+imagename)
  image_list_X.append(im)

images_X = np.array(image_list_X)

In [None]:
image_list_Y = []

for imagename in sorted(os.listdir(path_to_data+'/labels/')): 
  im=cv2.imread(path_to_data+'/labels/'+imagename,0)
  image_list_Y.append(im)

images_Y = np.array(image_list_Y)

In [None]:
images_X_train, images_X_val, images_Y_train, images_Y_val = train_test_split( images_X, images_Y, test_size = val_size_ratio , shuffle=False)

In [None]:
train_data = dataset( images_X = images_X_train, images_Y = images_Y_train, transform_img = transform_img )
train_loader = DataLoader( train_data, batch_size = batch_size )

In [None]:
val_data = dataset( images_X = images_X_val, images_Y = images_Y_val, transform_img = transform_img )
val_loader = DataLoader( val_data, batch_size = 1 )

# Utility Functions

In [None]:
%mkdir checkpoint best_model

In [None]:
def evaluate(model,val_loader,epoch,criterion):
  """To evaluate the validation set after each epoch"""

  acc = 0
  epoch_loss = 0.0
  a = np.zeros((3,), dtype=float)
  criterion = nn.CrossEntropyLoss()

  for img,label in val_loader:

    model.eval()
  
    xb = img.to('cuda')
    label = label.to('cuda').long()

    yb = model(xb)

    loss = criterion(yb, label)
    epoch_loss += loss.item()

    pred = F.softmax(yb, dim=1)
    _, preds  = torch.max(yb, dim=1)

    preds = preds.cpu()
    label = label.cpu()

    temp = (np.array( preds == label )).sum() / ( 256*512 )
    acc += temp
    a += IOU(np.array(preds[0]),np.array(label[0]))

  print("Accuracy = ",acc*100/len(val_loader))
  print("------------")

  return {"Class IOU":a*100/len(val_loader),"Mean IOU":(a*100/len(val_loader)).mean(),"Accuracy":acc*100/len(val_loader),"Loss":epoch_loss/len(val_loader)}

In [None]:
def fit(epochs, lr, model, train_loader, val_loader,criterion,opt_func=torch.optim.Adam,checkpoint_path="/content/checkpoint/current_checkpoint.pt", best_model_path="/content/best_model/best_model.pt"):

    epoch_data = {}
    optim = opt_func( model.parameters(), lr = lr, weight_decay = weight_decay )

    valid_loss_min = np.Inf

    for epoch in range(epochs):
      
        # Training Phase 
        model.train()
        epoch_loss = 0.0

        for step, batch_data in enumerate(train_loader):

            # Get the inputs and labels
            inputs = batch_data[0].to('cuda')
            labels = batch_data[1].to('cuda').long()

            # Forward propagation
            outputs = model(inputs)

            # Loss computation
            loss = criterion(outputs, labels)

            # Backpropagation
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Keep track of loss for current epoch
            epoch_loss += loss.item()

        epoch_data[epoch+1] = evaluate(model,val_loader,epoch+1,criterion)
        val_loss = epoch_data[epoch+1]["Loss"]

        print("Epoch number:",epoch+1," - Training Loss = " , epoch_loss / len(train_loader) , " Valiadtion Loss = ", val_loss )

        checkpoint = {
            'epoch': epoch + 1,
            'valid_loss_min': valid_loss_min,
            'state_dict': model.state_dict(),
            'optimizer': optim.state_dict(),
        }

        save_ckp(checkpoint, False, checkpoint_path, best_model_path)

        if val_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,val_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, best_model_path)
            valid_loss_min = val_loss

        print("------------")

    return epoch_data

# Training

In [None]:
model = ENet(num_classes=3).to('cuda')

In [None]:
def get_class_weights(loader, num_classes, c=1.02):

    _, y= next(iter(loader))
    y_flat = y.flatten()

    each_class = np.bincount(y_flat, minlength=num_classes)
    p_class = each_class / len(y_flat)
    class_weights = 1 / (np.log(c + p_class))

    return class_weights

In [None]:
# Class weights are assigned on the basis of their occurence in an image
# Suppose a particular label is present in most of the image, it will have
# a lower weight, hence contributing less towards back propagation
class_weights = get_class_weights(train_loader, 3)
class_weights = torch.from_numpy(class_weights).float().to('cuda')
criterion = nn.CrossEntropyLoss( weight = class_weights)

In [None]:
epoch_data = fit(num_epochs, lr, model, train_loader, val_loader , criterion, checkpoint_path=checkpoint_path, best_model_path=best_model_path)

# Loading Model

In [None]:
model = ENet(num_classes=3).to('cuda')

In [None]:
# define optimzer
optimizer = torch.optim.Adam( model.parameters(), lr = lr, weight_decay = weight_decay )

# define checkpoint saved path
ckp_path = best_model_path

In [None]:
model, optimizer, start_epoch, valid_loss_min = load_ckp(ckp_path, model, optimizer)

In [None]:
print("start_epoch = ", start_epoch)
print("valid_loss_min = ", valid_loss_min)

# Train again

In [None]:
def train_again(start_epoch,epochs,valid_loss_min,model,optim,train_loader,val_loader,criterion,checkpoint_path="./checkpoint/current_checkpoint.pt", best_model_path="./best_model/best_model.pt"):

    epoch_data = {}

    for epoch in range(start_epoch,epochs):
        # Training Phase 
        model.train()
        epoch_loss = 0.0

        for step, batch_data in enumerate(train_loader):

            # Get the inputs and labels
            inputs = batch_data[0].to('cuda')
            labels = batch_data[1].to('cuda').long()

            # Forward propagation
            outputs = model(inputs)

            # Loss computation
            loss = criterion(outputs, labels)

            # Backpropagation
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Keep track of loss for current epoch
            epoch_loss += loss.item()

        epoch_data[epoch+1] = evaluate(model,val_loader,epoch+1,criterion)
        val_loss = epoch_data[epoch+1]["Loss"]

        print("Epoch number:",epoch+1," - Training Loss = " , epoch_loss / len(train_loader) , " Valiadtion Loss = ", val_loss )

        checkpoint = {
            'epoch': epoch + 1,
            'valid_loss_min': val_loss,
            'state_dict': model.state_dict(),
            'optimizer': optim.state_dict(),
        }

        save_ckp(checkpoint, False, checkpoint_path, best_model_path)

        if val_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,val_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, best_model_path)
            valid_loss_min = val_loss

        print("------------")

    return epoch_data

In [None]:
epoch_data = train_again(start_epoch, epochs, valid_loss_min, model, optimizer, train_loader, val_loader, criterion)

# Evauating

In [None]:
acc = 0
epoch_loss = 0.0
a = np.zeros((3,), dtype=float)
k = 0

for img,label in val_loader:

  model.eval()
  
  xb = img.to('cuda')
  label = label.to('cuda').long()

  yb = model(xb)

  pred = F.softmax(yb, dim=1)             
  preds = torch.argmax(pred, dim=1).squeeze(1)

  preds = preds.cpu()
  label = label.cpu()

  temp = (np.array( preds == label )).sum() / ( 256*512 )
  acc += temp

  a += IOU(np.array(preds[0]),np.array(label[0]))

  row, col = 1, 3
  fig, axs = plt.subplots(row, col, figsize=(21, 10))
  fig.tight_layout()

  axs[0].imshow(cv2.cvtColor(images_X_val[k], cv2.COLOR_BGR2RGB))
  axs[0].set_title('Original')

  x = decode_segmap(preds,nc=3)
  x = x.squeeze(0)

  axs[1].imshow(cv2.cvtColor(x, cv2.COLOR_BGR2RGB))
  axs[1].set_title('Output')

  idx = x == 0
  x[idx] = images_X_val[k][idx]

  added_image = cv2.addWeighted(images_X_val[k],0.2,x,0.9,0)

  axs[2].imshow(cv2.cvtColor(added_image, cv2.COLOR_BGR2RGB))
  axs[2].set_title('Overlay')

  k = k + 1

print("Class IoU",(a/len(val_loader)).mean(),"Accuracy",acc/len(val_loader))