THIS is a jupyter notebook named inference.ipynb that \
a. loads at least one image/sample from the test set \ 
b. loads trained parameters from the best model you trained \
c. runs inference (i.e. applies the model) on one image from the test set \
d. displays the predicJons for this image

In [1]:
from src.data.CanopyDataset import CanopyDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch

In [2]:
%matplotlib inline

# * To get an overview of training set and to visualize all bands + label

# plot individual samples from train set
val_df = CanopyDataset(split='validation')
#train_df = CanopyDataset(split='train', transforms=None)

from ipywidgets import interact
@interact(train_idx=range(len(val_df)))
def plot_sample(train_idx=0):
    train_img, train_label = val_df[train_idx]

    # this is only to plot
    if torch.is_tensor(train_img):
        train_img = train_img.numpy()
        train_img = np.transpose(train_img, (1, 2, 0))
    print(train_img.shape)

    f, axs = plt.subplots(2,6, figsize=(14,4), constrained_layout=True)
    axs = axs.flatten()

    for i in range(12):
        sel = np.zeros(12, dtype=bool)
        sel[i] = True
        img = axs[i].imshow(train_img.compress(sel, axis=2))
        axs[i].set_title(f"Band {i}, index {train_idx}")
        # img = ax[0].imshow(train_img.cpu().detach())      # conversion to be able to plot
        plt.colorbar(img)

    f, ax = plt.subplots(1,1, figsize=(3,3))
    img = ax.imshow(train_label)
    plt.colorbar(img)
    ax.set_title("Label image")

    # f, ax = plt.subplots(1,1, figsize=(3,3))
    # sel = np.zeros(12, dtype=bool)
    # sel[1:4] = True
    # #print(train_img.compress(sel, axis=2))
    # img = ax.imshow(train_img.compress(sel, axis=2))
    # plt.colorbar(img)
    # ax.set_title("Test RGB")

