# Context
This notebook drives the training process for different models.

In [1]:
# Set project's environment variables
import os
import sys
from dotenv import load_dotenv
load_dotenv(dotenv_path="../../../project.env")
sys.path.append(os.environ["PYTHONPATH"])

In [2]:
# Import project-wide and PH2 specific variables and functions
import superheader as sup
import TRAIN.architecture.BERT.bert as bert

# Models

## Setup

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

import gc

In [4]:
TRAIN_CLASSES = "ten-classes"
num_classes = 2
exploring_batch_size = 1024
base_num_epochs = 2
rate_num_epochs = 2
num_class_candidates = list(range(2, num_classes+1, 2))

In [5]:
import json
with open(os.path.join(sup.DATA_ROOT, f"{TRAIN_CLASSES}-subsets.json"), "r") as f:
    loaded = json.load(f)

# Convert keys back to int and lists back to tuples
subsets = {
    int(k): {d: v for d, v in v_dict.items()}
    for k, v_dict in loaded.items()
}
subsets


{2: {'easy': [5, 7], 'average': [2, 4], 'hard': [2, 3]},
 3: {'easy': [0, 3, 5], 'average': [0, 1, 2], 'hard': [0, 1, 4]},
 4: {'easy': [1, 6, 7, 9], 'average': [1, 4, 7, 8], 'hard': [0, 1, 5, 9]},
 5: {'easy': [2, 5, 6, 7, 9],
  'average': [2, 4, 5, 7, 9],
  'hard': [0, 1, 2, 4, 8]},
 6: {'easy': [1, 3, 5, 6, 7, 9],
  'average': [1, 2, 5, 6, 7, 8],
  'hard': [1, 4, 5, 6, 8, 9]},
 7: {'easy': [0, 3, 4, 5, 6, 7, 9],
  'average': [1, 2, 3, 5, 6, 7, 8],
  'hard': [1, 2, 3, 4, 5, 6, 8]},
 8: {'easy': [0, 2, 3, 4, 5, 6, 7, 8],
  'average': [1, 3, 4, 5, 6, 7, 8, 9],
  'hard': [0, 1, 2, 4, 5, 7, 8, 9]},
 9: {'easy': [0, 1, 2, 3, 4, 6, 7, 8, 9],
  'average': [0, 1, 2, 3, 4, 5, 6, 7, 8],
  'hard': [0, 1, 2, 3, 4, 5, 7, 8, 9]},
 10: {'easy': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  'average': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  'hard': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}}

In [6]:
base_data_config = {
  "PH3" : False,
  "reducer" : '',
  "kernel" : '',
  "label_col" : sup.class_numeric_column,
  "batch_size" : exploring_batch_size,
  "class_list" : 'specified'
}

base_train_config = {
  "device" : bert.device,
  "arch" : sup.TRAIN_BERT_CODE,
  "loadable" : bert.DISTILBERT,
  "optimizer" : optim.AdamW,
  "lr" : 1e-5,
  "weight_decay" : 0,
  "loss_fn" : nn.CrossEntropyLoss,
}



## Train

In [7]:
metric_tracker = list()

In [8]:
for data_unit in [sup.DATA_S_PF, sup.DATA_S_PV]:
  data_config = base_data_config.copy()
  train_config = base_train_config.copy()

  data_config["data_unit"] = data_unit

  for PH2 in [False, True]:
    data_config["PH2"] = PH2
    data_config["n"] = 75 if PH2 else 72

    for n in num_class_candidates:
      s = subsets[n]
      train_config["num_epochs"] = base_num_epochs + n * rate_num_epochs
      for difficulty in ['easy', 'average', 'hard']:
        s_dif = s[difficulty]
        data_config["difficulty"] = difficulty
        data_config["class_numeric_list"] = s_dif
        
        print(data_config)
        print(train_config)
        model = bert.BERT(data_config=data_config, df=None, 
                          train_config=train_config)
        
        model.fit(verbose=True)

        model.test()

        model.full_score()
        print(model.accuracy)
        print(model.top2accuracy)
        print(model.macro_f1)
        print(model.macro_precision)
        print(model.macro_recall)

        metrics = {"data_unit" : data_unit,
                    "PH2" : PH2,
                    "num_classes" : n,
                    "difficulty" : difficulty,
                    "accuracy" : model.accuracy,
                    "top2accuracy" : model.top2accuracy,
                    "macro_f1" : model.macro_f1,
                    "macro_precision" : model.macro_precision,
                    "macro_recall" : model.macro_recall}
        
        model.keep_confusion_matrix()
        model.keep_loss()
        
        metric_tracker.append(metrics.copy())
        
        bert.clean_bert(model)
      
      gc.collect()
      if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    
    gc.collect()
    if torch.backends.mps.is_available():
      torch.mps.empty_cache()
  
  gc.collect()
  if torch.backends.mps.is_available():
    torch.mps.empty_cache()

