<a href="https://colab.research.google.com/github/goromal/FANet_Evaluation/blob/main/fanet_eval_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FANet Evaluation - 6.862 Project



*   David Elatov
*   Dayne Howard
*   Andrew Torgesen



## Setup

Link Google Drive to Github repo, install needed software, access training and testing data.

### CUDA

1. Go to **Menu > Runtime > Change runtime type** and make sure that GPU is enabled.
2. Run the commands below to ensure that the GPU (and CUDA) is operational.

In [None]:
! nvidia-smi

Tue Apr  6 14:06:56 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import torch
torch.cuda.is_available()

True

### Repository

Mount drive, clone repo, navigate to repo, and change working directory to access repo files. **Run ONCE per computing session.**

In [None]:
import os
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/MyDrive/
if not os.path.exists('FANet_Evaluation'):
  print('Repo not present. Cloning...')
  ! git clone https://github.com/goromal/FANet_Evaluation.git
  %cd FANet_Evaluation/
else:
  print('Repo already present. Updating...')
  %cd FANet_Evaluation/
  ! git pull origin main
from model.test_model import *
test()

Mounted at /content/gdrive
/content/gdrive/MyDrive
Repo already present. Updating...
/content/gdrive/MyDrive/FANet_Evaluation
From https://github.com/goromal/FANet_Evaluation
 * branch            main       -> FETCH_HEAD
Already up to date.
SUCCESS


### Python Packages

In [None]:
! pip install oyaml
! pip install torchstat

Collecting oyaml
  Downloading https://files.pythonhosted.org/packages/37/aa/111610d8bf5b1bb7a295a048fc648cec346347a8b0be5881defd2d1b4a52/oyaml-1.0-py2.py3-none-any.whl
Installing collected packages: oyaml
Successfully installed oyaml-1.0
Collecting torchstat
  Downloading https://files.pythonhosted.org/packages/bc/fe/f483b907ca80c90f189cd892bb2ce7b2c256010b30314bbec4fc17d1b5f1/torchstat-0.0.7-py3-none-any.whl
Installing collected packages: torchstat
Successfully installed torchstat-0.0.7


## FANet-18 Initial FPS Testing

Ensure that FANet implementation can be loaded and accessed via the GPU.

In [None]:
import torch
import oyaml as yaml
from torchstat import stat
import time,os

from model.fanet import FANet

In [None]:
network = FANet(backbone='resnet18')
network.cuda()
network.eval()
t_cnt = 0.0
with torch.no_grad():
  input = torch.rand((1,3,1024,2048)).cuda()
  
  torch.cuda.synchronize()
  x = network(input)
  x = network(input)
  
  torch.cuda.synchronize()
  torch.cuda.synchronize()
  start_ts = time.time()

  for i in range(100):
    x = network(input)
  
  torch.cuda.synchronize()
  end_ts = time.time()

  t_cnt = end_ts-start_ts

print('FANet-18 Performance (FPS): %f' % (100.0/t_cnt))

FANet-18 Performance (FPS): 40.201780


# Evaluation Pipeline Training Decomposition

This is how the evaluation pipeline trains a model on CityScapes, from start to finish.

In [None]:
import sys
sys.path.insert(0, '/content/gdrive/MyDrive/FANet_Evaluation/evaluation') # so that the evaluation pipeline's internal imports work
import tensorflow as tf
from evaluation.utils.params import get_params
from evaluation.utils.dirs import create_exp_dirs
from evaluation.utils.misc import timeit
import scipy.misc as misc # for image resizing
from tqdm import tqdm # progress bar visualization
import time # for timing

In [None]:
# Usable Models
from evaluation.models.dilation_mobilenet import DilationMobileNet # << using this network as an example
# etc...there's like 15 of them

# Metrics for measuring performance (mIoU, etc.)
from evaluation.metrics.metrics import Metrics

In [None]:
# Argument class to instantiate a model
class ModelTrainArgs(object):
    def __init__(self):
        # MODEL ARGS
        self.img_width = 1024
        self.img_height = 512
        self.num_channels = 3 # 3 channels for color images
        # data dir contains pre-processed weights.npy, X_train.npy, Y_train.npy, X_val.npy, Y_val.npy
        self.data_dir = '/content/gdrive/MyDrive/full_cityscapes_res' # DATA LOCATED IN "My Drive/full_cityscapes_res"
        self.weighted_loss = True
        self.batch_size = 4
        self.learning_rate = 0.0001
        
        # TRAIN ARGS
        self.data_mode = "experiment"
        self.num_classes = 20 # for CityScapes
        self.test_every = 10 # validation performed every 10 training epochs
        

