# Indiana Pines

1. Import dependencies

In [None]:
import random
import torch
import multiprocessing
import numpy as np
import torch.utils.data as data

from torch import nn

from sklearn.model_selection import train_test_split
from lightning import Trainer
from torch.optim.lr_scheduler import StepLR

from src.util.torch import resolve_torch_device
from src.util.hsi import (
    extract_patches,
    reduce_hsi_dim,
    train_test_band_patch_split,
    preprocess_hsi,
    to_bin_labels_mask,
    to_bin_labels,
    to_pu_labels,
    multi_to_bi_class_pu_ds,
    PreProcessType,
    DimReductionType,
)
from src.data.indian_pines import load_indian_pines
from src.model.hsic import HyperSpectralImageClassifier
from src.model.lenet import FullyConvolutionalLeNet, PuLeNet
from src.visualization.plot import (
    plot_segmentation_comparison,
    plot_numbers_distribution,
    plot_epoch_generic_comparison,
    plot_epoch_generic,
)
from src.data.dataset_decorator import UnlabeledDatasetDecorator
from src.util.reporting import (
    create_model_name,
    report_run,
    read_report_to_show,
    lightning_metrics,
)
from src.util.list_ext import smooth_moving_average
from src.util.loss import PULoss

2. Prepare env

In [None]:
learning_rate = 1e-3
num_epochs = 30
scheduler_step_size = num_epochs 
scheduler_gamma = 0.9

In [None]:
batch_size = 64
patch_size = 9
target_dim = 75
examples_per_class = []
smoth_window = 2
target_class = 1

pre_process_type = PreProcessType.STANDARTIZATION
dim_reduction_type = DimReductionType.PCA

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.set_float32_matmul_precision("medium")

In [None]:
f"Device is {device}"

3. Load dataset

In [None]:
image, labels = load_indian_pines()

image_h, image_w, image_c = image.shape

In [None]:
_, image = preprocess_hsi(image, pre_process_type)

In [None]:
_, target_dim, image = reduce_hsi_dim(
    image, target_dim, dim_reduction_type, device, random_seed
)

In [None]:
x, y = extract_patches(image, labels, patch_size=patch_size)

In [None]:
num_classes = len(np.unique(y))

f"Number of classes {num_classes}"

In [None]:
examples_per_class = [20] * num_classes

In [None]:
if examples_per_class:
    x_train, y_train, x_test, y_test, y_masked = train_test_band_patch_split(
        x, y, examples_per_class, "indian_pines"
    )

    plot_segmentation_comparison(
        y.reshape(image_h, image_w),
        y_masked.reshape(image_h, image_w),
        title2="Masked",
    )
else:
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=random_seed, stratify=y
    )

x_train, i_y_train, x_test, i_y_test = multi_to_bi_class_pu_ds(
    x_train, y_train, x_test, y_test
)

In [None]:
x_tensor = torch.tensor(x, dtype=torch.float32).permute(0, 3, 1, 2) 
y_tensor = torch.tensor(y, dtype=torch.long)

In [None]:
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).permute(0, 3, 1, 2)
y_train_tensor = torch.tensor(i_y_train[target_class], dtype=torch.long)
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).permute(0, 3, 1, 2)
y_test_tensor = torch.tensor(i_y_test[target_class], dtype=torch.long)

In [None]:
train_dataset = data.TensorDataset(x_train_tensor, y_train_tensor)
test_dataset = data.TensorDataset(x_test_tensor, y_test_tensor)

In [None]:
cpu_count = multiprocessing.cpu_count()

f"Setting num_workers to {cpu_count}"

In [None]:
full_dataset = data.TensorDataset(x_tensor, y_tensor)

train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=cpu_count,
    persistent_workers=True,
)

test_loader = data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=cpu_count,
    persistent_workers=True,
)

full_loader = data.DataLoader(
    full_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=cpu_count,
    persistent_workers=True,
)

predict_loader = data.DataLoader(
    UnlabeledDatasetDecorator(full_dataset),
    batch_size=batch_size,
    num_workers=cpu_count,
    persistent_workers=True,
)

4. Train model

In [None]:
_, counts = torch.unique(y_tensor, return_counts=True)

class_weights = 1.0 / counts 
class_weights = class_weights / class_weights.sum() 

class_weights

In [None]:
y_label, y_counts = torch.unique(y_tensor, return_counts=True)
y_counts_sum = counts.sum()

positive_probs = {l.item(): c / y_counts_sum for l, c in zip(y_label, y_counts)}

positive_probs

In [None]:
loss_fun = PULoss(prior=positive_probs[target_class])

backbone = PuLeNet(
    band=target_dim,
    classes=2,
)

pred_extractor = lambda logits: (torch.sigmoid(logits) > 0.5).int()
# pred_extractor = lambda logits: torch.argmax(logits, dim=1)

model = HyperSpectralImageClassifier(
    backbone,
    2,
    lr=learning_rate,
    loss_fun=loss_fun,
    # scheduler=lambda opt: StepLR(
    #     opt, step_size=scheduler_step_size, gamma=scheduler_gamma
    # ),
    pred_extractor=pred_extractor,
)

trainer = Trainer(accelerator="auto", devices=1, max_epochs=num_epochs)

In [None]:
trainer.fit(
    model,
    train_loader,
    test_loader,
)

In [None]:
smothed_train = smooth_moving_average(
    [it.loss.cpu() for it in model.train_metrics], smoth_window
)
smothed_eval = smooth_moving_average(
    [it.loss.cpu() for it in model.val_metrics], smoth_window
)

plot_epoch_generic_comparison(smothed_train, smothed_eval)

In [None]:
smooth_f1 = smooth_moving_average(
    [it.f1.cpu() for it in model.val_metrics], smoth_window
)

plot_epoch_generic(smooth_f1, desc="f1")

In [None]:
validation_result = trainer.validate(model, full_loader)

validation_result

5. Display prediction

In [None]:
y_pred = trainer.predict(model, predict_loader)

y_pred = torch.cat(y_pred, dim=0)
y_pred = pred_extractor(y_pred)
y_pred = y_pred.reshape(image_h, image_w)

y_true = np.zeros_like(y)
y_true[y == target_class] = 1

plot_segmentation_comparison(y_true.reshape(image_h, image_w), y_pred.numpy())

6. Write report

In [None]:
model_name = create_model_name(f"indian_pines_bin_{target_class}", examples_per_class)
model_category = "bin_lenet"

run_params = {
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "patch_size": patch_size,
    "target_dim": target_dim,
    "pre_process_type": str(pre_process_type),
    "dim_reduction_type": str(dim_reduction_type),
}

report_run(
    model_name=model_name,
    model_category=model_category,
    run_desc="Default run",
    run_params=run_params | model.get_params(),
    run_metrics=lightning_metrics(validation_result),
)

In [None]:
read_report_to_show(model_name, sort_by_metric="f1")

In [None]:
read_report_to_show(model_name, sort_by_metric="OA", model_category=model_category)