## IGARSS Code Workshop

We have a [labeled dataset of urban satellite scenes in Zurich](https://sites.google.com/site/michelevolpiresearch/data/zurich-dataset):
- 20 satellite scenes from QuickBird, 4-band (RGB + NIR) at 0.62m
- matching ground truth labels for 8 classes: Roads, Buildings, Trees, Grass, Bare Soil, Water, Railways and Swimming pools

Let's use `fastai` to build a segmentation algorithm from this data

## Zurich Summer Data

In [None]:
%matplotlib inline

In [None]:
# set up necessary dependencies

!curl -s https://course.fast.ai/setup/colab | bash
!pip install -q --upgrade wandb rasterio imgaug==0.2.5

In [None]:
# mount our google drive to access data
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"

In [None]:
# optional login to Weights and Biases for metric tracking: https://www.wandb.com/
# !wandb login [APIKEY]

In [None]:
import os
import math
from functools import partial

from fastai.vision import *
import rasterio
import torch
from torch import nn
from torchvision.models import resnet34
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score
from PIL import Image as PImage

import wandb
from wandb.fastai import WandbCallback

### Inspection of our data and naive first attempts

In [None]:
# open an input image
input_image = os.path.join(root_dir, 'zurich/images_tif/zh1.tif')
try:
    PImage.open(input_image)
except OSError as e:
    print('turns out pillow (and thus fast.ai) cannot open 16-bit tif files')

In [None]:
# instead let's use rasterio, it will also help handle the fourth band (although in visualizations it will appear as the alpha channel)
with rasterio.open(input_image) as src:
    img = src.read()
    print(img.min(), img.max(), img.dtype, img.shape)
    


In [None]:
# we can display this (awkwardly) with fastai
m = img.max()
Image(torch.from_numpy(img.astype(np.float32)).div_(m))

In [None]:
# let's try the labels
PImage.open(os.path.join(root_dir, 'zurich/groundtruth/zh1_GT.tif'))

`fastai` provides nice methods for creating segmentation data bunches. You can imagine based on the above that we will have to customize it a bit

In [None]:
# convert our segmentation label colors to classes
classes =  {0:0, 125:1, 150:2, 230:3, 255:4, 300:5, 510:6, 555:7, 765:8}

# extend the label list to use rasterio to open the data
class SatelliteSegmentationLabelList(SegmentationLabelList):
    def open(self, fn):
        with rasterio.open(fn) as src:
            label_sum = np.sum(src.read(), axis=0) # sum across channels
            label_cls = np.array([np.vectorize(classes.get)(label_sum)]) # map across our class/color dict
            return ImageSegment(torch.from_numpy(label_cls).float()) # return as an ImageSegment + float
    

In [None]:
# custom satellite segmentation class for reading our four band data
# the constants are roughly derived to normalize across bands
class SatelliteSegmentationItemList(SegmentationItemList):
    _label_cls = SatelliteSegmentationLabelList
    def open(self, fn):
        with rasterio.open(fn) as src:
            as_tensor = torch.from_numpy(src.read().astype(np.float32)) # read image into array + float
            as_tensor.div_(torch.tensor([[[500.]], [[500.]], [[700.]], [[1000.]]])) # normalize by band           
            return Image(as_tensor)

In [None]:
# try creating a databunch with our imagery + labels using our custom list + these methods
# data = (SatelliteSegmentationItemList    
#   .from_folder
#   .split_by_rand_pct
#   .label_from_func
#   .databunch
# )

In [None]:
# all of our images are different sizes so we can't make batches, let's fix that
# we have two options: read images in smaller, or use a transform to pull a random patch

# transform to provide random windowing into our large images
WINDOW_SIZE = (224, 224)
def _window_tfm(pxls, xrand:uniform=0.5, yrand:uniform=0.5):
    w, h = WINDOW_SIZE
    W, H = pxls.shape[-2:]
    x1 = math.floor(xrand * (W - w - 1))
    x2 = x1 + w
    y1 = math.floor(yrand * (H - h - 1))
    y2 = y1 + h
    return pxls[:, x1:x2,y1:y2]

window_tfm = TfmPixel(_window_tfm, order=1)
tfm = window_tfm(xrand=(0, 1), yrand=(0, 1))
xtra_tfms=[tfm]

In [None]:
# add extra transforms if desired
tfm_list = [
    window_tfm(xrand=(0, 1), yrand=(0, 1)),
#     zoom(scale=(1, 1.2)),
#     rotate(degrees=(-30, 30))
]

In [None]:
# create a fastai DataBunch with our imagery + labels + transforms
# add .transform method to prior attempt
# data = (SatelliteSegmentationItemList    
#   .from_folder
#   .split_by_rand_pct
#   .label_from_func
#   .databunch
# )

In [None]:
data.show_batch(figsize=(8,8)) # nice(?)

In [None]:
# show the data structure
data

In [None]:
# unfortunately we still have very few items to iterate over, let's fake that we have more files
class SatelliteSegmentationItemList(SegmentationItemList):
    _label_cls = SatelliteSegmentationLabelList
    def open(self, fn):
        with rasterio.open(fn) as src:
            return Image(torch.from_numpy(src.read().astype(np.float32)).div_(torch.tensor([[[500.]], [[500.]], [[700.]], [[1000.]]])))
    def duplicate_items(self, n):
        to_dup = self.items
        self.items = np.repeat(to_dup, n)
        return self

In [None]:
# now we are good to create our data (and add a bit of normalization at the end)
data = (SatelliteSegmentationItemList
  .from_folder(os.path.join(root_dir, 'zurich/images_tif'))
  .duplicate_items(6)
  # we need a new way to split our data      
  .label_from_func(lambda x: os.path.join(root_dir, f'zurich/groundtruth/{x.stem}_GT{x.suffix}'), classes=list(range(len(classes))))
  .transform((tfm_list, tfm_list), tfm_y=True)
  .databunch(bs=16)
  .normalize()
)

In [None]:
data

In [None]:
def IOU(input, target):
    target = target.squeeze(1)
    mask = target != 0
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [None]:
WEIGHT_DECAY=1e-2
# wandb.init(project="igarss-zurich-test")
learner = unet_learner(data, resnet34, metrics=[IOU], wd=WEIGHT_DECAY)

In [None]:
# let's try to train (just one epoch to start since it may not work)
learner.fit_one_cycle(1)

In [None]:
unet_input_conv = learner.model[0][0]

# add a new input layer with a fourth channel and copy over the weights
new_input = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
for i in range(3):
    new_input.weight[:,i] = unet_input_conv.weight[:,i]

for i in range(3,4):
    new_input.weight[:,i] = unet_input_conv.weight[:,2]

new_input.weight = nn.Parameter(new_input.weight.detach().requires_grad_(True))

# also add to skip channels to accept the extra channel
learner.model[0][0] = new_input
learner.layer_groups[0][0] = learner.model[0][0]
learner.model[10][0][0] = nn.Conv2d(100, 100, kernel_size=(3,3), stride=(1,1), padding=(1,1))
learner.model[10][1][0] = nn.Conv2d(100, 100, kernel_size=(3,3), stride=(1,1), padding=(1,1))
learner.model[11][0] = nn.Conv2d(100, len(classes), kernel_size=(1,1), stride=(1,1))
if torch.cuda.is_available():
    learner.model.cuda()

In [None]:
learner.summary()

In [None]:
# train
learner.fit_one_cycle(30)

In [None]:
learner.show_results(rows=3)

In [None]:
# unfreeze the pretrained weights for fine-tuning
learner.unfreeze()

In [None]:
learner.fit_one_cycle(30)

In [None]:
learner.show_results(rows=3)

In [None]:
# get the predictions on our validation set
preds, y_true = learner.get_preds()
pred_class = preds.argmax(dim=1)

In [None]:
# flatten our tensors and use scikit-learn to create a confusion matrix
flat_preds = pred_class.reshape(24 * 224 * 224)
flat_truth = y_true.reshape(24 * 224 * 224)
cm = confusion_matrix(flat_preds, flat_truth, labels=list(range(len(classes))))

In [None]:
class_labels = ['Roads', 'Buildings', 'Trees', 'Grass', 'Bare Soil', 'Water', 'Railways', 'Swimming pools', 'Background']  

# slight modification from sklearn (not yet available for segmentation in fastai)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
       yticks=np.arange(cm.shape[0]),
       # ... and label them with the respective list entries
       xticklabels=class_labels, yticklabels=class_labels,
       title='Normalized Confusion Matrix',
       ylabel='True label',
       xlabel='Predicted label')

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ax.text(j, i, format(cm[i, j], fmt),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()