In [None]:
# Set parameters
args = ModelTrainArgs()

# Reset the graph
tf.reset_default_graph()

# Create the sess
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True))

# Create Model class and "build" it
with sess.as_default():
    with tf.variable_scope('network') as scope:
        model = DilationMobileNet(args)
        model.build()

# Instantiate training components
sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
train_data = None
train_data_len = None
val_data = None
val_data_len = None
num_iterations_training_per_epoch = None
num_iterations_validation_per_epoch = None
metrics = Metrics(args.num_classes)

# Training batch generator
def train_generator():
    global args, num_iterations_training_per_epoch, train_data_len, train_data
    start = 0
    idx = np.random.choice(train_data_len, num_iterations_training_per_epoch * args.batch_size, replace=True)
    while True:
        # select the mini_batches
        mask = idx[start:start + args.batch_size]
        x_batch = train_data['X'][mask]
        y_batch = train_data['Y'][mask]

        # update start idx
        start += args.batch_size

        yield x_batch, y_batch

        if start >= train_data_len:
            return

# Load training and validation data
print("Loading Training data..")
train_data_wrongsize = {'X': np.load(args.data_dir + "X_train.npy"), 'Y': np.load(args.data_dir + "Y_train.npy")}
X = []
Y = []
train_data = dict()
for i in range(train_data_wrongsize['X'].shape[0]):
    X.append(misc.imresize(train_data_wrongsize['X'][i, ...], (args.img_height, args.img_width)))
    Y.append(misc.imresize(train_data_wrongsize['Y'][i, ...], (args.img_height, args.img_width), 'nearest'))
train_data['X'] = np.asarray(X)
train_data['Y'] = np.asarray(Y)
train_data_len = train_data['X'].shape[0]
num_iterations_training_per_epoch = (train_data_len + args.batch_size - 1) // args.batch_size
print("Train-shape-x -- " + str(train_data['X'].shape) + " " + str(train_data_len))
print("Train-shape-y -- " + str(train_data['Y'].shape))
print("Num of iterations on training data in one epoch -- " + str(num_iterations_training_per_epoch))
print("Training data is loaded")

print("Loading Validation data..")
val_data = {'X': np.load(args.data_dir + "X_val.npy"), 'Y': np.load(args.data_dir + "Y_val.npy")}
val_data['Y_large'] = self.val_data['Y']
val_data_len = val_data['X'].shape[0] - val_data['X'].shape[0] % args.batch_size
num_iterations_validation_per_epoch = (val_data_len + args.batch_size - 1) // args.batch_size
print("Val-shape-x -- " + str(val_data['X'].shape) + " " + str(val_data_len))
print("Val-shape-y -- " + str(val_data['Y'].shape))
print("Num of iterations on validation data in one epoch -- " + str(num_iterations_validation_per_epoch))
print("Validation data is loaded")

