Copyright (c) 2024 Microsoft Corporation.

Licensed under the MIT License

Perform TARDIS on FTW dataset
- Define ID and WILD dataloaders
- Form and apply the TARDIS pipeline on the pretrained FTW model
- Plot f, g distribution on map
- Plot the activation space of g
- Plot f, g predictions

In [None]:
import os
from datetime import datetime
from pathlib import Path

import tqdm

import numpy as np
import torch
import yaml

from scipy.stats import skew
from torch.utils.data import DataLoader

import sys

sys.path.append("..")

from torchgeo.trainers import SemanticSegmentationTask
from src.tardis_ftw.data_utils import (
    FTWDataModuleOOD,
    WILDDataset,
    process_ID_dataloader,
    process_WILD_dataloader,
)
from src.tardis_ftw.tardis_wrapper import TARDISWrapper, get_model_layers
from src.tardis_ftw.utils import (
    calculate_metrics,
    extract_data,
    load_config,
    percentile_stretch,
    plot_f_and_g_preds_probab,
    plot_g_prob_distribution_w_skewness,
    plot_histograms_for_countries,
    plot_ID_surrID_surrOOD,
    plot_ID_WILD,
    plot_tsne,
    plot_ftwtest_f1_skew_r2,
)

%load_ext autoreload
%autoreload 2

In [None]:
# Configuration from YAML file
config = load_config(config_path="../src/tardis_ftw/config.yaml")

# Base directory
base_dir = Path(config["base_dir"])

# Paths using the base directory
model_checkpoint = base_dir / config["paths"]["model_checkpoint"]
root_folder_torchgeo = base_dir / config["paths"]["root_folder_torchgeo"]
path_wild_patches = base_dir / config["paths"]["path_wild_patches"]

# Dataloader parameters
batch_size = config["dataloader"]["batch_size"]
num_workers = config["dataloader"]["num_workers"]
sample_N_from_each_country = config["dataloader"]["sample_N_from_each_country"]
all_countries = config["dataloader"]["all_countries"]
val_countries = config["dataloader"]["val_countries"]
test_countries = config["dataloader"]["test_countries"]
target = config["dataloader"]["target"]

# OOD detector parameters
chosen_layer = config["ood_detector"]["chosen_layer"]
resize_factor = config["ood_detector"]["resize_factor"]
id_fraction_thr = config["ood_detector"]["id_fraction_thr"]
n_batches_to_process = config["ood_detector"]["n_batches_to_process"]
random_state = config["ood_detector"]["random_state"]
estimators = config["ood_detector"]["estimators"]
test_size = config["ood_detector"]["test_size"]
use_optuna = config["ood_detector"]["use_optuna"]
patch_size = config["ood_detector"]["patch_size"]
device = "cuda" if torch.cuda.is_available() else "cpu"

# Path verification
print(f"Model checkpoint path: {model_checkpoint.resolve()}")
print(f"Root folder for TorchGeo: {root_folder_torchgeo.resolve()}")
print(f"Path to wild patches: {path_wild_patches.resolve()}")

In [None]:
# Load the checkpoint to get the hyperparams
checkpoint = torch.load(model_checkpoint, map_location="cpu")
params = checkpoint["hyper_parameters"]

# Ensure compatibility of the checkpoint parameters with the task definition
# For example, converting class_weights to a torch.Tensor if required by the model
if "class_weights" in params and isinstance(params["class_weights"], list):
    params["class_weights"] = torch.tensor(params["class_weights"], dtype=torch.float32)

# Load the task from the checkpoint
task = SemanticSegmentationTask.load_from_checkpoint(
    checkpoint_path=model_checkpoint,
    map_location=device,
    **params  # Use updated params
)

task.freeze()
model = task.model
model = model.eval().to(device)

layer_names = get_model_layers(model)

ID Dataloader

In [None]:
FTWDatamoduleOOD = FTWDataModuleOOD(
    root_folder_torchgeo,
    batch_size=batch_size,
    num_workers=num_workers,
    train_countries=all_countries,
    val_countries=val_countries,
    test_countries=test_countries,
    download=True,
    sample_N_from_each_country=sample_N_from_each_country,
    target=target,
)

FTWDatamoduleOOD.setup(stage="fit")
id_train_dataloader = FTWDatamoduleOOD.train_dataloader()
val_dataloader = FTWDatamoduleOOD.val_dataloader()
FTWDatamoduleOOD.setup(stage="test")
test_dataloader = FTWDatamoduleOOD.test_dataloader()

print("len train samples", len(id_train_dataloader.dataset))
print("len val samples", len(val_dataloader.dataset))
print("len test samples", len(test_dataloader.dataset))
print(next(iter(id_train_dataloader))["image"].shape)
print("DataModule setup complete.")

WILD Dataloader

In [None]:
wild_dataset = WILDGeoTIFFDataset(directory=path_wild_patches)
wild_data_loader = DataLoader(wild_dataset, batch_size=batch_size, shuffle=False)

