THIS is a jupyter notebook named inference.ipynb that \
a. loads at least one image/sample from the test set \ 
b. loads trained parameters from the best model you trained \
c. runs inference (i.e. applies the model) on one image from the test set \
d. displays the predicJons for this image

In [1]:
from src.data.CanopyDataset import CanopyDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch

In [3]:
%matplotlib inline

# * To get an overview of training set and to visualize all bands + label

# plot individual samples from train set
val_df = CanopyDataset(split='validation')
#train_df = CanopyDataset(split='train', transforms=None)

from ipywidgets import interact
@interact(train_idx=range(len(val_df)))
def plot_sample(train_idx=0):
    train_img, train_label = val_df[train_idx]
    if torch.is_tensor(train_img):
        train_img = train_img.numpy()
        train_img = np.transpose(train_img, (1, 2, 0))
    print(train_img.shape)

    f, axs = plt.subplots(2,6, figsize=(14,4), constrained_layout=True)
    axs = axs.flatten()

    for i in range(12):
        sel = np.zeros(12, dtype=bool)
        sel[i] = True
        axs[i].imshow(train_img.compress(sel, axis=2))
        axs[i].set_title(f"Band {i}, index {train_idx}")

    f, ax = plt.subplots(1,1, figsize=(3,3))
    img = ax.imshow(train_label)
    plt.colorbar(img)
    ax.set_title("Label image")