# Train
print("Training mode will begin NOW ..")
for cur_epoch in range(self.model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):

    # init tqdm and get the epoch value
    tt = tqdm(train_generator(), total=num_iterations_training_per_epoch, desc="epoch-" + str(cur_epoch) + "-")

    # init the current iterations
    cur_iteration = 0

    # init acc and loss lists
    loss_list = []
    acc_list = []

    # loop by the number of iterations
    for x_batch, y_batch in tt:

        # get the cur_it for the summary
        cur_it = model.global_step_tensor.eval(sess)

        # Feed data into the network
        feed_dict = {model.x_pl: x_batch,
                     model.y_pl: y_batch,
                     model.is_training: True}

        # Run the feed forward but the last iteration finalize what you want to do
        if cur_iteration < num_iterations_training_per_epoch - 1:

            # run the feed_forward
            _, loss, acc, summaries_merged = sess.run(
                        [model.train_op, model.loss, model.accuracy, model.merged_summaries],
                        feed_dict=feed_dict)
            # log loss and acc
            loss_list += [loss]
            acc_list += [acc]

        else:
            # run the feed_forward
            _, loss, acc, summaries_merged, segmented_imgs = self.sess.run(
                            [model.train_op, model.loss, model.accuracy,
                             model.merged_summaries, model.segmented_summary],
                             feed_dict=feed_dict)

            # log loss and acc
            loss_list += [loss]
            acc_list += [acc]
            total_loss = np.mean(loss_list)
            total_acc = np.mean(acc_list)

            # Update the Global step
            model.global_step_assign_op.eval(session=sess, feed_dict={model.global_step_input: cur_it + 1})

            # Update the Cur Epoch tensor
            # it is the last thing because if it is interrupted it repeat this
            model.global_epoch_assign_op.eval(session=sess, feed_dict={model.global_epoch_input: cur_epoch + 1})

            # print in console
            tt.close()
            print("epoch-" + str(cur_epoch) + "-" + "loss:" + str(total_loss) + "-" + " acc:" + str(total_acc)[:6])

            # Break the loop to finalize this epoch
            break

        # Update the Global step
        model.global_step_assign_op.eval(session=sess, feed_dict={model.global_step_input: cur_it + 1})

        # update the cur_iteration
        cur_iteration += 1

    # Test the model on validation set
    if cur_epoch % args.test_every == 0:
        step = model.global_step_tensor.eval(sess)
        epoch = model.global_epoch_tensor.eval(sess)
        print("Validation at step:" + str(step) + " at epoch:" + str(epoch) + " ..")

        # init tqdm and get the epoch value
        tt = tqdm(range(num_iterations_validation_per_epoch), total=num_iterations_validation_per_epoch,
                  desc="Val-epoch-" + str(epoch) + "-")

        # init acc and loss lists
        loss_list = []
        acc_list = []
        inf_list = []

        # idx of minibatch
        idx = 0

        # reset metrics
        metrics.reset()

        # get the maximum iou to compare with and save the best model
        max_iou = model.best_iou_tensor.eval(self.sess)

        # loop by the number of iterations
        for cur_iteration in tt:
            # load minibatches
            x_batch = val_data['X'][idx:idx + args.batch_size]
            y_batch = val_data['Y'][idx:idx + args.batch_size]

            # update idx of minibatch
            idx += args.batch_size

            # Feed this variables to the network
            feed_dict = {model.x_pl: x_batch,
                         model.y_pl: y_batch,
                         model.is_training: False}

            # Run the feed forward but the last iteration finalize what you want to do
            if cur_iteration < num_iterations_validation_per_epoch - 1:

                start = time.time()
                # run the feed_forward

                out_argmax, loss, acc, summaries_merged = sess.run(
                    [model.out_argmax, model.loss, model.accuracy, model.merged_summaries],
                    feed_dict=feed_dict)

                end = time.time()
                # log loss and acc
                loss_list += [loss]
                acc_list += [acc]
                inf_list += [end - start]

                # log metrics
                metrics.update_metrics_batch(out_argmax, y_batch)

            else:
                start = time.time()
                # run the feed_forward
                out_argmax, acc, segmented_imgs = sess.run(
                        [test_model.out_argmax, test_model.accuracy, test_model.segmented_summary],
                        feed_dict=feed_dict)

                end = time.time()
                # log loss and acc
                acc_list += [acc]
                inf_list += [end - start]
                # log metrics
                metrics.update_metrics_batch(out_argmax, y_batch)
                # mean over batches
                total_acc = np.mean(acc_list)
                mean_iou = metrics.compute_final_metrics(num_iterations_validation_per_epoch)
                mean_iou_arr = metrics.iou
                mean_inference = str(np.mean(inf_list)) + '-seconds'

                # print in console
                tt.close()
                print("Val-epoch-" + str(epoch) + "-" +
                      "acc:" + str(total_acc)[:6] + "-mean_iou:" + str(mean_iou))
                print("Last_max_iou: " + str(max_iou))
                if mean_iou > max_iou:
                    print("This validation got a new best iou. so we will save this one")
                    # Set the new maximum
                    model.best_iou_assign_op.eval(session=sess, feed_dict={model.best_iou_input: mean_iou})
                else:
                    print("Hmm, not the best validation epoch :/..")
                break

                # Break the loop to finalize this epoch
        

# Finish session
self.sess.close()

# FANet Testing (Single-Frame Version)

In [None]:
# Dayne Howard, David Elatov, Andrew Torgesen, MIT, 6.862, Spring 2021
# Credit for much of this code goes to MIT, 
# It was developed by staff for student use in course 6.036/6.862

# We'll use the PyTorch Framework
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchstat import stat
import torch.utils.model_zoo as model_zoo

import sys
sys.path.insert(0, '/content/gdrive/MyDrive/FANet_Evaluation/evaluation') # so that the evaluation pipeline's internal imports work
import oyaml as yaml
from PIL import Image # for image resizing

