# LICS Evaluation 
Evaluate different segmentation approaches on the LICS test dataset. These include the deterministic superpixel algorithm, pretrained and finetuned U-Net model. The accuracy, precision, recall, F1 and FOM metrics are calculated. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import train
import network
import torch

import glob

import importlib
import evaluation as eval
import utils
importlib.reload(utils)
importlib.reload(eval)

base_path = '../../data/'

# LICS

In [None]:
# Model names
lics_original = "LICS_UNET_12JUL2024.pth" #Previous SOTA approach (model from original LICS paper)
# Note reuslts will be slightly differnet than LICS paper due to random seed
lics_superpixel = "LICS_SUPERPIXELS_26JUL2024.pth" #Pretrained model trained on superpixel output
lics_finetune = "LICS_FINETUNE_26JUL24.pth" #Fine-tuned model 

In [None]:
# Load LICS data
incl_bands = [0,1,2,3,4,5,6]
satellite = 'landsat'

# Test data
target_pos = -2

lics_test_file = base_path + 'LICS/test/'
lics_test_paths = glob.glob(lics_test_file + '*.npy')
lics_test_targets = [np.load(file)[:,:,target_pos] for file in lics_test_paths]
lics_test_input = [np.load(file)[:,:,incl_bands] for file in lics_test_paths]

print("Test dimensions:")
print(np.shape(lics_test_targets))
print(np.shape(lics_test_input))

# Finetune data
target_pos = -1

lics_finetune_file = base_path + 'LICS/finetune/'
lics_finetune_paths = glob.glob(lics_finetune_file + '*.npy')
lics_finetune_targets = [np.load(file)[:,:,target_pos] for file in lics_finetune_paths]
lics_finetune_input = [np.load(file)[:,:,incl_bands] for file in lics_finetune_paths]

print("\nFinetune dimensions:")
print(np.shape(lics_finetune_targets))
print(np.shape(lics_finetune_input))

In [None]:
# Sense check the data

# Test data
i = np.random.randint(0,len(lics_test_paths))
rgb = utils.get_rgb(lics_test_input[i],satellite=satellite,contrast=0.2)
target = lics_test_targets[i]

fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(rgb)
ax[1].imshow(target)

for a in ax:
    a.axis('off')

# Finetune data
i = np.random.randint(0,len(lics_finetune_paths))
rgb = utils.get_rgb(lics_finetune_input[i],satellite=satellite,contrast=0.2)
target = lics_finetune_targets[i]

fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(rgb)
ax[1].imshow(target, cmap='gray')

for a in ax:
    a.axis('off')

## Superpixel algorithm

In [None]:
# Ititialize metrics
lics_test_metrics = {}
lics_finetune_metrics = {}

In [None]:
def get_sp_predictions(paths,satellite,rgb_bands,index_name,threshold = -1,method='slic', **kwargs):
    # Copy the input image to avoid modifying the original
    preds = []
    for path in paths:
        all_bands = np.load(path)
        mask = utils.get_mask_from_bands(all_bands, 
                                         satellite=satellite,
                                         rgb_bands=rgb_bands,
                                         threshold=threshold, 
                                         index_name=index_name,
                                         method=method, **kwargs)
        preds.append(mask)
    return preds


In [None]:
# Test
preds_sp = get_sp_predictions(lics_test_paths,
                           satellite='landsat',
                           rgb_bands=["nir", "green", "blue"],
                           index_name="NDWI",
                           threshold=-1, 
                           method='felzenszwalb',
                           min_size=60)

