# Context
This notebook drives the training process for different models.

In [None]:
# Set project's environment variables
import os
import sys
from dotenv import load_dotenv
load_dotenv(dotenv_path="../../../project.env")
sys.path.append(os.environ["PYTHONPATH"])

In [None]:
# Import project-wide and PH2 specific variables and functions
import superheader as sup
import TRAIN.architecture.BERT.bert as bert

# Models

## Setup

In [None]:
import torch
import gc

In [None]:
TRAIN_classes = 'alpha-classes'
num_classes = 28
num_class_candidates = list(range(2, num_classes+1, 4))
diff_candidates = ['easy', 'average', 'hard']
exploring_kernel_candidates = [sup.PH3_REDUCER_KERNEL_NAME_COS] 
exploringBERT = bert.BERT_TINY
exploring_base_num_epochs = 0
exploring_rate_num_epochs = 0.6
exploring_batch_size = 1024

In [None]:
import json
with open(os.path.join(sup.DATA_ROOT, f"{TRAIN_classes}-subsets.json"), "r") as f:
    loaded = json.load(f)

# Convert keys back to int and lists back to tuples
subsets = {
    int(k): {d: v for d, v in v_dict.items()}
    for k, v_dict in loaded.items()
}
subsets


## Train

In [None]:
sup.bert_score_tracker = []

In [None]:
for data_unit in [sup.DATA_S_PF, sup.DATA_S_PV]:
  for n in num_class_candidates:
    base_num_epochs = int(exploring_base_num_epochs + 
                                      n * exploring_rate_num_epochs)
    s = subsets[n]
    for difficulty in diff_candidates:
      bert.find_best(data_unit=data_unit, 
                      label_col=sup.class_numeric_column,
                      class_list='specified', 
                      class_numeric_list=s[difficulty],
                      num_classes=n, 
                      difficulty=difficulty,
                      KERNEL_CANDIDATES=exploring_kernel_candidates,
                      batch_size=exploring_batch_size,
                      base_num_epochs=base_num_epochs,
                      LOADABLE_CANDIDATES=[exploringBERT])

      
      gc.collect()
      if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    
    gc.collect()
    if torch.backends.mps.is_available():
      torch.mps.empty_cache()
  
  gc.collect()
  if torch.backends.mps.is_available():
    torch.mps.empty_cache()

gc.collect()
if torch.backends.mps.is_available():
  torch.mps.empty_cache()

# Keep metrics

In [None]:
import pandas as pd
from datetime import datetime

In [None]:
PREP_scores_df = pd.DataFrame(sup.bert_score_tracker, columns=sup.bert_scores_columns)
now = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
sup.create_dir_if_not_exists(os.path.join(sup.TRAIN_SCORES_ROOT, "specified", 
                                          sup.TRAIN_BERT_CODE, TRAIN_classes))
PREP_scores_df.to_csv(os.path.join(sup.TRAIN_SCORES_ROOT, "specified", 
                                       sup.TRAIN_BERT_CODE, TRAIN_classes,
                                       f"PREPanalysis-{now}.csv"), index=False)