# Grid Search hyper parameters for improved SRST

1. Import dependencies

In [1]:
import random
import torch
import numpy as np
import pandas as pd
import time
import glob
import os

from src.util.torch_device import resolve_torch_device
from src.data.indian_pines import load_indian_pines
from src.util.semi_guided import sample_fraction_from_segmentation
from src.definitions import GREED_SEARCH_FOLDER
from src.model.grid_search import GridSearch
from src.model.improved_spatial_regulated_self_training_grid_search import (
    ImprovedSpatialRegulatedSelfTrainingPipelineGridSearchAdapter,
)

2. Prepare env

In [2]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

generator = torch.Generator()
generator.manual_seed(random_seed)

<torch._C.Generator at 0x7fe720bd4e10>

In [3]:
f"Device is {device}"

'Device is cuda'

# Indian pines (Cluster exponential decay)

0. Set params

In [4]:
examples_per_class = 0.7
epoch_seconds = int(time.time())
run_name = f"improved-indian-pines-cluster-exponential-decay-{epoch_seconds}"

In [5]:
cpu_count = 4

f"Setting num_workers to {cpu_count}"

'Setting num_workers to 4'

1. Load dataset

In [6]:
image, labels = load_indian_pines()

In [7]:
num_classes = len(np.unique(labels))

f"Number of classes {num_classes}"

'Number of classes 17'

In [8]:
masked_labels = sample_fraction_from_segmentation(labels, examples_per_class)

2. Train model

In [None]:
params = {
    "input_channels": [50, 125, 200],
    "learning_rate": [1e-3, 1e-4],
    "patch_size": [9],
    "num_epochs": [11],
    "feature_extractor_epochs": [1, 9, 11],
    "semantic_threshold": [0.5, 0.6, 0.8],
    "lambda_v": [0.07, 0.3, 0.49],
    "k_star": [num_classes * 3],
    "batch_size": [64],
}

adapter = ImprovedSpatialRegulatedSelfTrainingPipelineGridSearchAdapter(
    params,
    image,
    masked_labels,
    labels,
    num_classes,
    device,
    random_seed,
    generator,
)

In [10]:
log_dir = GREED_SEARCH_FOLDER / run_name

In [11]:
search = GridSearch(
    adapter=adapter,
    optimize_metric="kappa_score",
    log_dir=log_dir,
    num_workers=cpu_count,
)

In [12]:
_, best_params, best_score = search.run()

100%|██████████| 1/1 [00:20<00:00, 20.97s/it, best_score=0.0475, split=0]
100%|██████████| 1/1 [00:38<00:00, 38.99s/it, best_score=0.00086, split=1]
100%|██████████| 1/1 [01:21<00:00, 81.06s/it, best_score=-0.00687, split=2]
100%|██████████| 1/1 [01:32<00:00, 92.51s/it, best_score=-0.00726, split=3]


[(<src.pipeline.improved_regulated_self_training_pipeline.ImprovedSpatialRegulatedSelfTrainingPipeline object at 0x7fe58e9c5d00>, {'input_channels': 50, 'learning_rate': 0.001, 'patch_size': 9, 'num_epochs': 11, 'feature_extractor_epochs': 1, 'semantic_threshold': 0.5, 'lambda_v': 0.07, 'k_star': 51, 'batch_size': 64}, {'overall_accuracy': 0.07200118899345398, 'average_accuracy': 0.49907252192497253, 'kappa_score': 0.04750770330429077, 'f1_score': 0.36189278960227966}), (<src.pipeline.improved_regulated_self_training_pipeline.ImprovedSpatialRegulatedSelfTrainingPipeline object at 0x7fe58e9c7590>, {'input_channels': 50, 'learning_rate': 0.001, 'patch_size': 9, 'num_epochs': 11, 'feature_extractor_epochs': 3, 'semantic_threshold': 0.5, 'lambda_v': 0.07, 'k_star': 51, 'batch_size': 64}, {'overall_accuracy': 0.0613577701151371, 'average_accuracy': 0.09840665757656097, 'kappa_score': 0.0008600950241088867, 'f1_score': 0.06525576859712601}), (<src.pipeline.improved_regulated_self_training_pi

In [13]:
print("Best Params:", best_params)
print("Best Score:", best_score)

Best Params: {'input_channels': 50, 'learning_rate': 0.001, 'patch_size': 9, 'num_epochs': 11, 'feature_extractor_epochs': 1, 'semantic_threshold': 0.5, 'lambda_v': 0.07, 'k_star': 51, 'batch_size': 64}
Best Score: {'overall_accuracy': 0.07200118899345398, 'average_accuracy': 0.49907252192497253, 'kappa_score': 0.04750770330429077, 'f1_score': 0.36189278960227966}


3. Training results

In [14]:
csv_files = glob.glob(os.path.join(log_dir, "*.csv"))

report = pd.concat([pd.read_csv(f) for f in csv_files])

report.head()

Unnamed: 0,input_channels,learning_rate,patch_size,num_epochs,feature_extractor_epochs,semantic_threshold,lambda_v,k_star,batch_size,overall_accuracy,average_accuracy,kappa_score,f1_score,best_iteration,best_kappa_score
0,50,0.001,9,11,11,0.5,0.07,51,64,0.06001,0.092271,-0.007261,0.055608,1,0.041108
0,50,0.001,9,11,9,0.5,0.07,51,64,0.058559,0.146254,-0.006875,0.141001,1,0.053109
0,50,0.001,9,11,1,0.5,0.07,51,64,0.072001,0.499073,0.047508,0.361893,2,0.049551
0,50,0.001,9,11,3,0.5,0.07,51,64,0.061358,0.098407,0.00086,0.065256,5,0.040056


In [15]:
len(report)

4

In [16]:
report.sort_values("kappa_score", ascending=False).head()

Unnamed: 0,input_channels,learning_rate,patch_size,num_epochs,feature_extractor_epochs,semantic_threshold,lambda_v,k_star,batch_size,overall_accuracy,average_accuracy,kappa_score,f1_score,best_iteration,best_kappa_score
0,50,0.001,9,11,1,0.5,0.07,51,64,0.072001,0.499073,0.047508,0.361893,2,0.049551
0,50,0.001,9,11,3,0.5,0.07,51,64,0.061358,0.098407,0.00086,0.065256,5,0.040056
0,50,0.001,9,11,9,0.5,0.07,51,64,0.058559,0.146254,-0.006875,0.141001,1,0.053109
0,50,0.001,9,11,11,0.5,0.07,51,64,0.06001,0.092271,-0.007261,0.055608,1,0.041108