metrics, arr = eval.eval_metrics(lics_test_targets,preds_sp)
lics_test_metrics['superpixels'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Finetune
preds = get_sp_predictions(lics_finetune_paths,
                            satellite='landsat',
                            rgb_bands=["nir", "green", "blue"],
                            index_name="NDWI",
                            threshold=-1, 
                            method='felzenszwalb',
                            min_size=60)

metrics, arr = eval.eval_metrics(lics_finetune_targets,preds)
lics_finetune_metrics['superpixels'] = metrics
eval.display_metrics(metrics,arr)

# Original model

In [None]:
# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

model = network.U_Net(7,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{lics_original}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Test
targets, preds = eval.get_preds(model,lics_test_paths,target_pos=-2,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(lics_test_targets ,preds)
lics_test_metrics['original'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Finetune
targets, preds = eval.get_preds(model,lics_finetune_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(lics_finetune_targets ,preds)
lics_finetune_metrics['original'] = metrics
eval.display_metrics(metrics,arr)

## Rough model

In [None]:
# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

model = network.U_Net(7,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{lics_superpixel}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Test
targets, preds_pretrained = eval.get_preds(model,lics_test_paths,target_pos=-2,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(lics_test_targets ,preds_pretrained)
lics_test_metrics['rough_model'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Finetune
targets, preds = eval.get_preds(model,lics_finetune_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(lics_finetune_targets ,preds)
lics_finetune_metrics['rough_model'] = metrics
eval.display_metrics(metrics,arr)


In [None]:
# Display some predictions
i = np.random.randint(0,len(lics_test_paths))
rgb = utils.get_rgb(lics_test_input[i],satellite=satellite,contrast=0.2)
target = lics_test_targets[i]
pred = preds[i]

fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(rgb)
ax[1].imshow(target)
ax[2].imshow(pred)

for a in ax:
    a.axis('off')

# Finetuned model

In [None]:
# Set device
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))

model = network.U_Net(7,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{lics_finetune}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Test
targets, preds_finetuned = eval.get_preds(model,lics_test_paths,target_pos=-2,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(lics_test_targets ,preds_finetuned)
lics_test_metrics['finetune_model'] = metrics
eval.display_metrics(metrics,arr)

In [None]:
# Display some predictions
i = np.random.randint(0,len(lics_test_paths))
rgb = utils.get_rgb(lics_test_input[i],satellite=satellite,contrast=0.2)
target = lics_test_targets[i]
pred = preds[i]

fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(rgb)
ax[1].imshow(target)
ax[2].imshow(pred)

for a in ax:
    a.axis('off')

In [None]:
# Finetune
targets, preds = eval.get_preds(model,lics_finetune_paths,target_pos=-1,incl_bands=incl_bands,satellite=satellite,batch_size=10)
print(len(preds))

metrics, arr = eval.eval_metrics(lics_finetune_targets ,preds)
lics_finetune_metrics['finetune_model'] = metrics
eval.display_metrics(metrics,arr)

# Final Metrics Table

In [None]:
import json

# Save metrics
with open('lics_test_metrics.json', 'w') as f:
    json.dump(lics_test_metrics, f)

with open('lics_finetune_metrics.json', 'w') as f:
    json.dump(lics_finetune_metrics, f)

In [None]:
# Test metrics
df_test_metrics = pd.read_csv('lics_test_metrics.csv')
df_test_metrics

In [None]:
df_test_metrics = json.load(open('lics_test_metrics.json'))
df_test_metrics = pd.DataFrame(df_test_metrics)

df_test_metrics = df_test_metrics[['original','superpixels','rough_model','finetune_model']]
df_test_metrics = df_test_metrics.transpose()
df_test_metrics = df_test_metrics[['accuracy','precision','recall','f1','fom']]
np.round(df_test_metrics,3)

In [None]:
# Finetune metrics
df_finetune_metrics = json.load(open('lics_finetune_metrics.json'))
df_finetune_metrics = pd.DataFrame(df_finetune_metrics)

df_finetune_metrics = df_finetune_metrics[['original','superpixels','rough_model','finetune_model']]
df_finetune_metrics = df_finetune_metrics.transpose()
df_finetune_metrics = df_finetune_metrics[['accuracy','precision','recall','f1','fom']]

np.round(df_finetune_metrics,3)

In [None]:
accuracy = df_finetune_metrics[['accuracy']]
accuracy['Test'] = df_test_metrics[['accuracy']]
accuracy.columns = ['Finetune','Test']
round(accuracy,3)

# Visualisations

In [None]:
# Display some predictions
i = 84
print(i)
rgb = utils.get_rgb(lics_test_input[i],satellite=satellite,contrast=0.2)
target = lics_test_targets[i]

fig, ax = plt.subplots(1,5,figsize=(15,5))
ax[0].imshow(rgb)
ax[1].imshow(target, cmap='gray')
ax[1].set_title("Ground Truth")

ax[2].imshow(preds_sp[i], cmap='gray')
accuracy = np.sum(preds_sp[i] == target) / np.size(target)
ax[2].set_title("Deterministic ({:.3f})".format(accuracy))

ax[3].imshow(preds_pretrained[i], cmap='gray')
accuracy = np.sum(preds_pretrained[i] == target) / np.size(target)
ax[3].set_title("Pretrained ({:.3f})".format(accuracy))

ax[4].imshow(preds_finetuned[i], cmap='gray')
accuracy = np.sum(preds_finetuned[i] == target) / np.size(target)
ax[4].set_title("Finetunned ({:.3f})".format(accuracy))

for a in ax:
    a.set_xticks([])
    a.set_yticks([])

#utils.save_fig(fig, 'inland_water_bodies')

In [None]:
# Display some predictions
i = 16
rgb = utils.get_rgb(lics_test_input[i],satellite=satellite,contrast=0.2)
target = lics_test_targets[i]

fig, ax = plt.subplots(1,5,figsize=(15,5))
ax[0].imshow(rgb)
ax[1].imshow(target, cmap='gray')
ax[1].set_title("Ground Truth")

ax[2].imshow(preds_sp[i], cmap='gray')
accuracy = np.sum(preds_sp[i] == target) / np.size(target)
ax[2].set_title("Deterministic ({:.3f})".format(accuracy))

ax[3].imshow(preds_pretrained[i], cmap='gray')
accuracy = np.sum(preds_pretrained[i] == target) / np.size(target)
ax[3].set_title("Pretrained ({:.3f})".format(accuracy))

ax[4].imshow(preds_finetuned[i], cmap='gray')
accuracy = np.sum(preds_finetuned[i] == target) / np.size(target)
ax[4].set_title("Finetunned ({:.3f})".format(accuracy))

for a in ax:
    a.set_xticks([])
    a.set_yticks([])