interactive(children=(Dropdown(description='train_idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,…

In [3]:
# # training model
# from torch.utils.data import DataLoader

# train_dataset = train_df = CanopyDataset(split='train')

# # TODO create a training data dataloader with the specifications above
# train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=1)    	# or num_workers = 2?
# for image, label in train_dl:
#   print(image.shape, label.shape)

In [4]:
### 
# * to be implemented in a seperate file

from torch.optim import SGD
from torch import nn

def setup_optimiser(model, learning_rate, weight_decay):
  return SGD(
    model.parameters(),
    learning_rate,
    weight_decay
  )

from tqdm.notebook import trange      # pretty progress bar

# criterion = nn.CrossEntropyLoss()   # ! need to change
criterion = nn.MSELoss()

def train_epoch(data_loader, model, optimiser, device):

  # set model to training mode. This is important because some layers behave differently during training and testing
  model.train(True)
  model.to(device)

  # stats
  loss_total = 0.0
  oa_total = 0.0

  # iterate over dataset
  pBar = trange(len(data_loader))
  for idx, (data, target) in enumerate(data_loader):
    # put data and target onto correct device
    data, target = data.to(device), target.to(device)
    # ! change to dataloader or dataset
    data = data.to(torch.float32)   # to match weights of model
    target = target.to(torch.float32) # to match data of model

    # reset gradients
    optimiser.zero_grad()

    # forward pass
    pred = model(data)

    # loss
    target = target.squeeze()  # to match shape of pred
    loss = criterion(pred, target)

    # backward pass
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # parameter update
    optimiser.step()

    # stats update
    loss_total += loss.item()
    # ! probably need to change
    acc = torch.mean(torch.abs(torch.sub(pred, target))).item()
    oa_total += acc

    # format progress bar
    pBar.set_description('Loss: {:.2f}, OA: {:.2f}'.format(
      loss_total/(idx+1),
      oa_total/(idx+1)
    ))
    pBar.update(1)
  
  pBar.close()

  # normalise stats
  loss_total /= len(data_loader)
  oa_total /= len(data_loader)

  return model, loss_total, oa_total

In [5]:
def validate_epoch(data_loader, model, device):       # note: no optimiser needed

  # set model to evaluation mode
  model.train(False)
  model.to(device)

  # stats
  loss_total = 0.0
  oa_total = 0.0

  # iterate over dataset
  pBar = trange(len(data_loader))
  for idx, (data, target) in enumerate(data_loader):
    with torch.no_grad():

      #TODO: likewise, implement the validation routine. This is very similar, but not identical, to the training steps.

      # put data and target onto correct device
      data, target = data.to(device), target.to(device)
      # ! change to dataloader or dataset
      data = data.to(torch.float32)   # to match weights of model
      target = target.to(torch.float32) # to match data of model

      # forward pass
      pred = model(data)

      # loss
      target = target.squeeze()  # to match shape of pred
      loss = criterion(pred, target)

      # stats update
      loss_total += loss.item()
      acc = torch.mean(torch.abs(torch.sub(pred, target))).item()
      oa_total += acc

      # format progress bar
      pBar.set_description('Loss: {:.2f}, OA: {:.2f}'.format(
        loss_total/(idx+1),
        oa_total/(idx+1)
      ))
      pBar.update(1)

  pBar.close()

  # normalise stats
  loss_total /= len(data_loader)
  oa_total /= len(data_loader)

  return loss_total, oa_total

In [6]:
# get model from file and not from notebook implementation
from src.models.SIDE_code_decode import SIDE

In [7]:
from torch.utils.data import DataLoader

# we also create a function for the data loader here (see Section 2.6 in Exercise 6)
def load_dataloader(batch_size, dataset, split='train'):
  return DataLoader(
      dataset,
      batch_size=batch_size,
      shuffle=(split=='train'),       # we shuffle the image order for the training dataset
      num_workers=2                   # perform data loading with two CPU threads
  )

In [8]:
#
# ! ------------------ model saving/loading

import glob
import os
#from src.models.SIDE_code_decode import SIDE

os.makedirs('cnn_states/SIDE', exist_ok=True)

def load_model(epoch='latest'):
  model = SIDE()
  modelStates = glob.glob('cnn_states/SIDE/*.pth')
  if len(modelStates) and (epoch == 'latest' or epoch > 0):
    modelStates = [int(m.replace('cnn_states/SIDE/','').replace('.pth', '')) for m in modelStates]
    if epoch == 'latest':
      epoch = max(modelStates)
    stateDict = torch.load(open(f'cnn_states/SIDE/{epoch}.pth', 'rb'), map_location='cpu')  # selects wieghts from epoch
    model.load_state_dict(stateDict)
  else:
    # fresh model
    epoch = 0       # no loaded weights
  return model, epoch


def save_model(model, epoch):
  torch.save(model.state_dict(), open(f'cnn_states/SIDE/{epoch}.pth', 'wb'))

In [9]:
# define hyperparameters
device = 'cuda'
start_epoch = 0        # set to 0 to start from scratch again or to 'latest' to continue training from saved checkpoint
batch_size = 1
learning_rate = 0.001
weight_decay = 0.001
num_epochs = 10

# * create all the needed variables
train_test_df = CanopyDataset(split='train')
val_test_df = CanopyDataset(split='validation')

# dataloader
dl_train_test = load_dataloader(batch_size, train_test_df)
dl_val_test = load_dataloader(batch_size, val_test_df)

# model
model_test = SIDE()

# optimizer

optim_test = setup_optimiser(model_test, learning_rate, weight_decay)



In [10]:
model_plot = []
# do epochs
while start_epoch < num_epochs:

  # training
  model_plot, loss_train, oa_train = train_epoch(dl_train_test, model_test, optim_test, device)

  # validation
  loss_val, oa_val = validate_epoch(dl_val_test, model_plot, device)

  # print stats
  print('[Ep. {}/{}] Loss train: {:.2f}, val: {:.2f}; OA train: {:.2f}, val: {:.2f}'.format(
      start_epoch+1, num_epochs,
      loss_train, loss_val,
      oa_train, oa_val
  ))

  # save model
  start_epoch += 1
  save_model(model_plot, start_epoch)

  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 1/10] Loss train: 1462.20, val: 14834.87; OA train: 13.58, val: 62.85


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 2/10] Loss train: 1376.49, val: 13995.63; OA train: 11.55, val: 61.46


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 3/10] Loss train: 1373.59, val: 13944.91; OA train: 11.44, val: 60.37


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 4/10] Loss train: 1371.42, val: 14027.08; OA train: 11.37, val: 64.84


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 5/10] Loss train: 1370.63, val: 13626.12; OA train: 11.33, val: 67.73


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 6/10] Loss train: 1369.89, val: 13802.18; OA train: 11.29, val: 68.11


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 7/10] Loss train: 1369.06, val: 14056.69; OA train: 11.27, val: 68.86


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 8/10] Loss train: 1368.77, val: 14046.99; OA train: 11.24, val: 70.09


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 9/10] Loss train: 1368.43, val: 14437.69; OA train: 11.22, val: 71.04


  0%|          | 0/5910 [00:00<?, ?it/s]

  0%|          | 0/1971 [00:00<?, ?it/s]

[Ep. 10/10] Loss train: 1366.96, val: 14882.13; OA train: 11.21, val: 72.47


In [12]:
val_df = CanopyDataset(split='validation')
#train_df = CanopyDataset(split='train', transforms=None)

from ipywidgets import interact
@interact(idx_val=range(len(val_df)))
def plot_sample(idx_val=0):
    train_img, train_label = val_df[idx_val]
    train_img = train_img.to(torch.float32).to('cuda')
    

    #data = data.to(torch.float32)   # to match weights of model
    #target = target.to(torch.float32) # to match data of model

    model_plot.train(False)
    # as model expects batch number
    train_img = model_plot(train_img.unsqueeze(0))
    #model(image_valid.unsqueeze(0))

    f, ax = plt.subplots(1,2, figsize=(6,6))
    ax = ax.flatten()
    img = ax[0].imshow(train_img.cpu().detach())      # conversion to be able to plot
    plt.colorbar(img)
    ax[0].set_title("Train image")

    img = ax[1].imshow(train_label)
    plt.colorbar(img)
    ax[1].set_title("Train label")
    plt.tight_layout()


interactive(children=(Dropdown(description='idx_val', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1…