<a href="https://colab.research.google.com/github/goromal/FANet_Evaluation/blob/main/FANet_trainer_WITH_TIME.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Dayne Howard, David Elatov, Andrew Torgesen, MIT, 6.862, Spring 2021
# Parts of this code were based on code developed by 6.036 staff for student use
# The FANet model here was modified slightly to include time effects.
# Original FANet Paper: arXiv:2007.03815v2 [cs.CV] 9 Jul 2020

In [None]:
! nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [None]:
import torch
torch.cuda.is_available()

False

### Repository

Mount drive, clone repo, navigate to repo, and change working directory to access repo files. **Run ONCE per computing session.**

In [None]:
import os
from google.colab import drive
import numpy as np
import h5py
drive.mount('/content/gdrive')
%cd gdrive/MyDrive
#Select the folder where the data is
%cd TartanAir

Mounted at /content/gdrive
/content/gdrive/MyDrive
/content/gdrive/.shortcut-targets-by-id/13SoHYEacjxfjCnzL_FlI8YHqm3zFzMva/TartanAir


In [None]:
! pip install oyaml
! pip install torchstat

Collecting oyaml
  Downloading https://files.pythonhosted.org/packages/37/aa/111610d8bf5b1bb7a295a048fc648cec346347a8b0be5881defd2d1b4a52/oyaml-1.0-py2.py3-none-any.whl
Installing collected packages: oyaml
Successfully installed oyaml-1.0
Collecting torchstat
  Downloading https://files.pythonhosted.org/packages/bc/fe/f483b907ca80c90f189cd892bb2ce7b2c256010b30314bbec4fc17d1b5f1/torchstat-0.0.7-py3-none-any.whl
Installing collected packages: torchstat
Successfully installed torchstat-0.0.7


In [None]:
from PIL import Image
import numpy as np
from matplotlib.image import imread
import os
from random import sample
# We'll use the PyTorch Framework
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchstat import stat
import torch.utils.model_zoo as model_zoo

import sys
sys.path.insert(0, '/content/gdrive/MyDrive/FANet_Evaluation/evaluation') # so that the evaluation pipeline's internal imports work
import oyaml as yaml
from PIL import Image # for image resizing

# For displaying images later
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

# Set a random seed for predictable behavior
torch.manual_seed(6036)

<torch._C.Generator at 0x7fd30eb23550>

In [None]:
# Argument class to instantiate a model
class ModelTrainArgs(object):
    def __init__(self):
        # MODEL ARGS
        self.img_width = 640
        self.img_height = 480
        self.num_channels = 3 # 3 channels for color images
        # data dir contains pre-processed weights.npy, X_train.npy, Y_train.npy, X_val.npy, Y_val.npy
        self.data_dir = '/content/gdrive/MyDrive/full_cityscapes_res/' # DATA LOCATED IN "My Drive/full_cityscapes_res"
        self.weighted_loss = True
        self.batch_size = 16
        self.learning_rate = 0.0001
        
        # TRAIN ARGS
        self.num_classes = 11 # for AirSim

# Set parameters
args = ModelTrainArgs()

In [None]:
# Load in the data as numpy files. X is input, Y is target
# All data (train, validation, test) included here.
# N=number of data oints
# C=number of channels
# W=Width
# H=Height
# X: (N, C, W, H)
# Y: (N, W, H)
X = np.load('X.npy')
Y = np.load('Y.npy')

print('Loaded X with size ' + str(X.shape))
print('Loaded Y with size ' + str(Y.shape))

Loaded X with size (8688, 3, 480, 640)
Loaded Y with size (8688, 60, 80)


In [None]:
# Specify how many previous frames, T, you want to include
# in the temporal aggregation context. 
# Change it in the FastAttModule to match self.T
T = 4

num_data_points = X.shape[0]

# All indices available are spread out by T+1.
ind = [*range(T,num_data_points,T+1)]