# Metrics for measuring performance (mIoU, etc.)
from evaluation.metrics.metrics import Metrics

# For displaying images later
import numpy as np
import matplotlib.pyplot as plt

# Set a random seed for predictable behavior
torch.manual_seed(6036)

In [None]:
# We'll create a Dataset class to use with PyTorch's Built-In Dataloaders
class FANetDataset(Dataset):
    '''
    A custom dataset class to use with PyTorch's built-in dataloaders.
    This will make feeding images to our models much easier downstream.

    data: np.arrays downloaded from Keras' databases
    vectorize: if True, outputed image data will be (X,)
                   if False, outputed image data will be (H,W)
    '''
    def __init__(self, data, labels, vectorize=False):
        self.data = data
        self.labels = labels
        self.vectorize = vectorize
    
    def __getitem__(self, idx):
        image_data = self.data[idx,:]
# # TO DO: Put in the Channels, height and width of the images used
# 		    C = 3
#         H = 512
#         W = 1028

#         image_data = image_data.reshape((C,H,W))
#         if self.vectorize:
#             image_data = image_data.reshape((H*W,))
################ The line below might need [idx,:] or something
        image_label = self.labels[idx]
#########################################################################
        return image_data, image_label

    def __len__(self):
        return self.data.shape[0]


# Argument class to instantiate a model
class ModelTrainArgs(object):
    def __init__(self):
        # MODEL ARGS
        self.img_width = 1024
        self.img_height = 512
        self.num_channels = 3 # 3 channels for color images
        # data dir contains pre-processed weights.npy, X_train.npy, Y_train.npy, X_val.npy, Y_val.npy
        self.data_dir = '/content/gdrive/MyDrive/full_cityscapes_res/' # DATA LOCATED IN "My Drive/full_cityscapes_res"
        self.weighted_loss = True
        self.batch_size = 10
        self.learning_rate = 0.0001
        
        # TRAIN ARGS
        self.data_mode = "experiment"
        self.num_classes = 20 # for CityScapes
        self.test_every = 10 # validation performed every 10 training epochs

# Set parameters
args = ModelTrainArgs()

In [None]:
########################## TO DO: Download Data Here
#Split into training, validation, and testing data
#train_data
#val_data
#test_data
#EXAMPLE:
# # Download MNIST Data
# (mnist_train, _), (mnist_test, _) = mnist.load_data()

# # Load data as Numpy arrays of size (#datapoints, 28*28=784)
# mnist_train = mnist_train.astype('float32') / 255.
# mnist_test = mnist_test.astype('float32') / 255.
# mnist_train = mnist_train.reshape((len(mnist_train), np.prod(mnist_train.shape[1:])))
# mnist_test = mnist_test.reshape((len(mnist_test), np.prod(mnist_test.shape[1:])))

# # Split test data into a test and validation set:
# val_data = mnist_test[:(mnist_test.shape[0]//2),:]
# test_data = mnist_test[(mnist_test.shape[0]//2):,:]
# train_data = mnist_train

# Download training data. Resize the targets to match FANet output
print("Loading Training data..")
###train_data_wrongsize = {'X': np.load(args.data_dir + "x_train_small_set.npy"), 'Y': np.load(args.data_dir + "y_train_small_set.npy")}
###X = []
Y = []
X = np.load(args.data_dir + "x_train_small_set.npy")
Y_wrongsize = np.load(args.data_dir + "y_train_small_set.npy")
###train_data = dict()
###train_labels = dict()
for i in range(Y_wrongsize.shape[0]):
    #On the next two lines, it looks like the dimensions are flipped, but that
    #is because of how PIL.Image defines their images (width,height) instead
    #of numpy's (rows,columns)
    ###X.append(np.array(Image.fromarray(train_data_wrongsize['X'][i, ...]).resize((1024,512),Image.NEAREST)))
    Y.append(np.array(Image.fromarray(Y_wrongsize[i, ...]).resize((128,64),Image.NEAREST)))

train_data = (np.asarray(np.transpose(X,axes=(0,3,1,2)))/255.0).astype(np.float) #change to float and normalize RGB
train_labels = np.asarray(Y)
print("Done loading training data")

Loading Training data..
Done loading training data


