# Notebook for creating figures of the results

## Preliminaries

In [None]:
import os
import yaml
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from Dataset.dataset import TranslationDataset, normalize, esawc_to_image
from Models.classifiers import *

with open("local_config.yml", "r") as f:
    config = yaml.safe_load(f)
data_root = Path(config["data_root"])

# For results figure
countries = [
    "United_Kingdom",
    "Portugal", 
    "Czech_Republic", 
    "Sweden", 
]
regions = [
    "England",
    "Portugal",
    "Czech_Republic",
    "Stockholms_Laen",
]
idxs = [0, 229, 11, 0]

# For motivation figure
# countries = ["Croatia"]
# regions = ["Croatia"]
# idxs = [12]

mins = TranslationDataset.mins
maxs = TranslationDataset.maxs
visualization = [2,1,0]

## Dataset samples

In [None]:

for i in range(len(countries)):
    
    country, region, idx = countries[i], regions[i], idxs[i]
    
    s2 = torch.load(os.path.join(data_root, "Images", "Sentinel-2", "2018", country, f"{region}{idx}.pt"))
    s2 = normalize(s2, mins, maxs)
    s2 = s2[visualization].permute(1,2,0).numpy()
    plt.figure(figsize=(6, 6))
    plt.imshow(s2 * 1.5)
    plt.axis('off')
    plt.show()

    l8 = torch.load(os.path.join(data_root, "Images", "Landsat8", "2018", country, f"{region}{idx}.pt"))
    l8 = normalize(l8, mins, maxs)
    l8 = l8[visualization].permute(1,2,0).numpy()
    plt.figure(figsize=(6, 6))
    plt.imshow(l8)
    plt.axis('off')
    plt.show()

    l8_pan = torch.load(os.path.join(data_root, "Images", "Landsat8-Panchro", "2018", country, f"{region}{idx}.pt"))
    l8_pan = l8_pan.squeeze()
    plt.figure(figsize=(6, 6))
    plt.imshow(l8_pan, cmap="binary_r")
    plt.axis('off')
    plt.show()

## Translation results

In [None]:
for i in range(len(countries)):
    
    country, region, idx = countries[i], regions[i], idxs[i]
    
    img = torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4", "2018", country, f"{region}{idx}.pt"))
    img += torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4_v2", "2018", country, f"{region}{idx}.pt"))
    img += torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4_v3", "2018", country, f"{region}{idx}.pt"))
    # ls8 += torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4_v4", "2018", country, f"{region}{idx}.pt"))
    # ls8 += torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4_v5", "2018", country, f"{region}{idx}.pt"))
    img /= 3
    img = img[visualization].permute(1,2,0).numpy()
    plt.figure(figsize=(6,6))
    plt.imshow(img * 1.5)
    plt.axis('off')
    plt.show()

## Land cover classification results

In [None]:
from torchvision.transforms import GaussianBlur

classifier = DeepLabV3_SMP(4, 9, "resnet34")
classifier.load_state_dict(torch.load("Checkpoints/ClassifierDeepLabV3_2025-10-13_22-08-36/checkpoint.pt"))
classifier.eval()

for i in range(len(countries)):
    
    country, region, idx = countries[i], regions[i], idxs[i]
    
    img = torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4", "2018", country, f"{region}{idx}.pt"))
    img += torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4_v2", "2018", country, f"{region}{idx}.pt"))
    img += torch.load(os.path.join(data_root, "Images", "Landsat8-Translated", "Diffusion_EMA_2025-10-05_23-31-54_0.4_v3", "2018", country, f"{region}{idx}.pt"))
    img /= 3
    with torch.inference_mode():
        lc_hat = torch.argmax(classifier(img[None,:,:,:]).squeeze(), dim=0)
    lc_hat = lc_hat.numpy()
    plt.figure(figsize=(6,6))
    plt.imshow(esawc_to_image(lc_hat))
    plt.axis('off')
    plt.show()

## SSIM/FID vs IoU

In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('ggplot')
matplotlib.rcParams.update({'font.size': 22, 'font.family': 'serif'})
matplotlib.rcParams.update({})

plt.figure(figsize=(9,9))
# Pansharpening + Regression
plt.scatter(0.761, 32.0, s=100, c="palegreen", edgecolors="k", label="Bicubic + pansharpening")
plt.scatter(0.768, 33.8, s=100, c="limegreen", edgecolors="k", label="+ scale-only regression")
plt.scatter(0.764, 19.8, s=100, c="green", edgecolors="k", label="+ linear regression")
# DL methods
plt.scatter(0.826, 35.9, s=100, c="aquamarine", edgecolors="k", label="UNet")
plt.scatter(0.846, 44.5, s=100, c="darkturquoise", edgecolors="k", label="+ SSIM loss")
plt.scatter(0.821, 35.1, s=100, c="dodgerblue", edgecolors="k", label="AUNet")
plt.scatter(0.832, 36.7, s=100, c="blue", edgecolors="k", label="+ SSIM loss")
plt.scatter(0.827, 35.1, s=100, c="mediumorchid", edgecolors="k", label="UNet Ensemble")
# Generative
plt.scatter(0.732, 41.8, s=100, c="gold", edgecolors="k", label="Pix2Pix")
plt.scatter(0.723, 45.6, s=100, c="orange", edgecolors="k", label="Palette")
plt.scatter(0.806, 54.1, s=100, c="indianred", edgecolors="k", label="L8-S2 Diffusion")
plt.ylim([15,55])
plt.xlabel("SSIM")
plt.ylabel("IoU in %")
plt.show()

plt.figure(figsize=(9,9))
# Pansharpening + Regression
plt.scatter(64.6, 32.0, s=100, c="palegreen", edgecolors="k", label="Bicubic + pansharpening")
plt.scatter(62.1, 33.8, s=100, c="limegreen", edgecolors="k", label="+ scale-only regression")
plt.scatter(69.2, 19.8, s=100, c="green", edgecolors="k", label="+ linear regression")
# DL methods
plt.scatter(53.3, 35.9, s=100, c="aquamarine", edgecolors="k", label="UNet")
plt.scatter(51.3, 44.5, s=100, c="darkturquoise", edgecolors="k", label="+ SSIM loss")
plt.scatter(55.3, 35.1, s=100, c="dodgerblue", edgecolors="k", label="AUNet")
plt.scatter(51.9, 36.7, s=100, c="blue", edgecolors="k", label="+ SSIM loss")
plt.scatter(55.1, 35.1, s=100, c="mediumorchid", edgecolors="k", label="UNet Ensemble")
# Generative
plt.scatter(29.7, 41.8, s=100, c="gold", edgecolors="k", label="Pix2Pix")
plt.scatter(46.9, 45.6, s=100, c="orange", edgecolors="k", label="Palette")
plt.scatter(22.5, 54.1, s=100, c="indianred", edgecolors="k", label="L8-S2 Diffusion")
plt.ylim([15,55])
plt.xlabel("FID")
plt.ylabel("IoU in %")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()