# chooses indices of about 10% of the points randomly for validation
val_ind = np.asarray(sample(ind,int(num_data_points/(T+1))//10))
for k in val_ind:
  ind.remove(k)

# chooses indices of about 10% of the points randomly for testing
test_ind = np.asarray(sample(ind,int(num_data_points/(T+1))//10))
for k in test_ind:
  ind.remove(k)

# The rest of the indices go to training
train_ind = np.asarray(sample(ind,len(ind)))

#This splits the indices into batches of 16.
train_ind = np.array_split(train_ind, len(train_ind)//16)
val_ind = np.array_split(val_ind, len(val_ind)//16)
test_ind = np.array_split(test_ind, len(test_ind)//16)

In [None]:
print(len(val_ind))
print(len(test_ind))
print(len(train_ind))

13
13
108


In [None]:
# A function which creates a "one_hot" tensor from a tensor with class labels
# starting at 0
def to_one_hot(tensor,device,nClasses=args.num_classes):
    n,h,w = tensor.size()
    one_hot = torch.zeros(n,nClasses,h,w).to(device).scatter_(1,tensor.view(n,1,h,w),1)
    return one_hot

# Mean Intersection Over Union
# This variant of mIoU gives even weight per pixel, not even weight per class
class mIoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True, n_classes=args.num_classes):
        super(mIoULoss, self).__init__()
        self.classes = n_classes

    def forward(self, inputs, target_oneHot):
      # inputs => N x Classes x H x W
    	# target_oneHot => N x Classes x H x W
      N = inputs.size()[0]
      
      # predicted probabilities for each pixel along channel
      inputs = torch.nn.functional.one_hot(inputs.argmax(1), num_classes=args.num_classes).permute(0,3,1,2)

      # Numerator Product
      inter = inputs * target_oneHot

      ## Sum over all pixels N x C x H x W => N x C
      #inter = inter.view(N,self.classes,-1).sum(2)
      ## Sum over all pixels N x C x H x W => N
      inter = torch.sum(inter,(1,2,3))

      #Denominator 
      union= inputs + target_oneHot - (inputs*target_oneHot)

      ## Sum over all pixels N x C x H x W => N
      union = torch.sum(union,(1,2,3))

      loss = inter/union
      ## Return average loss over classes and batch
      return np.nanmean(np.asarray(loss.cpu()))

In [None]:
def train(model, device, X, Y,indices, optimizer, T):
  '''
  Function for training our networks. One call to train() performs a single
  epoch for training.
  model: an instance of our model
  X: All input data as numpy array
  Y: All target data as numpy array
  indices: the indices to pull from X and Y, separated into a list of numpy 
  arrays. Each of these sublists is effectively a batch 
  device: either "cpu" or "cuda", depending on if you're running with GPU support
  optimizer: optimizer used for training
  T: How many previous frames are aggregated
  '''

  # Set the model to training mode.
  model.train()

  # initialze the optimizer (the optimizer implements SGD)
  optimizer.zero_grad()

  #we'll keep adding the loss of each batch to total_loss, so we can calculate
  #the average loss at the end of the epoch.
  total_loss = 0
  total_mIoU = 0
  numBatches = int(len(indices))

  # We'll iterate through each batch. One call of train() trains for 1 epoch.
  # batch_ind: an integer representing which batch number we're on

  for batch_ind in range(numBatches):
    batch_size = len(indices[batch_ind])
    mIoU = 0
    loss_value = 0

    for i in indices[batch_ind]:  #Loop through the member of the batch    
      # Normalize the input data here, since doing it all at once is too 
      # much memory. T previous time frames are included
      input = torch.from_numpy(X[i-T:i+1,:,:,:]/255)
      target = torch.from_numpy(Y[i-T:i+1,:,:])

      # This line sends data to GPU if you're using a GPU
      input = input.to(device, dtype=torch.float)
      target = target.type(torch.LongTensor).to(device)    

      # feed our input through the network
      output = model.forward(input)
      
      loss_function = nn.CrossEntropyLoss()
      # loss is only based on the frames after T. Current code setup means
      # this will just be a single frame.
      loss = loss_function(output[T:],target[T:])
      loss_value += loss.item()

      #Calculate the mIoU
      mIoU_function = mIoULoss()
      target_onehot = to_one_hot(target, device)
      mIoU += mIoU_function(output[T:], target_onehot[T:]).item()

      # Perform backprop. Gradients are retained through all members of the batch
      loss.backward()
    
    # after finishing the batch, take a step and reset
    optimizer.step()
    optimizer.zero_grad()
    
    loss_value = loss_value / batch_size
    mIoU       = mIoU / batch_size
    total_loss += loss_value
    total_mIoU += mIoU
  total_loss /= (numBatches)
  total_mIoU /= (numBatches)

  return (total_loss, total_mIoU)

def test(model, device, T, X, Y, indices):
  '''
  Function for testing our models. One call to test() runs through every
  datapoint in our dataset once.
  model: an instance of our model
  device: either "cpu" or "cuda:0", depending on if you're running with GPU support
  T: How many previous frames are aggregated
  X: All input data as numpy array
  Y: All target data as numpy array
  indices: the indices to pull from X and Y, separated into a list of numpy 
  arrays. Each of these sublists is effectively a batch 
  '''

  # set model to evaluation mode
  model.eval()

  # we'll keep track of total loss to calculate the average later
  total_loss = 0
  total_mIoU = 0
  numBatches = int(len(indices))

  #don't perform backprop if testing
  with torch.no_grad():
    # iterate thorugh each test image
    for batch_ind in range(numBatches):
      batch_size = len(indices[batch_ind])
      mIoU = 0
      loss_value = 0

      for i in indices[batch_ind]:      
        # Normalize the input data here, since doing it all at once is too 
        # much memory
        input = torch.from_numpy(X[i-T:i+1,:,:,:]/255)
        target = torch.from_numpy(Y[i-T:i+1,:,:])

        # (input, target) = readBatch(namelist[i:(i+args.batch_size)])
        # This line sends data to GPU if you're using a GPU
        input = input.to(device, dtype=torch.float)
        target = target.type(torch.LongTensor).to(device)    

        # feed our input through the network
        output = model.forward(input)
        
        loss_function = nn.CrossEntropyLoss()
        loss = loss_function(output[T:],target[T:])
        loss_value += loss.item()

        #Calculate the mIoU
        mIoU_function = mIoULoss()
        target_onehot = to_one_hot(target, device)
        mIoU += mIoU_function(output[T:], target_onehot[T:]).item()

      #accumulate loss to later calculate the average
      loss_value = loss_value / batch_size
      mIoU       = mIoU / batch_size
      total_loss += loss_value
      total_mIoU += mIoU        

  total_loss /= (numBatches)
  total_mIoU /= (numBatches)
  return (total_loss, total_mIoU)

In [None]:
#####################  FANet Model  ######################################
class BatchNorm2d(nn.BatchNorm2d):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, num_features, activation='none'):
        super(BatchNorm2d, self).__init__(num_features=num_features)
        if activation == 'leaky_relu':
            self.activation = nn.LeakyReLU()
        elif activation == 'none':
            self.activation = lambda x:x
        else:
            raise Exception("Accepted activation: ['leaky_relu']")

    def forward(self, x):
        return self.activation(x)

up_kwargs = {'mode': 'bilinear', 'align_corners': True}


class FANet(nn.Module):
    def __init__(self,
                 nclass=args.num_classes,
                 backbone='resnet18',
                 norm_layer=BatchNorm2d):
        super(FANet, self).__init__()

        self.norm_layer = norm_layer
        self._up_kwargs = up_kwargs
        self.nclass = nclass
        self.backbone = backbone
        if backbone == 'resnet18':
            self.expansion = 1
            self.resnet = Resnet18(norm_layer=norm_layer)
        elif backbone == 'resnet34':
            self.expansion = 1
            self.resnet = Resnet34(norm_layer=norm_layer)
        elif backbone == 'resnet50':
            self.expansion = 4
            self.resnet = Resnet50(norm_layer=norm_layer)
        elif backbone == 'resnet101':
            self.expansion = 4
            self.resnet = Resnet101(norm_layer=norm_layer)
        elif backbone == 'resnet152':
            self.expansion = 4
            self.resnet = Resnet152(norm_layer=norm_layer)
        else:
            raise RuntimeError('unknown backbone: {}'.format(backbone))

        self.fam_32 = FastAttModule(512*self.expansion,256,128,norm_layer=norm_layer)
        self.fam_16 = FastAttModule(256*self.expansion,256,128,norm_layer=norm_layer)
        self.fam_8 = FastAttModule(128*self.expansion,256,128,norm_layer=norm_layer)
        self.fam_4 = FastAttModule(64*self.expansion,256,128,norm_layer=norm_layer)

        self.clslayer  = FPNOutput(256, 256, nclass,norm_layer=norm_layer)

    def forward(self, x, lbl=None):

        _, _, h, w = x.size()

        feat4, feat8, feat16, feat32 = self.resnet(x)

        upfeat_32, smfeat_32 = self.fam_32(feat32,None,True,True)
        upfeat_16, smfeat_16 = self.fam_16(feat16,upfeat_32,True,True)
        upfeat_8             = self.fam_8(feat8,upfeat_16,True,False)
        smfeat_4             = self.fam_4(feat4,upfeat_8,False,True)

        x = self._upsample_cat(smfeat_16, smfeat_4)

        outputs = self.clslayer(x)

        return outputs

    def _upsample_cat(self, x1, x2):
        '''Upsample and concatenate feature maps.
        '''
        _,_,H,W = x2.size()
        x1 = F.interpolate(x1, (H,W), **self._up_kwargs)
        x = torch.cat([x1,x2],dim=1)
        return x


class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, norm_layer=None, activation='leaky_relu',*args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        self.norm_layer = norm_layer
        if self.norm_layer is not None:
            self.bn = norm_layer(out_chan, activation=activation)
        else:
            self.bn =  lambda x:x

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class FPNOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, norm_layer=None, *args, **kwargs):
        super(FPNOutput, self).__init__()
        self.norm_layer = norm_layer
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1, norm_layer=norm_layer)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x


class FastAttModule(nn.Module):
    def __init__(self, in_chan, mid_chn=256, out_chan=128, norm_layer=None, *args, **kwargs):
        super(FastAttModule, self).__init__()
        self.norm_layer = norm_layer
        self._up_kwargs = up_kwargs
        mid_chn = int(in_chan/2)        
        self.w_qs = ConvBNReLU(in_chan, 32, ks=1, stride=1, padding=0, norm_layer=norm_layer, activation='none')

        self.w_ks = ConvBNReLU(in_chan, 32, ks=1, stride=1, padding=0, norm_layer=norm_layer, activation='none')

        self.w_vs = ConvBNReLU(in_chan, in_chan, ks=1, stride=1, padding=0, norm_layer=norm_layer)

        self.latlayer3 = ConvBNReLU(in_chan, in_chan, ks=1, stride=1, padding=0, norm_layer=norm_layer)

        self.up = ConvBNReLU(in_chan, mid_chn, ks=1, stride=1, padding=1, norm_layer=norm_layer)
        self.smooth = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1, norm_layer=norm_layer)

        # Modification to original FANet Github code.
        # This T should match the T declared at the top of this code
        self.T = 4
          
    def forward(self, feat, up_fea_in,up_flag, smf_flag):
        
        query = self.w_qs(feat)
        key   = self.w_ks(feat)
        value = self.w_vs(feat)

        N,C,H,W = feat.size()

        query_ = query.view(N,32,-1).permute(0, 2, 1)
        query = F.normalize(query_, p=2, dim=2, eps=1e-12)

        key_   = key.view(N,32,-1)
        key   = F.normalize(key_, p=2, dim=1, eps=1e-12)

        value = value.view(N,C,-1).permute(0, 2, 1)

        f = torch.matmul(key, value)

        #Modification to original FANet Github code:
        #Add in the hisotry key and value maps
        batch_s = key.shape[0]
        for idx in range(self.T,batch_s):
          for t in range(idx-self.T,idx):
            f += torch.matmul(key[t,:,:], value[t,:,:])

        y = torch.matmul(query, f)
        y = y.permute(0, 2, 1).contiguous()

        y = y.view(N, C, H, W)
        W_y = self.latlayer3(y)
        p_feat = W_y + feat

        if up_flag and smf_flag:
            if up_fea_in is not None:
                p_feat = self._upsample_add(up_fea_in, p_feat)
            up_feat = self.up(p_feat)
            smooth_feat = self.smooth(p_feat)
            return up_feat, smooth_feat

        if up_flag and not smf_flag:
            if up_fea_in is not None:
                p_feat = self._upsample_add(up_fea_in, p_feat)
            up_feat = self.up(p_feat)
            return up_feat

        if not up_flag and smf_flag:
            if up_fea_in is not None:
                p_feat = self._upsample_add(up_fea_in, p_feat)
            smooth_feat = self.smooth(p_feat)
            return smooth_feat

    def _upsample_add(self, x, y):
        '''Upsample and add two feature maps.
        '''
        _,_,H,W = y.size()
        return F.interpolate(x, (H,W), **self._up_kwargs) + y

In [None]:
model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth'}

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                     padding=0, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_chan, out_chan, stride=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        self.norm_layer = norm_layer
        self.conv1 = conv3x3(in_chan, out_chan, stride)
        self.bn1 = norm_layer(out_chan, activation='leaky_relu')
        self.conv2 = conv3x3(out_chan, out_chan)
        self.bn2 = norm_layer(out_chan, activation='none')
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan,
                          kernel_size=1, stride=stride, bias=False),
                norm_layer(out_chan, activation='none'),
                )

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.bn2(out)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out_ = shortcut + out
        out_ = self.relu(out_)
        return out_

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_chan, out_chan, stride=1, base_width=64, norm_layer=None):
        super(Bottleneck, self).__init__()
        width = int(out_chan*(base_width / 64.)) * 1
        self.norm_layer = norm_layer
        self.conv1 = conv1x1(in_chan, width)
        self.bn1 = norm_layer(width, activation='leaky_relu')
        self.conv2 = conv3x3(width, width, stride)
        self.bn2 = norm_layer(width, activation='leaky_relu')
        self.conv3 = conv1x1(width, out_chan * self.expansion)
        self.bn3 = norm_layer(out_chan * self.expansion, activation='none')
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan*self.expansion or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan*self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                norm_layer(out_chan*self.expansion, activation='none'),
                )

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.conv3(out)
        out = self.bn3(out)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out_ = shortcut +out
        out_ = self.relu(out_)

        return out_