In [None]:
# Download validation data. Resize the targets to match FANet output
print("Loading Validation data..")
###val_data_wrongsize = {'X': np.load(args.data_dir + "x_val_small_set.npy"), 'Y': np.load(args.data_dir + "y_val_small_set.npy")}
###X = []
Y = []
X = np.load(args.data_dir + "x_val_small_set.npy")
Y_wrongsize = np.load(args.data_dir + "y_val_small_set.npy")
###val_data = dict()
###val_labels = dict()
for i in range(Y_wrongsize.shape[0]):
    ###X.append(np.array(Image.fromarray(val_data_wrongsize['X'][i, ...]).resize((1024,512),Image.NEAREST)))
    ###Y.append(np.array(Image.fromarray(val_data_wrongsize['Y'][i, ...]).resize((128,64),Image.NEAREST)))
    Y.append(np.array(Image.fromarray(Y_wrongsize[i, ...]).resize((128,64),Image.NEAREST)))

val_data = (np.asarray(np.transpose(X,axes=(0,3,1,2)))/255.0).astype(np.float) #change to float and normalize RGB
val_labels = np.asarray(Y)
print("Done loading validation data")

Loading Validation data..
Done loading validation data


In [None]:
# Display dataset information
print("Downloaded the following data:")
print(f"train_data has shape {train_data.shape}")
print(f"train_labels has shape {train_labels.shape}")
print(f"val_data has shape {val_data.shape}")
print(f"val_labels has shape {val_labels.shape}")

# Create Dataset objects for each of our train/val/test sets
train_dataset = FANetDataset(train_data, train_labels)
val_dataset = FANetDataset(val_data, val_labels)
#test_dataset = FANetDataset(test_data, test_labels)

# Create a PyTorch dataloader for each train/val/test set
# Batch Size: 16 was what FANet used in their paper
train_loader = DataLoader(train_dataset, batch_size= args.batch_size)
val_loader = DataLoader(val_dataset, batch_size= args.batch_size)
#test_loader = DataLoader(test_dataset, batch_size= args.batch_size)

# Display dataloader info
print("Created the following Dataloaders:")
print(f"train_loader has {len(train_loader)} batches of training data")
print(f"val_loader has {len(val_loader)} batches of validation data")
#print(f"test_loader has {len(test_loader)} batches of testing data")

Downloaded the following data:
train_data has shape (200, 3, 512, 1024)
train_labels has shape (200, 64, 128)
val_data has shape (50, 3, 512, 1024)
val_labels has shape (50, 64, 128)
Created the following Dataloaders:
train_loader has 20 batches of training data
val_loader has 5 batches of validation data


In [None]:
def to_one_hot(tensor,device,nClasses=20):
    n,h,w = tensor.size()
    one_hot = torch.zeros(n,nClasses,h,w).to(device).scatter_(1,tensor.view(n,1,h,w),1)
    return one_hot

class mIoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True, n_classes=20):
        super(mIoULoss, self).__init__()
        self.classes = n_classes

    def forward(self, inputs, target_oneHot):
    	# inputs => N x Classes x H x W
    	# target_oneHot => N x Classes x H x W

    	N = inputs.size()[0]

    	# predicted probabilities for each pixel along channel
    	inputs = F.softmax(inputs,dim=1)
    	
    	# Numerator Product
    	inter = inputs * target_oneHot
    	## Sum over all pixels N x C x H x W => N x C
    	inter = inter.view(N,self.classes,-1).sum(2)

    	#Denominator 
    	union= inputs + target_oneHot - (inputs*target_oneHot)
    	## Sum over all pixels N x C x H x W => N x C
    	union = union.view(N,self.classes,-1).sum(2)

    	loss = inter/union

    	## Return average loss over classes and batch
    	return loss.mean()

In [None]:
def train(model, device, train_loader, optimizer, val_loader=None):
    '''
    Function for training our networks. One call to train() performs a single
    epoch for training.
    model: an instance of our model
    device: either "cpu" or "cuda", depending on if you're running with GPU support
    train_loader: the dataloader for the training set
    optimizer: optimizer used for training (the optimizer implements SGD)
    val_loader: (optional) validation set to include 
    '''

    # Set the model to training mode.
    model.train()

    #we'll keep adding the loss of each batch to total_loss, so we can calculate
    #the average loss at the end of the epoch.
    total_loss = 0
    mIoU       = 0

    # We'll iterate through each batch. One call of train() trains for 1 epoch.
    # batch_idx: an integer representing which batch number we're on
    # input: a pytorch tensor representing a batch of input images.
    for batch_idx, (input,target) in enumerate(train_loader):
        # This line sends data to GPU if you're using a GPU
        input = input.to(device, dtype=torch.float)
        target = target.type(torch.LongTensor).to(device)

        # initialze the optimizer (the optimizer implements SGD)
        optimizer.zero_grad()

        # feed our input through the network
        output = model.forward(input)
        loss_function = nn.CrossEntropyLoss()
        loss_value = loss_function(output,target)

        # Perform backprop
        loss_value.backward()
        optimizer.step()

        #accumulate loss to later calculate the average
        total_loss += loss_value

        #Calculate the mIoU
        mIoU_function = mIoULoss()
        target_onehot = to_one_hot(target, device)
        mIoU += mIoU_function(output, target_onehot)

    return (total_loss.item()/len(train_loader), mIoU.item()/len(train_loader))

