# SWED Evaluation 
Evaluate different segmentation approaches on the SWED test dataset. These include the deterministic superpixel algorithm, pretrained and finetuned U-Net model. The accuracy, precision, recall, F1 and FOM metrics are calculated. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import train
import network
import torch

import glob

import importlib
import evaluation as eval
import utils
importlib.reload(utils)
importlib.reload(eval)

base_path = '../../data/'

# SWED

In [None]:
# Model names
swed_original = "SWED_UNET_12JUL2024.pth" #Model from previos paper (will not give SOTA results due to different architecture)
swed_superpixel = "SWED_SUPERPIXELS_12JUL2024.pth" #Pretrained model trained on superpixel output
swed_finetune = "SWED-FINETUNE-26JUL24.pth" #Fine-tuned model 

In [None]:
# Load LICS data
incl_bands = [0,1,2,3,4,5,6,7,8,9,10,11]
satellite = 'sentinel'

# Test data
target_pos = -1

test_file = base_path + 'SWED/test/'
test_paths = glob.glob(test_file + '*.npy')
test_targets = [np.load(file)[:,:,target_pos] for file in test_paths]
test_input = [np.load(file)[:,:,incl_bands] for file in test_paths]

print("Test dimensions:")
print(np.shape(test_targets))
print(np.shape(test_input))

# Finetune data
target_pos = -1

finetune_file = base_path + 'SWED/finetune/'
finetune_paths = glob.glob(finetune_file + '*.npy')
finetune_targets = [np.load(file)[:,:,target_pos] for file in finetune_paths]
finetune_input = [np.load(file)[:,:,incl_bands] for file in finetune_paths]

print("\nFinetune dimensions:")
print(np.shape(finetune_targets))
print(np.shape(finetune_input))

In [None]:
# Sense check the data

# Test data
i = np.random.randint(0,len(test_paths))
rgb = utils.get_rgb(test_input[i],satellite=satellite,contrast=0.2)
target = test_targets[i]

fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(rgb)
ax[1].imshow(target)

for a in ax:
    a.axis('off')

# Finetune data
i = np.random.randint(0,len(finetune_paths))
rgb = utils.get_rgb(finetune_input[i],satellite=satellite,contrast=0.2)
target = finetune_targets[i]

fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(rgb)
ax[1].imshow(target)

for a in ax:
    a.axis('off')

## Superpixel algorithm

In [None]:
# Ititialize metrics
test_metrics = {}
finetune_metrics = {}

In [None]:
def get_sp_predictions(paths,satellite,rgb_bands,index_name,threshold = -1,method='slic', **kwargs):
    # Copy the input image to avoid modifying the original
    preds = []
    for path in paths:
        all_bands = np.load(path)
        mask = utils.get_mask_from_bands(all_bands, 
                                         satellite=satellite,
                                         rgb_bands=rgb_bands,
                                         threshold=threshold, 
                                         index_name=index_name,
                                         method=method, **kwargs)
        preds.append(mask)
    return preds


In [None]:
# Test
preds = get_sp_predictions(test_paths,
                           satellite='sentinel',
                           rgb_bands=["nir", "green", "blue"],
                           index_name="NDWI",
                           threshold=0, 
                           method='felzenszwalb',
                           min_size=60)


metrics, arr = eval.eval_metrics(test_targets,preds)
test_metrics['superpixels'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Finetune
preds = get_sp_predictions(finetune_paths,
                            satellite='sentinel',
                           rgb_bands=["nir", "green", "blue"],
                           index_name="NDWI",
                           threshold=0, 
                           method='felzenszwalb',
                           min_size=60)


metrics, arr = eval.eval_metrics(finetune_targets,preds)
finetune_metrics['superpixels'] = metrics
eval.display_metrics(metrics,arr)

# Original model

In [None]:
# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

model = network.U_Net(12,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{swed_original}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
from train_unet import * #load dataset and model classes

# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

# Load saved model 
model = torch.load('../../models/UNET-SCALE-13MAR23.pth', map_location=torch.device('cpu') )
model.eval()
model.to(device)

In [None]:
# Test
targets, preds = eval.get_preds(model,test_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(test_targets ,preds)
test_metrics['original'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Finetune
targets, preds = eval.get_preds(model,finetune_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(finetune_targets ,preds)
finetune_metrics['original'] = metrics
eval.display_metrics(metrics,arr)

## Rough model

In [None]:
# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

model = network.U_Net(12,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{swed_superpixel}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Test
targets, preds = eval.get_preds(model,test_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(test_targets ,preds)
test_metrics['rough_model'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Display some predictions
i = np.random.randint(0,len(test_paths))
rgb = utils.get_rgb(test_input[i],satellite=satellite,contrast=0.2)
target = test_targets[i]
pred = preds[i]

fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(rgb)
ax[1].imshow(target)
ax[2].imshow(pred)

for a in ax:
    a.axis('off')

In [None]:
# Finetune
targets, preds = eval.get_preds(model,finetune_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(finetune_targets ,preds)
finetune_metrics['rough_model'] = metrics
eval.display_metrics(metrics,arr)


# Finetuned model

In [None]:
# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

model = network.U_Net(12,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{swed_finetune}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Test
targets, preds = eval.get_preds(model,test_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(test_targets ,preds)
test_metrics['finetune_model'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Display some predictions
i = np.random.randint(0,len(test_paths))
rgb = utils.get_rgb(test_input[i],satellite=satellite,contrast=0.2)
target = test_targets[i]
pred = preds[i]

fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(rgb)
ax[1].imshow(target)
ax[2].imshow(pred)

for a in ax:
    a.axis('off')

In [None]:
# Finetune
targets, preds = eval.get_preds(model,finetune_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(finetune_targets ,preds)
finetune_metrics['finetune_model'] = metrics
eval.display_metrics(metrics,arr)

# Final Metrics Table

In [None]:
 # Test metrics
df_test_metrics = pd.DataFrame(test_metrics)

df_test_metrics = df_test_metrics[['original','superpixels','rough_model','finetune_model']]
df_test_metrics = df_test_metrics.transpose()
df_test_metrics = df_test_metrics[['accuracy','precision','recall','f1','fom']]
np.round(df_test_metrics,3)

In [None]:
# Finetune metrics
df_finetune_metrics = pd.DataFrame(finetune_metrics)

df_finetune_metrics = df_finetune_metrics[['original','superpixels','rough_model','finetune_model']]
df_finetune_metrics = df_finetune_metrics.transpose()
df_finetune_metrics = df_finetune_metrics[['accuracy','precision','recall','f1','fom']]

np.round(df_finetune_metrics,3)

# Archive

In [None]:
for i in range(75):
    pred = preds[i]
    target = finetune_targets[i]

    accuracy = np.sum(pred == target) / np.size(pred)

    if accuracy < 0.6:
        print(f"{i}: {accuracy}")

        fig, ax = plt.subplots(1,2,figsize=(10,5))
        ax[0].imshow(pred)
        ax[1].imshow(target)
        