class ResNet(nn.Module):
    def __init__(self, block, layers, strides, norm_layer=None):
        super(ResNet, self).__init__()
        self.norm_layer = norm_layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(64, activation='leaky_relu')
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.inplanes = 64
        self.layer1 = self.create_layer(block,   64, bnum=layers[0], stride=strides[0], norm_layer=norm_layer)
        self.layer2 = self.create_layer(block,  128, bnum=layers[1], stride=strides[1], norm_layer=norm_layer)
        self.layer3 = self.create_layer(block,  256, bnum=layers[2], stride=strides[2], norm_layer=norm_layer)
        self.layer4 = self.create_layer(block,  512, bnum=layers[3], stride=strides[3], norm_layer=norm_layer)

    def create_layer(self, block , out_chan, bnum, stride=1,norm_layer=None):
        layers = [block(self.inplanes, out_chan, stride=stride, norm_layer=norm_layer)]
        self.inplanes = out_chan*block.expansion
        for i in range(bnum-1):
            layers.append(block(self.inplanes, out_chan, stride=1, norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)

        feat4 = self.layer1(x)
        feat8 = self.layer2(feat4) # 1/8
        feat16 = self.layer3(feat8) # 1/16
        feat32 = self.layer4(feat16) # 1/32
        return feat4, feat8, feat16, feat32

    def init_weight(self,state_dict):
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            if 'fc' in k: continue
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict, strict=True)