def test(model, device, val_loader):
    '''
    Function for testing our models. One call to test() runs through every
    datapoint in our dataset once.
    model: an instance of our model
    device: either "cpu" or "cuda:0", depending on if you're running with GPU support
    val_loader: the dataloader for the data to run the model on
    '''
    # set model to evaluation mode
    model.eval()

    # we'll keep track of total loss to calculate the average later
    test_loss = 0
    mIoU      = 0

    #don't perform backprop if testing
    with torch.no_grad():
        # iterate thorugh each test image
        for (input,target) in val_loader:

            # send input image to GPU if using GPU
            input = input.to(device, dtype=torch.float)
            target = target.type(torch.LongTensor).to(device)

            # run input through our model
            output = model(input)

            loss_function = nn.CrossEntropyLoss() 
            loss_value = loss_function(output,target) 
            test_loss += loss_value

            #Calculate the mIoU
            mIoU_function = mIoULoss()
            target_onehot = to_one_hot(target, device)
            mIoU += mIoU_function(output, target_onehot)

    return (test_loss.item()/len(val_loader), mIoU.item()/len(val_loader))

In [None]:
#####################  FANet Model  ######################################
class BatchNorm2d(nn.BatchNorm2d):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, num_features, activation='none'):
        super(BatchNorm2d, self).__init__(num_features=num_features)
        if activation == 'leaky_relu':
            self.activation = nn.LeakyReLU()
        elif activation == 'none':
            self.activation = lambda x:x
        else:
            raise Exception("Accepted activation: ['leaky_relu']")

    def forward(self, x):
        return self.activation(x)

up_kwargs = {'mode': 'bilinear', 'align_corners': True}


class FANet(nn.Module):
    def __init__(self,
                 nclass=20,
                 backbone='resnet18',
                 norm_layer=BatchNorm2d):
        super(FANet, self).__init__()

        self.norm_layer = norm_layer
        self._up_kwargs = up_kwargs
        self.nclass = nclass
        self.backbone = backbone
        if backbone == 'resnet18':
            self.expansion = 1
            self.resnet = Resnet18(norm_layer=norm_layer)
        elif backbone == 'resnet34':
            self.expansion = 1
            self.resnet = Resnet34(norm_layer=norm_layer)
        elif backbone == 'resnet50':
            self.expansion = 4
            self.resnet = Resnet50(norm_layer=norm_layer)
        elif backbone == 'resnet101':
            self.expansion = 4
            self.resnet = Resnet101(norm_layer=norm_layer)
        elif backbone == 'resnet152':
            self.expansion = 4
            self.resnet = Resnet152(norm_layer=norm_layer)
        else:
            raise RuntimeError('unknown backbone: {}'.format(backbone))

        self.fam_32 = FastAttModule(512*self.expansion,256,128,norm_layer=norm_layer)
        self.fam_16 = FastAttModule(256*self.expansion,256,128,norm_layer=norm_layer)
        self.fam_8 = FastAttModule(128*self.expansion,256,128,norm_layer=norm_layer)
        self.fam_4 = FastAttModule(64*self.expansion,256,128,norm_layer=norm_layer)

        self.clslayer  = FPNOutput(256, 256, nclass,norm_layer=norm_layer)

    def forward(self, x, lbl=None):

        _, _, h, w = x.size()

        feat4, feat8, feat16, feat32 = self.resnet(x)

        upfeat_32, smfeat_32 = self.fam_32(feat32,None,True,True)
        upfeat_16, smfeat_16 = self.fam_16(feat16,upfeat_32,True,True)
        upfeat_8 = self.fam_8(feat8,upfeat_16,True,False)
        smfeat_4 = self.fam_4(feat4,upfeat_8,False,True)

        x = self._upsample_cat(smfeat_16, smfeat_4)

        outputs = self.clslayer(x)
        
        return outputs

    def _upsample_cat(self, x1, x2):
        '''Upsample and concatenate feature maps.
        '''
        _,_,H,W = x2.size()
        x1 = F.interpolate(x1, (H,W), **self._up_kwargs)
        x = torch.cat([x1,x2],dim=1)
        return x