print("len(dataset):", len(wild_data_loader.dataset))
print(next(iter(wild_data_loader))["image"].shape)
len(wild_dataset.coords), len(wild_dataset.valid_pairs)

Form the Pipeline

In [None]:
num_clusters = int(0.3 * len(id_train_dataloader.dataset))

ood_model = TARDISWrapper(
    base_model=model,
    hook_layer_name=chosen_layer,
    main_loader=FTWDatamoduleOOD,
    id_loader=id_train_dataloader,
    wild_loader=wild_data_loader,
    n_batches_to_process=n_batches_to_process,
    test_size=test_size,
    use_optuna=False,
    num_clusters=num_clusters,
    M=id_fraction_thr,
    random_state=random_state,
    n_estimators=estimators,
    resize_factor=resize_factor,
    patch_size=patch_size,
    device=device,
    classifier_save_path="ood_classifier.pkl",
)

print("number of ID samples", len(id_train_dataloader.dataset))
print("number of WILD samples", len(wild_data_loader.dataset))
print("Chosen layer:", chosen_layer)
print("Number of clusters:", num_clusters)

Apply TARDIS:
- Compute features
- Apply surrogate label assignment 
- Train the classifier g

In [None]:
if hasattr(ood_model.ood_classifier, "classes_"):
    print("'g' is already loaded.")
    pass
else:
    print("'g' is not loaded.")
    # Compute features
    X, y = ood_model.compute_features()
    # Feature space clustering
    y_clustered = ood_model.feature_space_clustering(X, y)
    # Classification
    metrics = ood_model.g_classification(X, y_clustered)
    # Print metrics if training a classifier
    print(metrics["accuracy"])
    print(metrics["classification_report"])
    print(metrics["fpr95"])
    print(metrics["roc_auc"])

Collect ID 

In [None]:
ID_all = process_ID_dataloader(
    id_train_dataloader,
    ood_model,
    return_batch=True,
    return_f_pred=True,
    return_g_pred=True,
    return_thresholded_g_pred=False,
    return_coords=True,
    return_masks=True,
    upsample=True,
    max_batches=None,
)

Collect WILD

In [None]:
WILD_all = process_WILD_dataloader(
    wild_data_loader,
    ood_model,
    return_batch=True,
    return_f_pred=True,
    return_g_pred=True,
    return_thresholded_g_pred=False,
    return_coords=True,
    upsample=True,
    max_batches=None,
)

Sanity check

In [None]:
try:
    if y is not None:
        y_uniq, y_counts = np.unique(y, return_counts=True)
        print(
            "There are {unique} unique clusters with counts {counts}".format(
                unique=len(y_uniq), counts=dict(zip(y_uniq, y_counts))
            )
        )
        print("X .shape:", X.shape)
        y_clustered_uniq, y_clustered_counts = np.unique(
            y_clustered, return_counts=True
        )
        print(
            "There are {unique} unique clusters with counts {counts}".format(
                unique=len(y_clustered_uniq),
                counts=dict(zip(y_clustered_uniq, y_clustered_counts)),
            )
        )
except NameError:
    print("Variable 'y' or 'y_clustered' is not defined.")

### Plot the activation space: "X" in 2D, labelled with y (ID/WILD labels) and y_clustered (Surrogate ID and Surrogate OOD labels)

In [None]:
try:
    plot_tsne(X, y, y_clustered)

except NameError:
    print("Variable 'y' or 'y_clustered' is not defined.")

### Plot on map: ID and WILD

In [None]:
plot_ID_WILD(
    ID_all,
    WILD_all,
    save=True,
    file_format="png",
    dpi=100,
    filepath="./plots",
    filename="on_map_id_wild",
)

### Plot on map: ID and WILD breakdown into surrogate ID and surrogate OOD

In [None]:
plot_ID_surrID_surrOOD(
    WILD_all,
    ID_all,
    save=True,
    file_format="png",
    dpi=100,
    filepath="./plots",
    filename="on_map_id_surrID_surrOOD",
)

### g(WILD) probabilities in a Histogram

In [None]:
skew_prob = skew(WILD_all["g_pred_probs"])
print("Skewness of g_pred over WILD data", skew_prob)

plot_g_prob_distribution_w_skewness(
    WILD_all["g_pred_probs"], suffix="WILD", skewness_value=skew_prob, save_plot=True
)

### g and f(FTW_{train}^country[i])

In [None]:
target

In [None]:
dataloader_factory = {}
max_batches = None
batch_size = 1

for country in all_countries:
    FTWDatamoduleOOD = FTWDataModuleOOD(
        root_folder_torchgeo,
        batch_size=batch_size,
        num_workers=num_workers,
        train_countries=None,
        val_countries=None,
        test_countries=country,
        download=False,
        sample_N_from_each_country=sample_N_from_each_country,
        target=target,
    )

    FTWDatamoduleOOD.setup(stage="test")
    test_dataloader = FTWDatamoduleOOD.test_dataloader()
    dataloader_factory[country] = test_dataloader