interactive(children=(Dropdown(description='train_idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,…

In [3]:
# # training model
# from torch.utils.data import DataLoader

# train_dataset = train_df = CanopyDataset(split='train')

# # TODO create a training data dataloader with the specifications above
# train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=1)    	# or num_workers = 2?
# for image, label in train_dl:
#   print(image.shape, label.shape)

In [46]:
### 
# * to be implemented in a seperate file

from torch.optim import SGD
from torch import nn

def setup_optimiser(model, learning_rate, weight_decay):
  return SGD(
    model.parameters(),
    learning_rate,
    weight_decay
  )

from tqdm.notebook import trange      # pretty progress bar

# criterion = nn.CrossEntropyLoss()   # ! need to change
criterion = nn.MSELoss()

def train_epoch(data_loader, model, optimiser, device):

  # set model to training mode. This is important because some layers behave differently during training and testing
  model.train(True)
  model.to(device)

  # stats
  loss_total = 0.0
  oa_total = 0.0

  # iterate over dataset
  pBar = trange(len(data_loader))
  for idx, (data, target) in enumerate(data_loader):
    # put data and target onto correct device
    data, target = data.to(device), target.to(device)
    # ! change to dataloader or dataset
    data = data.to(torch.float32)   # to match weights of model
    target = target.to(torch.float32) # to match data of model

    # reset gradients
    optimiser.zero_grad()

    # forward pass
    pred = model(data)

    # loss
    loss = criterion(pred, target)

    # backward pass
    loss.backward()

    # parameter update
    optimiser.step()

    # stats update
    loss_total += loss.item()
    # ! probably need to change
    acc = torch.mean(torch.abs(torch.sub(pred, target))).item()
    oa_total += acc

    # format progress bar
    pBar.set_description('Loss: {:.2f}, OA: {:.2f}'.format(
      loss_total/(idx+1),
      100 * oa_total/(idx+1)
    ))
    pBar.update(1)
  
  pBar.close()

  # normalise stats
  loss_total /= len(data_loader)
  oa_total /= len(data_loader)

  return model, loss_total, oa_total

In [5]:
def validate_epoch(data_loader, model, device):       # note: no optimiser needed

  # set model to evaluation mode
  model.train(False)
  model.to(device)

  # stats
  loss_total = 0.0
  oa_total = 0.0

  # iterate over dataset
  pBar = trange(len(data_loader))
  for idx, (data, target) in enumerate(data_loader):
    with torch.no_grad():

      #TODO: likewise, implement the validation routine. This is very similar, but not identical, to the training steps.

      # put data and target onto correct device
      data, target = data.to(device), target.to(device)
      # ! change to dataloader or dataset
      data = data.to(torch.float32)   # to match weights of model
      target = target.to(torch.float32) # to match data of model

      # forward pass
      pred = model(data)

      # loss
      loss = criterion(pred, target)

      # stats update
      loss_total += loss.item()
      acc = torch.mean(torch.abs(torch.sub(pred, target))).item()
      oa_total += acc

      # format progress bar
      pBar.set_description('Loss: {:.2f}, OA: {:.2f}'.format(
        loss_total/(idx+1),
        100 * oa_total/(idx+1)
      ))
      pBar.update(1)

  pBar.close()

  # normalise stats
  loss_total /= len(data_loader)
  oa_total /= len(data_loader)

  return loss_total, oa_total

In [42]:
#
# ! --------------------------
# model to test, copy paste back when working

import torch
from torch import nn

def residual(in_chan, out_channel):
    residual = nn.Sequential(
            nn.BatchNorm2d(num_features=in_chan),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_chan, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
        )
    return residual

# def residual_maxpool(in_chan, out_channel):
#     res_max = nn.Sequential(
#         residual(in_chan, out_channel),
#         nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
#     )

#     return res_max

def residual_decode(in_chan, out_channel):
    residual = nn.Sequential(
            # nn.MaxUnpool2d(kernel_size=2, stride=1),
            nn.BatchNorm2d(num_features=in_chan),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_chan, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
        )
    return residual

class BasicBlock(nn.Module):

    def __init__(self, in_channel, out_channel, resblock):
        super(BasicBlock, self).__init__()

        # first block changes channel size, then keeps the same for the 2 other
        self.sub1 = resblock(in_channel, out_channel)
        self.sub23 = resblock(out_channel, out_channel)
        
        if in_channel < out_channel:
            self.skip = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1)
        elif in_channel > out_channel:
            self.skip = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=1, stride=1)
        else:
            raise ValueError("Basis block: in_channel and out_channel should not be equal")

    def forward(self, x):
        out = self.sub1(x)  # first block changes channel size
        out = self.sub23(out)
        out = self.sub23(out)

        out = torch.add(out, self.skip(x))

        return out


class SIDE(nn.Module):

    def __init__(self):
        super().__init__()  # super(self, SIDE).__init__() for backward compatiility

        self.residualAdapt = BasicBlock(12, 64, residual)
        # seperate as we need the result for the forward pass
        
        self.residual1 = BasicBlock(64, 128, residual)
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        # ! skip other ones, decode from 8 by 8


        # * Upsampling

        # unpool needs additionnal arg (indices), so seperate from residual block
        # nn.Sequential doesn't allow for additional params
        # https://stackoverflow.com/questions/59912850/autoencoder-maxunpool2d-missing-indices-argument
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.up1 = BasicBlock(128, 64, residual_decode)
        # maxunpool, 
        #self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=1)
        # need element wise sum in forward before last residual block
        self.final = BasicBlock(64, 1, residual_decode)
    

    def forward(self, x):
        x = self.residualAdapt(x)     # get from 12 to 64 channels
        #print("shape after resadapt : " + str(x.shape))
        out, ind1 = self.maxpool(x)    # first maxpool to 16x16x128
        #print("shape after first_maxpool : " + str(x1.shape))
        out = self.residual1(out)     
        #print("shape after res1 : " + str(x1.shape))   
        out, ind2 = self.maxpool(out)    # second maxpool to 8x8x256
        #print("shape after maxpool 2 : " + str(x1.shape)) 
        out = self.unpool(out, ind2)    # ind 2 as they are the last ones
        #print("shape after unpool 1 : " + str(x1.shape))
        out = self.up1(out)
        #print("shape after up1 : " + str(out.shape))
        out = self.unpool(out, ind1)
        print("shape after unpool2 : " + str(out.shape))

        out = torch.add(x, out)
        print("shape out after add : " + str(out.shape))
        out = self.final(out)
        print("shape out before squeeze : " + str(out.shape))
        # need to get 32x32 tensor to compare to label
        return out.squeeze()

In [7]:
from torch.utils.data import DataLoader

# we also create a function for the data loader here (see Section 2.6 in Exercise 6)
def load_dataloader(batch_size, dataset, split='train'):
  return DataLoader(
      dataset,
      batch_size=batch_size,
      shuffle=(split=='train'),       # we shuffle the image order for the training dataset
      num_workers=2                   # perform data loading with two CPU threads
  )

In [8]:
#
# ! ------------------ model saving/loading

import glob
import os
#from src.models.SIDE_code_decode import SIDE

os.makedirs('cnn_states/SIDE', exist_ok=True)

def load_model(epoch='latest'):
  model = SIDE()
  modelStates = glob.glob('cnn_states/SIDE/*.pth')
  if len(modelStates) and (epoch == 'latest' or epoch > 0):
    modelStates = [int(m.replace('cnn_states/SIDE/','').replace('.pth', '')) for m in modelStates]
    if epoch == 'latest':
      epoch = max(modelStates)
    stateDict = torch.load(open(f'cnn_states/SIDE/{epoch}.pth', 'rb'), map_location='cpu')  # selects wieghts from epoch
    model.load_state_dict(stateDict)
  else:
    # fresh model
    epoch = 0       # no loaded weights
  return model, epoch


def save_model(model, epoch):
  torch.save(model.state_dict(), open(f'cnn_states/SIDE/{epoch}.pth', 'wb'))

In [47]:
#from src.models.SIDE_code_decode import SIDE

# define hyperparameters
device = 'cuda'
start_epoch = 0        # set to 0 to start from scratch again or to 'latest' to continue training from saved checkpoint
batch_size = 2
learning_rate = 0.0001
weight_decay = 0.001
num_epochs = 1

# * create all the needed variables
train_test_df = CanopyDataset(split='train')
val_test_df = CanopyDataset(split='validation')

# dataloader
dl_train_test = load_dataloader(batch_size, train_test_df)
dl_val_test = load_dataloader(batch_size, val_test_df)

# model
model_test = SIDE()

# optimizer

optim_test = setup_optimiser(model_test, learning_rate, weight_decay)



In [None]:
# only one step
model = model_test
data_loader = dl_train_test
optimiser = setup_optimiser(model, learning_rate, weight_decay)

model.train(True)
model.to(device)

# stats
loss_total = 0.0
oa_total = 0.0

num_im = 3

for idx, (data, target) in enumerate(data_loader):
    # put data and target onto correct device
    data, target = data.to(device), target.to(device)
    # ! change to dataloader or dataset
    data = data.to(torch.float32)   # to match weights of model
    target = target.to(torch.float32) # to match data of model

    # reset gradients
    optimiser.zero_grad()

    # forward pass
    pred = model(data)

    # loss
    loss = criterion(pred, target)
    #print(str(type(loss)) + '  :  ' + str(loss.dtype)+ '  :  ' + str(loss.shape))
    #print(loss)

    # backward pass
    loss.backward()

    # parameter update
    optimiser.step()

    # stats update
    loss_total += loss.item()
    # mean of absolute per pixel height differences from predicted height and GT
    acc = torch.mean(torch.abs(torch.sub(pred, target))).item()
    oa_total += acc
    print('OA : ' + str(acc))

    #to do only 1 to test
    if idx > num_im:
        break
    

# normalise stats
loss_total /= num_im
oa_total /= num_im
print('totals:')
print(loss_total)
print(oa_total)


shape after unpool2 : torch.Size([2, 64, 32, 32])
shape out after add : torch.Size([2, 64, 32, 32])
shape out before squeeze : torch.Size([2, 1, 32, 32])
OA : 5389.33447265625
shape after unpool2 : torch.Size([2, 64, 32, 32])
shape out after add : torch.Size([2, 64, 32, 32])
shape out before squeeze : torch.Size([2, 1, 32, 32])
OA : 1301163563548672.0
shape after unpool2 : torch.Size([2, 64, 32, 32])
shape out after add : torch.Size([2, 64, 32, 32])
shape out before squeeze : torch.Size([2, 1, 32, 32])
OA : nan
shape after unpool2 : torch.Size([2, 64, 32, 32])
shape out after add : torch.Size([2, 64, 32, 32])
shape out before squeeze : torch.Size([2, 1, 32, 32])
OA : nan
shape after unpool2 : torch.Size([2, 64, 32, 32])
shape out after add : torch.Size([2, 64, 32, 32])
shape out before squeeze : torch.Size([2, 1, 32, 32])
OA : nan
totals:
nan
nan


In [None]:
# do epochs
while start_epoch < num_epochs:

  # training
  model, loss_train, oa_train = train_epoch(dl_train_test, model_test, optim_test, device)

  # validation
  loss_val, oa_val = validate_epoch(dl_val_test, model, device)

  # print stats
  print('[Ep. {}/{}] Loss train: {:.2f}, val: {:.2f}; OA train: {:.2f}, val: {:.2f}'.format(
      start_epoch+1, num_epochs,
      loss_train, loss_val,
      100*oa_train, 100*oa_val
  ))

  # save model
  start_epoch += 1
  save_model(model, start_epoch)

In [41]:

model_plot = model

val_df = CanopyDataset(split='validation')
#train_df = CanopyDataset(split='train', transforms=None)

from ipywidgets import interact
@interact(idx_val=range(len(val_df)))
def plot_sample(idx_val=0):
    train_img, train_label = val_df[idx_val]
    train_img = train_img.to(torch.float32).to('cuda')
    

    #data = data.to(torch.float32)   # to match weights of model
    #target = target.to(torch.float32) # to match data of model

    model_plot.train(False)
    # as model expects batch number
    train_img = model_plot(train_img.unsqueeze(0))
    #model(image_valid.unsqueeze(0))

    f, ax = plt.subplots(1,2, figsize=(6,6))
    ax = ax.flatten()
    img = ax[0].imshow(train_img.cpu().detach())      # conversion to be able to plot
    plt.colorbar(img)
    ax[0].set_title("Train image")

    img = ax[1].imshow(train_label)
    plt.colorbar(img)
    ax[1].set_title("Train label")
    plt.tight_layout()


interactive(children=(Dropdown(description='idx_val', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1…