def Resnet18(pretrained=True, norm_layer=None, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2],[2, 2, 2, 2], norm_layer=norm_layer)
    if pretrained:
        model.init_weight(model_zoo.load_url(model_urls['resnet18']))
    return model

In [None]:
def ten2img(my_tensor,filename):
  '''Call this function to save a segmentation image.
  my_tensor = [labels, height, width]. labels start at 0 '''

  listofColors = ['black','red','lime','cyan','sandybrown','crimson',
                  'yellow', 'purple','maroon','fuchsia','olive']
  labels, h, w = my_tensor.shape
  my_image = np.zeros((h,w,3))
  for i in range(labels):
    col = matplotlib.colors.to_rgb(listofColors[i])
    my_image += np.tensordot(my_tensor[i,:,:],col,0)
  image = Image.fromarray((my_image*255).astype(np.uint8)).resize(size=(1024,512)).save(filename + '.png',format='PNG')

def saveimg(network, X,Y,indices, filename):
  '''Calls ten2img function above to save a generated segmentation map and its target'''
    network.eval()
    with torch.no_grad():
      i = indices[0]
      input = torch.from_numpy(X[i-T:i+1,:,:,:]/255)
      target = torch.from_numpy(Y[i-T:i+1,:,:])
      input = input.to(device, dtype=torch.float)

      output = network(input)[T].cpu()
      output_one_hot = np.asarray(torch.nn.functional.one_hot(output.argmax(0), num_classes=args.num_classes).permute(2,0,1))
      
      target = target.type(torch.LongTensor)
      target_one_hot = np.asarray(to_one_hot(target, "cpu")[T])

      ten2img(output_one_hot,'prediction_'+filename)
      ten2img(target_one_hot,'target_'+filename)
  
      return