country_results = {}
metrics_dict = {}

for country in tqdm.tqdm(all_countries):
    print("Country:", country)

    testset_id_dataloader = dataloader_factory[country]

    return_f_pred = True
    return_g_pred = True
    return_thresholded_g_pred = False
    return_coords = False
    return_masks = True
    upsample = True
    max_batches = max_batches

    ID_all = process_ID_dataloader(
        testset_id_dataloader,
        ood_model,
        return_f_pred=return_f_pred,
        return_g_pred=return_g_pred,
        return_thresholded_g_pred=return_thresholded_g_pred,
        return_coords=return_coords,
        return_masks=return_masks,
        upsample=upsample,
        max_batches=max_batches,
    )

    f_preds_single_channel = ID_all["f_preds"]
    g_pred_probs_testsetid = ID_all["g_pred_probs"]
    true_masks_all = ID_all["masks"]

    metrics = calculate_metrics(true_masks_all, f_preds_single_channel)

    country_results[country] = {
        "f_preds_testsetid": f_preds_single_channel,
        "g_pred_probs_testsetid": g_pred_probs_testsetid,
        "true_masks": true_masks_all,
        "metrics": metrics,
    }

    # Clean up
    del f_preds_single_channel, g_pred_probs_testsetid, true_masks_all
    torch.cuda.empty_cache()

Hist for all Countries

In [None]:
plot_histograms_for_countries(country_results, metric="f1")

F1 score vs. OOD score probability Plot

In [None]:
df = extract_data(country_results)

df["country"] = df["country"].str.strip().str.lower()

countries_to_remove = ["rwanda", "kenya", "india", "brazil"]
df = df[~df["country"].isin(countries_to_remove)]

df["country"] = df["country"].str.replace("_", " ").str.title()

df = df.reset_index(drop=True)

plot_ftwtest_f1_skew_r2(df, save_plot=True)

### Input, g, and f preds of 10 Percentiles: lowest, mid and top. Plot and save them all 

In [None]:
# Sort the array and get the corresponding indexes
sorted_indexes = np.argsort(WILD_all["g_pred_probs"])

# Determine the percentile ranges
n = len(WILD_all["g_pred_probs"])
lowest_10_percentile_indexes = sorted_indexes[: int(0.1 * n)]
mid_10_percentile_start = int(0.45 * n)
mid_10_percentile_end = int(0.55 * n)
mid_10_percentile_indexes = sorted_indexes[
    mid_10_percentile_start:mid_10_percentile_end
]
top_10_percentile_indexes = sorted_indexes[-int(0.1 * n) :]

print("Indexes of the lowest 10 percentile values:", lowest_10_percentile_indexes)
print("Indexes of the middle 10 percentile values:", mid_10_percentile_indexes)
print("Indexes of the top 10 percentile values:", top_10_percentile_indexes)

In [None]:
idx = np.random.choice(lowest_10_percentile_indexes)
coord = WILD_all["coords"][idx]
g_pred_prob_wild = WILD_all["g_pred_probs"][idx]

window_a = percentile_stretch(WILD_all["batch"][idx, :3, :, :].permute(1, 2, 0))
window_b = percentile_stretch(WILD_all["batch"][idx, 4:-1, :, :].permute(1, 2, 0))

f_pred_permuted = WILD_all["f_preds"][idx]

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
fname = f"{g_pred_prob_wild}_WILD_{idx}_{timestamp}.png"
print("OOD Score", g_pred_prob_wild)
plot_f_and_g_preds_probab(window_a, window_b, g_pred_prob_wild, f_pred_permuted)

Save them all

In [None]:
percentiles_dict = {
    "lowest_10": lowest_10_percentile_indexes,
    "mid_10": mid_10_percentile_indexes,
    "top_10": top_10_percentile_indexes,
}

num_samples_to_process = len(percentiles_dict["lowest_10"])

for perc_name, percentile_idx in percentiles_dict.items():
    sample_count = 0
    for idx in percentile_idx:
        if sample_count >= num_samples_to_process:
            break

        coord = WILD_all["coords"][idx]
        g_pred_prob_wild = WILD_all["g_pred_probs"][idx]

        window_a = percentile_stretch(WILD_all["batch"][idx, :3, :, :].permute(1, 2, 0))
        window_b = percentile_stretch(
            WILD_all["batch"][idx, 4:-1, :, :].permute(1, 2, 0)
        )

        f_pred_permuted = WILD_all["f_preds"][idx]

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        fname = f"{perc_name}_{g_pred_prob_wild}_WILD_{idx}_{timestamp}.png"

        plot_f_and_g_preds_probab(
            window_a,
            window_b,
            g_pred_prob_wild,
            f_pred_permuted,
            "./preds_percentile",
            fname,
        )

        sample_count += 1