# Context
This notebook drives the training process for different models.

In [None]:
# Set project's environment variables
import os
import sys
from dotenv import load_dotenv
load_dotenv(dotenv_path="../../../project.env")
sys.path.append(os.environ["PYTHONPATH"])

In [None]:
# Import project-wide and PH2 specific variables and functions
import superheader as sup
from TRAIN.architecture.archeader import knn

# Models

## Setup

In [None]:
import gc

In [None]:
import json
with open(os.path.join(sup.DATA_ROOT, "all-classes-subsets.json"), "r") as f:
    loaded = json.load(f)

# Convert keys back to int and lists back to tuples
subsets = {
    int(k): {d: v for d, v in v_dict.items()}
    for k, v_dict in loaded.items()
}
subsets


In [None]:
base_data_config = {
  "PH3" : False,
  "reducer" : '',
  "kernel" : '',
  "n" : -1,
  "label_col" : sup.class_numeric_column,
  "class_list" : 'specified'
}

base_train_config = {
  "arch" : sup.TRAIN_KNN_CODE,
  "k" : 1
}

num_class_candidates = list(range(2, 38+1, 1))

## Train

In [None]:
metric_tracker = list()

In [None]:
for data_unit in [sup.DATA_S_PF, sup.DATA_S_PV]:
  data_config = base_data_config.copy()
  train_config = base_train_config.copy()

  data_config["data_unit"] = data_unit

  for PH2 in [False, True]:
    data_config["PH2"] = PH2

    for n in num_class_candidates:
      s = subsets[n]
      for difficulty in ['easy', 'average', 'hard']:
        s_dif = s[difficulty]
        data_config["difficulty"] = difficulty
        data_config["class_numeric_list"] = s_dif
        
        print(data_config)
        print(train_config)
        model = knn.KNN(data_config=data_config, df=None, 
                          train_config=train_config)
        
        model.fit()

        model.test()

        model.full_score()
        print(model.accuracy)
        print(model.top2accuracy)
        print(model.macro_f1)
        print(model.macro_precision)
        print(model.macro_recall)

        metrics = {"data_unit" : data_unit,
                    "PH2" : PH2,
                    "num_classes" : n,
                    "difficulty" : difficulty,
                    "accuracy" : model.accuracy,
                    "top2accuracy" : model.top2accuracy,
                    "macro_f1" : model.macro_f1,
                    "macro_precision" : model.macro_precision,
                    "macro_recall" : model.macro_recall}
        
        model.keep_confusion_matrix()
        #model.keep_loss()
        
        metric_tracker.append(metrics.copy())
        
        print("clearing memory...")
        del model
        gc.collect()

# Keep metrics

In [None]:
import pandas as pd

metrics_df = pd.DataFrame(metric_tracker)
metrics_df

In [None]:
sup.create_dir_if_not_exists(os.path.join(sup.SCORES_ROOT, "specified", sup.TRAIN_KNN_CODE))

metrics_df.to_csv(os.path.join(sup.SCORES_ROOT, "specified", sup.TRAIN_KNN_CODE, "PREPanalysis.csv"), index=False)