gc.collect()
if torch.backends.mps.is_available():
  torch.mps.empty_cache()
  



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': False, 'n': 72, 'difficulty': 'easy', 'class_numeric_list': [5, 7]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                              

0.8530020703933747
1
0.8505393449410534
0.8831525762511104
0.8550420168067228
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': False, 'n': 72, 'difficulty': 'average', 'class_numeric_list': [2, 4]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


0.5038910505836576
1
0.3350582147477361
0.2519455252918288
0.5
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': False, 'n': 72, 'difficulty': 'hard', 'class_numeric_list': [2, 3]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                              

0.6051587301587301
1
0.5496099907044416
0.7179385309728887
0.6079842519685039
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': True, 'n': 75, 'difficulty': 'easy', 'class_numeric_list': [5, 7]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.9751552795031055
1
0.975155173005041
0.975455387465267
0.975390156062425
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': True, 'n': 75, 'difficulty': 'average', 'class_numeric_list': [2, 4]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                              

0.5019455252918288
1
0.3443485540031491
0.7495107632093934
0.5057915057915058
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': True, 'n': 75, 'difficulty': 'hard', 'class_numeric_list': [2, 3]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                              

0.5119047619047619
1
0.37182579089232515
0.6686991869918699
0.5156850393700787
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': False, 'n': 72, 'difficulty': 'easy', 'class_numeric_list': [5, 7]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.926829268292683
1
0.926654740608229
0.9282296650717703
0.9261904761904762
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': False, 'n': 72, 'difficulty': 'average', 'class_numeric_list': [2, 4]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.46511627906976744
1
0.31746031746031744
0.23809523809523808
0.47619047619047616
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': False, 'n': 72, 'difficulty': 'hard', 'class_numeric_list': [2, 3]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.5952380952380952
1
0.5159322033898305
0.7763157894736843
0.5952380952380952
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': True, 'n': 75, 'difficulty': 'easy', 'class_numeric_list': [5, 7]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.9512195121951219
1
0.9511904761904761
0.9511904761904761
0.9511904761904761
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': True, 'n': 75, 'difficulty': 'average', 'class_numeric_list': [2, 4]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.5348837209302325
1
0.38920454545454547
0.7619047619047619
0.5238095238095238
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': True, 'n': 75, 'difficulty': 'hard', 'class_numeric_list': [2, 3]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 6}


                                                                               

0.5238095238095238
1
0.5227272727272727
0.5240274599542334
0.5238095238095237


# Keep metrics

In [9]:
import pandas as pd

metrics_df = pd.DataFrame(metric_tracker)
metrics_df.sort_values(by='accuracy', ascending=False)

Unnamed: 0,data_unit,PH2,num_classes,difficulty,accuracy,top2accuracy,macro_f1,macro_precision,macro_recall
3,Spf,True,2,easy,0.975155,1,0.975155,0.975455,0.97539
9,Spv,True,2,easy,0.95122,1,0.95119,0.95119,0.95119
6,Spv,False,2,easy,0.926829,1,0.926655,0.92823,0.92619
0,Spf,False,2,easy,0.853002,1,0.850539,0.883153,0.855042
2,Spf,False,2,hard,0.605159,1,0.54961,0.717939,0.607984
8,Spv,False,2,hard,0.595238,1,0.515932,0.776316,0.595238
10,Spv,True,2,average,0.534884,1,0.389205,0.761905,0.52381
11,Spv,True,2,hard,0.52381,1,0.522727,0.524027,0.52381
5,Spf,True,2,hard,0.511905,1,0.371826,0.668699,0.515685
1,Spf,False,2,average,0.503891,1,0.335058,0.251946,0.5


In [10]:
sup.create_dir_if_not_exists(os.path.join(sup.SCORES_ROOT, "specified", sup.TRAIN_BERT_CODE))

metrics_df.to_csv(os.path.join(sup.SCORES_ROOT, "specified", sup.TRAIN_BERT_CODE, f"PREPanalysis{num_classes}.csv"), index=False)