class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, norm_layer=None, activation='leaky_relu',*args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        self.norm_layer = norm_layer
        if self.norm_layer is not None:
            self.bn = norm_layer(out_chan, activation=activation)
        else:
            self.bn =  lambda x:x

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class FPNOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, norm_layer=None, *args, **kwargs):
        super(FPNOutput, self).__init__()
        self.norm_layer = norm_layer
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1, norm_layer=norm_layer)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x


class FastAttModule(nn.Module):
    def __init__(self, in_chan, mid_chn=256, out_chan=128, norm_layer=None, *args, **kwargs):
        super(FastAttModule, self).__init__()
        self.norm_layer = norm_layer
        self._up_kwargs = up_kwargs
        mid_chn = int(in_chan/2)        
        self.w_qs = ConvBNReLU(in_chan, 32, ks=1, stride=1, padding=0, norm_layer=norm_layer, activation='none')

        self.w_ks = ConvBNReLU(in_chan, 32, ks=1, stride=1, padding=0, norm_layer=norm_layer, activation='none')

        self.w_vs = ConvBNReLU(in_chan, in_chan, ks=1, stride=1, padding=0, norm_layer=norm_layer)

        self.latlayer3 = ConvBNReLU(in_chan, in_chan, ks=1, stride=1, padding=0, norm_layer=norm_layer)

        self.up = ConvBNReLU(in_chan, mid_chn, ks=1, stride=1, padding=1, norm_layer=norm_layer)
        self.smooth = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1, norm_layer=norm_layer)

    def forward(self, feat, up_fea_in,up_flag, smf_flag):

        query = self.w_qs(feat)
        key   = self.w_ks(feat)
        value = self.w_vs(feat)

        N,C,H,W = feat.size()

        query_ = query.view(N,32,-1).permute(0, 2, 1)
        query = F.normalize(query_, p=2, dim=2, eps=1e-12)

        key_   = key.view(N,32,-1)
        key   = F.normalize(key_, p=2, dim=1, eps=1e-12)

        value = value.view(N,C,-1).permute(0, 2, 1)

        f = torch.matmul(key, value)
        y = torch.matmul(query, f)
        y = y.permute(0, 2, 1).contiguous()

        y = y.view(N, C, H, W)
        W_y = self.latlayer3(y)
        p_feat = W_y + feat

        if up_flag and smf_flag:
            if up_fea_in is not None:
                p_feat = self._upsample_add(up_fea_in, p_feat)
            up_feat = self.up(p_feat)
            smooth_feat = self.smooth(p_feat)
            return up_feat, smooth_feat

        if up_flag and not smf_flag:
            if up_fea_in is not None:
                p_feat = self._upsample_add(up_fea_in, p_feat)
            up_feat = self.up(p_feat)
            return up_feat

        if not up_flag and smf_flag:
            if up_fea_in is not None:
                p_feat = self._upsample_add(up_fea_in, p_feat)
            smooth_feat = self.smooth(p_feat)
            return smooth_feat

    def _upsample_add(self, x, y):
        '''Upsample and add two feature maps.
        '''
        _,_,H,W = y.size()
        return F.interpolate(x, (H,W), **self._up_kwargs) + y

In [None]:
model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth'}

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                     padding=0, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_chan, out_chan, stride=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        self.norm_layer = norm_layer
        self.conv1 = conv3x3(in_chan, out_chan, stride)
        self.bn1 = norm_layer(out_chan, activation='leaky_relu')
        self.conv2 = conv3x3(out_chan, out_chan)
        self.bn2 = norm_layer(out_chan, activation='none')
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan,
                          kernel_size=1, stride=stride, bias=False),
                norm_layer(out_chan, activation='none'),
                )

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.bn2(out)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out_ = shortcut + out
        out_ = self.relu(out_)
        return out_

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_chan, out_chan, stride=1, base_width=64, norm_layer=None):
        super(Bottleneck, self).__init__()
        width = int(out_chan*(base_width / 64.)) * 1
        self.norm_layer = norm_layer
        self.conv1 = conv1x1(in_chan, width)
        self.bn1 = norm_layer(width, activation='leaky_relu')
        self.conv2 = conv3x3(width, width, stride)
        self.bn2 = norm_layer(width, activation='leaky_relu')
        self.conv3 = conv1x1(width, out_chan * self.expansion)
        self.bn3 = norm_layer(out_chan * self.expansion, activation='none')
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan*self.expansion or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan*self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                norm_layer(out_chan*self.expansion, activation='none'),
                )

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.conv3(out)
        out = self.bn3(out)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out_ = shortcut +out
        out_ = self.relu(out_)

        return out_

