# Indiana Pines

1. Import dependencies

In [None]:
import random
import torch
import numpy as np

import torch.utils.data as data

from torch import nn

from src.util.hsi import (
    extract_patches,
    DimReductionType,
    PreProcessType,
    preprocess_hsi,
    extract_band_patches,
    train_test_band_patch_split,
    reduce_hsi_dim,
)
from src.model.spectral_former import SpectralFormer
from src.util.torch import resolve_torch_device
from src.data.indian_pines import load_indian_pines
from src.visualization.plot import plot_segmentation_comparison, plot_epoch_generic
from src.data.dataset_decorator import UnlabeledDatasetDecorator
from src.trainer.co_trainer import BiCoTrainer
from src.model.ensemble import Ensemble
from src.trainer.base_trainer import AdamOptimizedModule
from src.trainer.classification_trainer import ClassificationTrainer
from src.util.reporting import (
    classification_trainer,
    create_model_name,
    report_run,
    read_report_to_show,
)

2. Prepare env

In [None]:
dim = 32
depth = 5
heads = 4
mlp_dim = 8
dropout = 0.3
emb_dropout = 0.1
learning_rate = 5e-4
weight_decay = 5e-3
num_epochs = 300
scheduler_step_size = num_epochs // 10
scheduler_gamma = 0.9

In [None]:
batch_size = 64
patch_size = 7
band_patch = 3

target_dim = 75

pre_process_type = PreProcessType.STANDARTIZATION
dim_reduction_type = DimReductionType.PCA

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

generator = torch.Generator()
generator.manual_seed(random_seed)

In [None]:
f"Device is {device}"

3. Load dataset

In [None]:
image, labels = load_indian_pines()

image_h, image_w, image_c = image.shape

In [None]:
_, image = preprocess_hsi(image, pre_process_type)

In [None]:
num_classes = len(np.unique(labels))

f"Number of classes {num_classes}"

In [None]:
examples_per_class = [20] * num_classes

4. Reduce dimensions

In [None]:
_, target_dim, image = reduce_hsi_dim(
    image, target_dim, dim_reduction_type, device, random_seed
)

5. Prepare dataset

In [None]:
print(f"Image shape: {image.shape}")

x, y = extract_patches(image, labels, patch_size)

print(f"Patched image shape: {x.shape}")
print(f"Patched labels shape: {y.shape}")

x = extract_band_patches(x, band_patch)

print(f"Bandwise patched image shape: {x.shape}")

In [None]:
x_train, y_train, x_test, y_test, y_masked = train_test_band_patch_split(
    x, y, examples_per_class, "indian_pines"
)

In [None]:
plot_segmentation_comparison(labels, y_masked.reshape(labels.shape), title2="Downsampled")

In [None]:
x_train.shape

In [None]:
x_all = torch.tensor(x, dtype=torch.float32, device=device).permute(0, 2, 1)
y_all = torch.tensor(y, dtype=torch.long, device=device)
x_train = torch.tensor(x_train, dtype=torch.float32, device=device).permute(0, 2, 1)
y_train = torch.tensor(y_train, dtype=torch.long, device=device)
x_test = torch.tensor(x_test, dtype=torch.float32, device=device).permute(0, 2, 1)
y_test = torch.tensor(y_test, dtype=torch.long, device=device)

In [None]:
train_dataset = data.TensorDataset(x_train, y_train)
test_dataset = data.TensorDataset(x_test, y_test)
full_dataset = data.TensorDataset(x_all, y_all)

In [None]:
train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    generator=generator,
)
test_loader = data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)
full_loader = data.DataLoader(
    full_dataset,
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)
predict_loader = data.DataLoader(
    UnlabeledDatasetDecorator(full_dataset),
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)

In [None]:
f"Training samples: {len(train_dataset)}, Testing samples: {len(test_dataset)}"

6. Train semi-supervised model

In [None]:
model_1 = AdamOptimizedModule(
    SpectralFormer(
        image_size=patch_size,
        near_band=band_patch,
        num_patches=target_dim,
        num_classes=num_classes,
        dim=dim,
        depth=depth,
        heads=heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
        emb_dropout=emb_dropout,
    ),
    lr=learning_rate,
    weight_decay=weight_decay,
    scheduler_step_size=scheduler_step_size,
    scheduler_gamma=scheduler_gamma,
)

model_2 = AdamOptimizedModule(
    SpectralFormer(
        image_size=patch_size,
        near_band=band_patch,
        num_patches=target_dim,
        num_classes=num_classes,
        dim=dim,
        depth=depth,
        heads=heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
        emb_dropout=emb_dropout,
    ),
    lr=learning_rate,
    weight_decay=weight_decay,
    scheduler_step_size=scheduler_step_size,
    scheduler_gamma=scheduler_gamma,
)

loss = nn.CrossEntropyLoss()

trainer = ClassificationTrainer(
    num_epochs=num_epochs,
    num_classes=num_classes,
    criterion=loss,
    device=device,
    record_history=False,
    gradient_accumulation_steps=2,
)

co_trainer = BiCoTrainer(
    batch_size=batch_size,
    confidence_threshold=0.9,
    generator=generator,
    trainer=trainer,
)

In [None]:
feedback = co_trainer.fit(
    models=(model_1, model_2),
    labeled=train_dataset,
    unlabeled=test_dataset,
    eval_dl=test_loader,
)

In [None]:
feedback.history[-1].eval

In [None]:
plot_epoch_generic(
    [it.eval["eval_f1"] for it in feedback.history], desc="F1"
)

In [None]:
co_trained = Ensemble([model_1, model_2])

In [None]:
validation_result = trainer.validate(co_trained, full_loader)

validation_result

In [None]:
_, y_pred = trainer.predict(co_trained, predict_loader)
y_pred = torch.cat(y_pred, dim=0)
y_pred = torch.argmax(y_pred, dim=1)
y_pred = y_pred.reshape(image_h, image_w)

plot_segmentation_comparison(labels, y_pred.cpu().numpy())

6. Write report

In [None]:
model_name = create_model_name("indian_pines_", examples_per_class)
model_category = "specteal_former_co_training"

report_run(
    model_name=model_name,
    model_category=model_category,
    run_desc="Default run",
    run_params={
        "hidden_dim": dim,
        "num_layers": depth,
        "num_heads": heads,
        "mlp_dim": mlp_dim,
        "dropout": dropout,
        "emb_dropout": emb_dropout,
        "scheduler_step_size": scheduler_step_size,
        "scheduler_gamma": scheduler_gamma,
        "weight_decay": weight_decay,
        "learning_rate": learning_rate,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "patch_size": patch_size,
        "target_dim": target_dim,
        "pre_process_type": str(pre_process_type),
        "dim_reduction_type": str(dim_reduction_type),
    },
    run_metrics=classification_trainer(validation_result),
)

In [None]:
read_report_to_show(model_name, sort_by_metric="f1")

In [None]:
read_report_to_show(model_name, sort_by_metric="f1", model_category=model_category)