# Indiana Pines

1. Import dependencies

In [None]:
import random
import torch
import numpy as np

import torch.utils.data as data

from torch import nn

from src.util.hsi import (
    extract_patches,
    DimReductionType,
    PreProcessType,
    preprocess_hsi,
    reduce_hsi_dim,
    read_fixed_labels_mask,
    train_test_split_by_mask,
)
from src.util.torch import resolve_torch_device
from src.data.indian_pines import load_indian_pines
from src.visualization.plot import (
    plot_segmentation_comparison,
    plot_epoch_generic,
    plot_masked_segmentation_comparison,
)
from src.data.dataset_decorator import UnlabeledDatasetDecorator, BinaryDatasetDecorator
from src.model.lenet import PuLeNet
from src.trainer.multiview_trainer import MultiViewTrainer
from src.model.ensemble import MultiViewEnsemble
from src.trainer.base_trainer import AdamOptimizedModule
from src.trainer.classification_trainer import ClassificationTrainer
from src.util.reporting import (
    classification_trainer,
    create_model_name,
    report_run,
    read_report_to_show,
)

2. Prepare env

In [None]:
learning_rate = 1e-3
num_epochs = 15
scheduler_step_size = num_epochs
scheduler_gamma = 0.9
weight_decay = 0

In [None]:
batch_size = 32
patch_size = 9
target_dim = 75

pre_process_type = PreProcessType.STANDARTIZATION
dim_reduction_type = DimReductionType.PCA

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

generator = torch.Generator()
generator.manual_seed(random_seed)

In [None]:
f"Device is {device}"

3. Load dataset

In [None]:
image, labels = load_indian_pines()

image_h, image_w, image_c = image.shape

In [None]:
_, image = preprocess_hsi(image, pre_process_type)

In [None]:
num_classes = len(np.unique(labels))

f"Number of classes {num_classes}"

In [None]:
from src.util.dict_ext import arrange_and_repeat


examples_per_class = arrange_and_repeat(num_classes, 20)

4. Reduce dimensions

In [None]:
_, target_dim, image = reduce_hsi_dim(
    image, target_dim, dim_reduction_type, device, random_seed
)

5. Prepare dataset

In [None]:
x, y = extract_patches(image, labels, patch_size=patch_size)

In [None]:
mask = read_fixed_labels_mask("indian-pines-v1.npy")

x_train, y_train, x_test, y_test = train_test_split_by_mask(x, y, mask)

_ = plot_masked_segmentation_comparison(y.reshape(image_h, image_w), mask)

In [None]:
x_all = torch.tensor(x, dtype=torch.float32).permute(0, 3, 1, 2)
y_all = torch.tensor(y, dtype=torch.int32)
x_train = torch.tensor(x_train, dtype=torch.float32).permute(0, 3, 1, 2)
y_train = torch.tensor(y_train, dtype=torch.int32)
x_test = torch.tensor(x_test, dtype=torch.float32).permute(0, 3, 1, 2)
y_test = torch.tensor(y_test, dtype=torch.int32)

In [None]:
train_dataset = data.TensorDataset(x_train, y_train)
test_dataset = data.TensorDataset(x_test, y_test)
full_dataset = data.TensorDataset(x_all, y_all)

In [None]:
test_loader = data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)
unlabeled_loader = data.DataLoader(
    UnlabeledDatasetDecorator(test_dataset),
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)
full_loader = data.DataLoader(
    full_dataset,
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)
predict_loader = data.DataLoader(
    UnlabeledDatasetDecorator(full_dataset),
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)

6. Train semi-supervised model

In [None]:
from src.model.dbda import DBDA


models = []
trainers = []
labeled = []

for i in range(1, num_classes):
    model = AdamOptimizedModule(
        DBDA(
            band=target_dim,
            classes=1,
            flatten_out=True
        ),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    loss = nn.BCEWithLogitsLoss()
    trainer = ClassificationTrainer(
        num_epochs=num_epochs,
        num_classes=num_classes,
        criterion=loss,
        device=device,
        record_history=False,
        dl_accumulation_steps=2
    )

    models.append(model)
    trainers.append(trainer)
    labeled.append(
        data.DataLoader(
            BinaryDatasetDecorator(train_dataset, i),
            batch_size=batch_size,
            shuffle=True,
            generator=generator,
        )
    )

co_trainer = MultiViewTrainer(
    num_classes=num_classes,
    confidence_threshold=0.9,
    device=device,
    max_epochs=2
)

In [None]:
feedback, co_trained = co_trainer.fit(
    models=models,
    trainers=trainers,
    labeled=labeled,
    unlabeled=unlabeled_loader,
    ensemble_eval_dl=test_loader,
)

In [None]:
plot_epoch_generic([it.eval["eval_f1"] for it in feedback.history], desc="F1")

In [None]:
plot_epoch_generic([it.eval["eval_kappa"] for it in feedback.history], desc="Kappa")

In [None]:
plot_epoch_generic(
    [it.eval["eval_accuracy_avg"] for it in feedback.history], desc="AA"
)

In [None]:
plot_epoch_generic(
    [it.eval["eval_accuracy_overall"] for it in feedback.history], desc="OA"
)

In [None]:
plot_epoch_generic(
    [it.train["unlabeled_len"] for it in feedback.history], desc="Unlabeled count"
)

In [None]:
validation_result = co_trainer.validate(co_trained, full_loader)

validation_result

In [None]:
_, y_pred = co_trainer.predict(co_trained, predict_loader)
y_pred = torch.cat(y_pred, dim=0)
y_pred = y_pred.reshape(image_h, image_w)

plot_segmentation_comparison(labels, y_pred.cpu().numpy())

6. Write report

In [None]:
model_name = create_model_name("indian_pines_", examples_per_class)
model_category = "lenet_multiview"

In [None]:
run_params = {
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "patch_size": patch_size,
    "target_dim": target_dim,
    "pre_process_type": str(pre_process_type),
    "dim_reduction_type": str(dim_reduction_type),
}

report_run(
    model_name=model_name,
    model_category=model_category,
    run_desc="New fixed mask, 2 epoch, new unlabled shrinking",
    run_params=run_params | model.get_params(),
    run_metrics=classification_trainer(validation_result),
)

In [None]:
read_report_to_show(model_name, sort_by_metric="f1")

In [None]:
read_report_to_show(model_name, sort_by_metric="f1", model_category=model_category)