In [None]:
# Check if using CPU or GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create an instance of our CNN
network = FANet(backbone='resnet18').to(device)

# initialize our optimizer. We'll use Adam
optimizer = torch.optim.Adam(network.parameters())

epochs = 15

# Train the CNN
for epoch in range(1, epochs+1):
    (train_loss, train_mIoU) = train(network, device, X, Y, train_ind , optimizer, T)
    (val_loss, val_mIoU) = test(network, device, T, X, Y , val_ind)
    print('Train Epoch: {:02d} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTraining mIoU: {:.6f} \tValidation mIoU: {:.6f}'.format(epoch, train_loss, val_loss, train_mIoU, val_mIoU))
    if (epoch==1 or epoch==5 or epoch==7 or epoch==10 or epoch==15):
      saveimg(network, X,Y,train_ind[0], 'T3_attemptA_mixedbatches_train_epoch_'+str(epoch))
      saveimg(network, X,Y,val_ind[0], 'T3_attemptA_mixedbatches_val_epoch_'+str(epoch))

# Test the CNN
(test_loss, test_mIoU) = test(network, device, T, X, Y , test_ind)
print('Test Loss: {:.6f} \t Test mIoU: {:.6f}'.format(test_loss, test_mIoU))
saveimg(network, X,Y,test_ind[0], 'T3_attemptA_mixedbatches_test')