# Grid Search hyper parameters for SRST

1. Import dependencies

In [None]:
import random
import torch
import numpy as np
import pandas as pd
import time
import glob
import os

from src.util.torch import resolve_torch_device
from src.data.indian_pines import load_indian_pines
from src.util.hsi import sample_from_segmentation_matrix
from src.definitions import GREED_SEARCH_FOLDER
from src.model.grid_search import GridSearch
from src.model.spatial_regulated_self_training_grid_search import (
    SpatialRegulatedSelfTrainingPipelineGridSearchAdapter,
)

2. Prepare env

In [14]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

generator = torch.Generator()
generator.manual_seed(random_seed)

<torch._C.Generator at 0x2c6b800f0>

In [15]:
f"Device is {device}"

'Device is mps'

# Indian pines (Cluster exponential decay)

0. Set params

In [16]:
examples_per_class = 15
# epoch_seconds = int(time.time())
epoch_seconds = 1745249041
run_name = f"indian-pines-cluster-exponential-decay-{epoch_seconds}"

In [17]:
cpu_count = 4

f"Setting num_workers to {cpu_count}"

'Setting num_workers to 4'

1. Load dataset

In [18]:
image, labels = load_indian_pines()

In [19]:
num_classes = len(np.unique(labels))

f"Number of classes {num_classes}"

'Number of classes 17'

In [20]:
masked_labels = sample_from_segmentation_matrix(labels, examples_per_class)

2. Train model

In [21]:
params = {
    "splits": [4],
    "learning_rate": [1e-3, 1e-4, 1e-5],
    "patch_size": [9],
    "num_epochs": [11],
    "feature_extractor_epochs": [1, 5, 9, 11],
    "semantic_threshold": [0.5, 0.6, 0.7, 0.8],
    "lambda_v": [0.07, 0.09, 0.2, 0.3, 0.4, 0.49],
    "k_star": [num_classes * 2, num_classes * 3],
    "batch_size": [64],
}

adapter = SpatialRegulatedSelfTrainingPipelineGridSearchAdapter(
    params, image, masked_labels, labels, num_classes, device, random_seed, generator
)

In [22]:
log_dir = GREED_SEARCH_FOLDER / run_name

In [23]:
search = GridSearch(
    adapter=adapter,
    optimize_metric="kappa_score",
    log_dir=log_dir,
    num_workers=cpu_count,
)

In [None]:
_, best_params, best_score = search.run()

In [None]:
print("Best Params:", best_params)
print("Best Score:", best_score)

3. Training results

In [None]:
csv_files = glob.glob(os.path.join(log_dir, "*.csv"))

report = pd.concat([pd.read_csv(f) for f in csv_files])

report.head()

Unnamed: 0,splits,learning_rate,patch_size,num_epochs,feature_extractor_epochs,semantic_threshold,lambda_v,k_star,batch_size,overall_accuracy,average_accuracy,kappa_score,f1_score,best_iteration,best_kappa_score
0,4,0.001,9,11,11,0.5,0.07,34,64,0.09617,0.471391,0.049954,0.34612,3,0.06644
1,4,0.001,9,11,11,0.5,0.07,51,64,0.082465,0.495981,0.018962,0.349963,3,0.041227
0,4,0.001,9,11,1,0.5,0.07,34,64,0.037424,0.081474,-0.023413,0.086824,6,0.070588
1,4,0.001,9,11,1,0.5,0.07,51,64,0.069304,0.429489,0.082754,0.351052,3,0.12701
2,4,0.001,9,11,1,0.5,0.09,34,64,0.065251,0.25874,0.020772,0.248643,5,0.223197


In [27]:
len(report)

16

In [28]:
report.sort_values("kappa_score", ascending=False).head()

Unnamed: 0,splits,learning_rate,patch_size,num_epochs,feature_extractor_epochs,semantic_threshold,lambda_v,k_star,batch_size,overall_accuracy,average_accuracy,kappa_score,f1_score,best_iteration,best_kappa_score
6,4,0.001,9,11,1,0.5,0.3,34,64,0.08646,0.561332,0.172896,0.427189,11,0.172896
1,4,0.001,9,11,1,0.5,0.07,51,64,0.069304,0.429489,0.082754,0.351052,3,0.12701
0,4,0.001,9,11,11,0.5,0.07,34,64,0.09617,0.471391,0.049954,0.34612,3,0.06644
5,4,0.001,9,11,1,0.5,0.2,51,64,0.076939,0.522283,0.049296,0.374899,6,0.143055
0,4,1e-05,9,11,5,0.5,0.07,34,64,0.076566,0.525517,0.048435,0.375794,4,0.061346
