In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from torchvision import transforms
import numpy as np
import torch

In [None]:
from src.path import ProjPaths
from src.data.band3_binary_mask_data import Band3BinaryMaskDataset, RandomCropImgAndLabels, ToTensorImgAndLabels
from src.models.unet_ptl import UNet
from src.metrics import sample_logits_and_labels, logits_to_prediction, classification_cases, prediction_metrics, compute_true_false_classifications_for_sample_and_model
from src.visualization.visualize import show_image_and_true_false_classifications

In [None]:
metrics_fpath = ProjPaths.metrics_path / 'unet_ptl_v5.csv'

In [None]:
model_metrics = pd.read_csv(metrics_fpath)
model_metrics.head(3)

In [None]:
test_path = ProjPaths.interim_sn1_data_path / "test"
test_dataset = Band3BinaryMaskDataset(test_path, transform=transforms.Compose([
                                           RandomCropImgAndLabels(384),
                                           ToTensorImgAndLabels()
                                       ]))

## EDA of building cover

In [None]:
sns.histplot(data=model_metrics['building_cover'])
plt.title('Frequency of building land cover')
plt.show()

In [None]:
images_without_buildings = np.round(model_metrics.query('building_cover == 0').shape[0] / model_metrics.shape[0] * 100, 2)
print(f'{images_without_buildings} % of images do not have any building pixels')

## Find patterns in metrics

The less land that is covered by buildings in the image, the higher the accuracy. The intuitive edge case is when almost no building needs to be detected at all, because in such a case the very naive model that predicts "no building" all the time would already achieve a good accuracy.

In [None]:
sns.scatterplot(x='building_cover', y='accuracy', data=model_metrics)
plt.title('Accuracy as a function of fractional building cover per image')
plt.show()

This effect is much less visible for Jaccard index values. The reason for this is that the Jaccard index better corrects for the overall number of pixels with buildings.

In [None]:
sns.scatterplot(x='building_cover', y='jaccard', data=model_metrics)
plt.title('Jaccard index as a function of fractional building cover per image')
plt.show()

A (rather similar) alternative to the Jaccard index would be the dice index:

In [None]:
sns.scatterplot(x='jaccard', y='dice', data=model_metrics)
plt.title('Comparison of dice index and jaccard index')
plt.show()

Inspect model bias: are there generally more or less buildings predicted?

In [None]:
avg_predicted_positives = ((model_metrics['true_pos'] + model_metrics['false_pos'])/model_metrics['n_pixels']).mean()
avg_building_cover = model_metrics['building_cover'].mean()

In [None]:
avg_predicted_positives

In [None]:
avg_building_cover

## Evaluate model for individual samples

Here we will use our UNet Pytorch Lightning model.

In [None]:
chkpt_path = ProjPaths.model_path / 'unet' / 'unet_ptl_v5' / 'checkpoints' / 'best_model-unet-epoch=15-val_loss=0.09.ckpt'
model = UNet.load_from_checkpoint(chkpt_path)

model.eval()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(DEVICE)

In [None]:
def compute_true_false_classifications(this_sample_id):
    
    sample = test_dataset[this_sample_id]
    classes_masked = compute_true_false_classifications_for_sample_and_model(sample, model, DEVICE)
    
    return sample, classes_masked

In [None]:
this_sample_id = 132
this_sample_id = 866

sample, classes_masked = compute_true_false_classifications(this_sample_id)

### Visualize

In [None]:
show_image_and_true_false_classifications(sample, classes_masked)

In [None]:
pred = ((classes_masked.data == 1) | (classes_masked.data == 3))*1 # get original prediction from true_pos and false_pos

In [None]:
plt.imshow(pred)
plt.show()

## Find example cases

In [None]:
model_metrics.query('jaccard < 0.4').query('building_cover < 0.2')

In [None]:
model_metrics.query('jaccard < 0.4').query('building_cover > 0.2')

In [None]:
sample_ids = [237, 512, 634, 866]

for this_sample_id in sample_ids:

    sample, classes_masked = compute_true_false_classifications(this_sample_id)
    show_image_and_true_false_classifications(sample, classes_masked)

TODO: precision, recall, F1 score

In [None]:
image_ids = [test_dataset.image_ids[this_sample_id] for this_sample_id in sample_ids]

In [None]:
image_ids

## Multi-spectral images

In [None]:
path_8band = ProjPaths.interim_sn1_data_path / 'test' / '8band'
fpaths = [path_8band / (f'8band_' + this_image_id + '.tif') for this_image_id in image_ids]

In [None]:
fpaths

In [None]:
import rioxarray

In [None]:
img_arr = rioxarray.open_rasterio(str(fpaths[3]))
img_arr

In [None]:
img_arr.shape

In [None]:
img_arr[0].band.rio.crs

In [None]:
img_arr.plot.imshow(col="band", col_wrap=3, cmap="Greys_r")
plt.show()

In [None]:
def scale_values(in_arr):
    
    arr = in_arr.copy()
    
    min_val = np.min(arr)
    max_val = np.max(arr)
    arr_range = max_val - min_val
    factor = arr_range / 255
    
    arr = (arr - min_val)
    arr = arr/factor
    
    return arr

In [None]:
vals = img_arr[[4, 3, 2]].values

for ii in [0, 1, 2]:
    vals[ii, :, :] = scale_values(vals[ii, :, :])

In [None]:
vals_disp = np.moveaxis(vals, 0, -1)

In [None]:
vals_disp.shape

In [None]:
plt.imshow(vals_disp)
plt.show()