class ResNet(nn.Module):
    def __init__(self, block, layers, strides, norm_layer=None):
        super(ResNet, self).__init__()
        self.norm_layer = norm_layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(64, activation='leaky_relu')
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.inplanes = 64
        self.layer1 = self.create_layer(block,   64, bnum=layers[0], stride=strides[0], norm_layer=norm_layer)
        self.layer2 = self.create_layer(block,  128, bnum=layers[1], stride=strides[1], norm_layer=norm_layer)
        self.layer3 = self.create_layer(block,  256, bnum=layers[2], stride=strides[2], norm_layer=norm_layer)
        self.layer4 = self.create_layer(block,  512, bnum=layers[3], stride=strides[3], norm_layer=norm_layer)

    def create_layer(self, block , out_chan, bnum, stride=1,norm_layer=None):
        layers = [block(self.inplanes, out_chan, stride=stride, norm_layer=norm_layer)]
        self.inplanes = out_chan*block.expansion
        for i in range(bnum-1):
            layers.append(block(self.inplanes, out_chan, stride=1, norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)

        feat4 = self.layer1(x)
        feat8 = self.layer2(feat4) # 1/8
        feat16 = self.layer3(feat8) # 1/16
        feat32 = self.layer4(feat16) # 1/32
        return feat4, feat8, feat16, feat32

    def init_weight(self,state_dict):
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            if 'fc' in k: continue
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict, strict=True)

def Resnet18(pretrained=True, norm_layer=None, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2],[2, 2, 2, 2], norm_layer=norm_layer)
    if pretrained:
        model.init_weight(model_zoo.load_url(model_urls['resnet18']))
    return model

In [None]:
# Check if using CPU or GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create an instance of our CNN
network = FANet(backbone='resnet18').to(device)

# initialize our optimizer. We'll use Adam
optimizer = torch.optim.Adam(network.parameters())

epochs = 50

# Train the CNN
for epoch in range(1, epochs+1):
    (train_loss, train_mIoU) = train(network, device, train_loader, optimizer)
    (val_loss, val_mIoU) = test(network, device, val_loader)
    print('Train Epoch: {:02d} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTraining mIoU: {:.6f} \tValidation mIoU: {:.6f}'.format(epoch, train_loss, val_loss, train_mIoU, val_mIoU))

# Test the CNN
#print('\nTest Loss:', test(network, device, test_loader))

Train Epoch: 01 	Training Loss: 5.193585 	Validation Loss: 1.815554 	Training mIoU: 0.024720 	Validation mIoU: 0.043057
Train Epoch: 02 	Training Loss: 1.658912 	Validation Loss: 1.495707 	Training mIoU: 0.054932 	Validation mIoU: 0.077109
Train Epoch: 03 	Training Loss: 1.303959 	Validation Loss: 1.134543 	Training mIoU: 0.083167 	Validation mIoU: 0.102064
Train Epoch: 04 	Training Loss: 1.147002 	Validation Loss: 1.114584 	Training mIoU: 0.104511 	Validation mIoU: 0.111412
Train Epoch: 05 	Training Loss: 1.081876 	Validation Loss: 1.129767 	Training mIoU: 0.110686 	Validation mIoU: 0.111982
Train Epoch: 06 	Training Loss: 1.034072 	Validation Loss: 1.026069 	Training mIoU: 0.118301 	Validation mIoU: 0.120688
Train Epoch: 07 	Training Loss: 1.006999 	Validation Loss: 1.020619 	Training mIoU: 0.124740 	Validation mIoU: 0.120670
Train Epoch: 08 	Training Loss: 0.986638 	Validation Loss: 0.939766 	Training mIoU: 0.126526 	Validation mIoU: 0.134797
Train Epoch: 09 	Training Loss: 0.937189