# Indiana Pines

1. Import dependencies

In [None]:
import random
import torch
import multiprocessing
import numpy as np

import torch.utils.data as data

from sklearn.model_selection import train_test_split

from lightning import Trainer

from src.util.torch import resolve_torch_device
from src.util.hsi import (
    extract_patches,
    reduce_hsi_dim,
    train_test_band_patch_split,
    preprocess_hsi,
    PreProcessType,
    DimReductionType,
)
from src.data.indian_pines import load_indian_pines
from src.model.hsic import HyperSpectralImageClassifier
from src.model.lenet import FullyConvolutionalLeNet
from src.visualization.plot import (
    plot_segmentation_comparison,
    plot_numbers_distribution,
)
from src.data.dataset_decorator import UnlabeledDatasetDecorator
from src.util.reporting import create_model_name, report_run, read_report_to_show

2. Prepare env

In [None]:
learning_rate = 1e-3
num_epochs = 12

In [None]:
batch_size = 32
patch_size = 9
examples_per_class = []
target_dim = 75

pre_process_type = PreProcessType.STANDARTIZATION
dim_reduction_type = DimReductionType.NOPE

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.set_float32_matmul_precision("medium")

In [None]:
f"Device is {device}"

3. Load dataset

In [None]:
image, labels = load_indian_pines()

image_h, image_w, image_c = image.shape

In [None]:
_, image = preprocess_hsi(image, pre_process_type)

In [None]:
_, target_dim, image = reduce_hsi_dim(
    image, target_dim, dim_reduction_type, device, random_seed
)

In [None]:
x, y = extract_patches(image, labels, patch_size=patch_size)

In [None]:
if examples_per_class:
    x_train, y_train, x_test, y_test, y_masked = train_test_band_patch_split(
        x, y, examples_per_class, "indian_pines"
    )

    plot_segmentation_comparison(
        y.reshape(image_h, image_w), y_masked.reshape(image_h, image_w)
    )
else:
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=random_seed, stratify=y
    )

# plot_numbers_distribution(y_train, desc="Train class distribution")

In [None]:
x_tensor = torch.tensor(x, dtype=torch.float32).permute(0, 3, 1, 2) 
y_tensor = torch.tensor(y, dtype=torch.long)
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).permute(0, 3, 1, 2)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).permute(0, 3, 1, 2)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

In [None]:
cpu_count = multiprocessing.cpu_count()

f"Setting num_workers to {cpu_count}"

In [None]:
train_dataset = data.TensorDataset(x_train_tensor, y_train_tensor)
test_dataset = data.TensorDataset(x_test_tensor, y_test_tensor)
full_dataset = data.TensorDataset(x_tensor, y_tensor)

train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=cpu_count,
    persistent_workers=True,
)
test_loader = data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=cpu_count,
    persistent_workers=True,
)
predict_loader = data.DataLoader(
    UnlabeledDatasetDecorator(full_dataset),
    batch_size=batch_size,
    num_workers=cpu_count,
    persistent_workers=True,
)

In [None]:
f"Training samples: {len(train_dataset)}, Testing samples: {len(test_dataset)}"

In [None]:
num_classes = len(np.unique(y))

f"Number of classes {num_classes}"

4. Train model

In [None]:
model = HyperSpectralImageClassifier(
    FullyConvolutionalLeNet(target_dim, num_classes), num_classes, lr=learning_rate
)

trainer = Trainer(accelerator="auto", devices=1, max_epochs=num_epochs)

In [None]:
trainer.fit(model, train_loader, test_loader)

In [None]:
validation_result = trainer.validate(model, test_loader)

validation_result

5. Display prediction

In [None]:
y_pred = trainer.predict(model, predict_loader)

In [None]:
y_pred = torch.cat(y_pred, dim=0)
y_pred = torch.argmax(y_pred, dim=1)
y_pred = y_pred.reshape(image_h, image_w)

In [None]:
plot_segmentation_comparison(y.reshape(image_h, image_w), y_pred.numpy(), num_classes)

6. Write report

In [26]:
model_name = create_model_name("indian_pines_", examples_per_class)
model_category = "lenet"

report_run(
    model_name=model_name,
    model_category=model_category,
    run_desc="Default lenet",
    run_params={
        "learning_rate": learning_rate,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "patch_size": patch_size,
        "target_dim": target_dim,
        "pre_process_type": str(pre_process_type),
        "dim_reduction_type": str(dim_reduction_type),
    },
    run_metrics=validation_result[-1],
)

PosixPath('/home/melal/Workspace/spatial-regulated-self-training/reports/runs/indian_pines__.csv')

In [27]:
read_report_to_show(model_name, sort_by_metric="val_f1")

Unnamed: 0,timestamp,model_category,run_desc,params,val_f1
0,2025-05-24T16:13:48.380674+00:00,lenet,Default lenet,"{""learning_rate"": 0.001, ""num_epochs"": 12, ""ba...",0.929728


In [28]:
read_report_to_show(model_name, sort_by_metric="val_f1", model_category=model_category)

Unnamed: 0,timestamp,model_category,run_desc,params,val_f1
0,2025-05-24T16:13:48.380674+00:00,lenet,Default lenet,"{""learning_rate"": 0.001, ""num_epochs"": 12, ""ba...",0.929728
