# Indiana Pines

1. Import dependencies

In [None]:
import random
import torch
import numpy as np

import torch.utils.data as data

from torch import nn

from src.util.image import (
    scale_image
)
from src.util.patches import extract_patches
from src.util.torch_device import resolve_torch_device
from src.data.indian_pines import load_indian_pines
from src.model.lenet import FullyConvolutionalLeNet
from src.visualization.plot import plot_segmentation_comparison, plot_epoch_generic
from src.data.dataset_decorator import UnlabeledDatasetDecorator
from src.model.autoencoder import SpatialAutoEncoder
from src.trainer.autoencoder_trainer import AutoEncoderTrainer
from src.trainer.co_trainer import BiCoTrainer
from src.model.ensemble import Ensemble
from src.trainer.base_trainer import AdamOptimizedModule
from src.trainer.classification_trainer import ClassificationTrainer
from src.util.hsi import train_test_band_patch_split, reduce_depth_with_patched_autoencoder

2. Prepare env

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

batch_size = 32
patch_size = 9
fraction_of_examples = 0.1
examples_per_class = 20
reduced_dim_size = 200

device = resolve_torch_device()

generator = torch.Generator()
generator.manual_seed(random_seed)

In [None]:
f"Device is {device}"

3. Load dataset

In [None]:
image, labels = load_indian_pines()

image_h, image_w, image_c = image.shape

_, x = scale_image(image)

In [None]:
num_classes = len(np.unique(labels))

f"Number of classes {num_classes}"

4. Reduce dimensions

In [None]:
if reduced_dim_size != image_c:
    auto_encoder_epochs = 100
    auto_encoder_lr = 1e-3
    reduced_dim_size = 50

    autoencoder = AdamOptimizedModule(
        SpatialAutoEncoder(input_channels=image_c, embedding_size=reduced_dim_size),
        lr=auto_encoder_lr,
    )

    trainer = AutoEncoderTrainer(nn.MSELoss(), auto_encoder_epochs, device)

    x = reduce_depth_with_patched_autoencoder(x, patch_size, autoencoder, trainer, device)

5. Prepare dataset

In [None]:
x, y = extract_patches(x, labels, patch_size=patch_size)

examples_per_class_arr = np.repeat(examples_per_class, num_classes)

x_train, y_train, x_test, y_test, y_masked = train_test_band_patch_split(
    x, y, examples_per_class_arr
)

In [None]:
plot_segmentation_comparison(labels, y_masked.reshape(labels.shape), title2="Downsampled")

In [None]:
x_train.shape

In [None]:
x_all = torch.tensor(x, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
y_all = torch.tensor(y, dtype=torch.long, device=device)
x_train = torch.tensor(x_train, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
y_train = torch.tensor(y_train, dtype=torch.long, device=device)
x_test = torch.tensor(x_test, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
y_test = torch.tensor(y_test, dtype=torch.long, device=device)

In [None]:
train_dataset = data.TensorDataset(x_train, y_train)
test_dataset = data.TensorDataset(x_test, y_test)
full_dataset = data.TensorDataset(x_all, y_all)

In [None]:
train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    generator=generator,
)
test_loader = data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)
predict_loader = data.DataLoader(
    UnlabeledDatasetDecorator(full_dataset),
    batch_size=batch_size,
    shuffle=False,
    generator=generator,
)

In [None]:
f"Training samples: {len(train_dataset)}, Testing samples: {len(test_dataset)}"

4. Train model

In [None]:
num_epochs = 30
learning_rate = 1e-3

In [None]:
model = AdamOptimizedModule(
    FullyConvolutionalLeNet(reduced_dim_size, num_classes),
    lr=learning_rate,
)

loss = nn.CrossEntropyLoss()

trainer = ClassificationTrainer(
    num_epochs, num_classes, loss, device, gradient_accumulation_steps=2
)

In [None]:
feedback = trainer.fit(model, train_loader, test_loader)

In [None]:
plot_epoch_generic(
    [it.train["train_loss"] for it in feedback.history],
    desc="Train loss",
)

In [None]:
plot_epoch_generic(
    [it.eval["eval_loss"] for it in feedback.history if "eval_loss" in it.eval],
    desc="Eval loss",
)

5. Display prediction

In [None]:
trainer.validate(model, test_loader)

In [None]:
_, y_pred = trainer.predict(model, predict_loader)

In [None]:
y_pred = torch.cat(y_pred, dim=0)

In [None]:
y_pred = torch.argmax(y_pred, dim=1)
y_pred = y_pred.reshape(image_h, image_w)

In [None]:
plot_segmentation_comparison(labels, y_pred.cpu().numpy())

6. Train semi-supervised model

In [None]:
model_1 = AdamOptimizedModule(
    FullyConvolutionalLeNet(reduced_dim_size, num_classes),
    lr=learning_rate,
)

model_2 = AdamOptimizedModule(
    FullyConvolutionalLeNet(reduced_dim_size, num_classes),
    lr=learning_rate,
)

loss = nn.CrossEntropyLoss()

co_trainer = BiCoTrainer(
    batch_size=batch_size,
    confidence_threshold=0.9,
    generator=generator,
    trainer=ClassificationTrainer(
        num_epochs=num_epochs,
        num_classes=num_classes,
        criterion=loss,
        device=device,
        record_history=False,
        gradient_accumulation_steps=2,
    ),
)

In [None]:
feedback = co_trainer.fit(
    models=(model_1, model_2),
    labeled=train_dataset,
    unlabeled=test_dataset,
    eval_dl=test_loader,
)

In [None]:
feedback.history[-1].eval

In [None]:
plot_epoch_generic(
    [it.eval["eval_f1"] for it in feedback.history], desc="F1"
)

In [None]:
co_trained = Ensemble([model_1, model_2]).to(device)

_, y_pred = trainer.predict(co_trained, predict_loader)
y_pred = torch.cat(y_pred, dim=0)
y_pred = torch.argmax(y_pred, dim=1)
y_pred = y_pred.reshape(image_h, image_w)

plot_segmentation_comparison(labels, y_pred.cpu().numpy())