# Context
This notebook drives the training process for different models.

In [1]:
# Set project's environment variables
import os
import sys
from dotenv import load_dotenv
load_dotenv(dotenv_path="../../../project.env")
sys.path.append(os.environ["PYTHONPATH"])

In [2]:
# Import project-wide and PH2 specific variables and functions
import superheader as sup
import TRAIN.architecture.BERT.bert as bert

# Models

## Setup

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

import gc

In [4]:
TRAIN_CLASSES = "two-classes"
num_classes = 2
exploring_batch_size = 1024
base_num_epochs = 2
rate_num_epochs = 3.5
num_class_candidates = list(range(2, num_classes+1, 2))

In [5]:
import json
with open(os.path.join(sup.DATA_ROOT, f"{TRAIN_CLASSES}-subsets.json"), "r") as f:
    loaded = json.load(f)

# Convert keys back to int and lists back to tuples
subsets = {
    int(k): {d: v for d, v in v_dict.items()}
    for k, v_dict in loaded.items()
}
subsets


{2: {'easy': [0, 1], 'average': [0, 1], 'hard': [0, 1]}}

In [6]:
base_data_config = {
  "PH3" : False,
  "reducer" : '',
  "kernel" : '',
  "label_col" : sup.class_numeric_column,
  "batch_size" : exploring_batch_size,
  "class_list" : 'specified'
}

base_train_config = {
  "device" : bert.device,
  "arch" : sup.TRAIN_BERT_CODE,
  "loadable" : bert.DISTILBERT,
  "optimizer" : optim.AdamW,
  "lr" : 1e-5,
  "weight_decay" : 0,
  "loss_fn" : nn.CrossEntropyLoss,
}



## Train

In [7]:
metric_tracker = list()

In [8]:
for data_unit in [sup.DATA_S_PF, sup.DATA_S_PV]:
  data_config = base_data_config.copy()
  train_config = base_train_config.copy()

  data_config["data_unit"] = data_unit

  for PH2 in [False, True]:
    data_config["PH2"] = PH2
    data_config["n"] = 75 if PH2 else 72

    for n in num_class_candidates:
      s = subsets[n]
      train_config["num_epochs"] = int(base_num_epochs + n * rate_num_epochs)
      for difficulty in ['easy', 'average', 'hard']:
        s_dif = s[difficulty]
        data_config["difficulty"] = difficulty
        data_config["class_numeric_list"] = s_dif
        
        print(data_config)
        print(train_config)
        model = bert.BERT(data_config=data_config, df=None, 
                          train_config=train_config)
        
        model.fit(verbose=True)

        model.test()

        model.full_score()
        print(model.accuracy)
        print(model.top2accuracy)
        print(model.macro_f1)
        print(model.macro_precision)
        print(model.macro_recall)

        metrics = {"data_unit" : data_unit,
                    "PH2" : PH2,
                    "num_classes" : n,
                    "difficulty" : difficulty,
                    "accuracy" : model.accuracy,
                    "top2accuracy" : model.top2accuracy,
                    "macro_f1" : model.macro_f1,
                    "macro_precision" : model.macro_precision,
                    "macro_recall" : model.macro_recall}
        
        model.keep_confusion_matrix()
        model.keep_loss()
        
        metric_tracker.append(metrics.copy())
        
        bert.clean_bert(model)
      
      gc.collect()
      if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    
    gc.collect()
    if torch.backends.mps.is_available():
      torch.mps.empty_cache()
  
  gc.collect()
  if torch.backends.mps.is_available():
    torch.mps.empty_cache()

gc.collect()
if torch.backends.mps.is_available():
  torch.mps.empty_cache()
  



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': False, 'n': 72, 'difficulty': 'easy', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                              

0.626953125
1
0.5848795500371431
0.7387288708586883
0.6347633587786259
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': False, 'n': 72, 'difficulty': 'average', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                              

0.6328125
1
0.5885088919288646
0.7616562677254679
0.640854961832061
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': False, 'n': 72, 'difficulty': 'hard', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                              

0.8515625
1
0.8515602349889372
0.8522727272727273
0.8521221374045802
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': True, 'n': 75, 'difficulty': 'easy', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                              

0.609375
1
0.5609747731988819
0.7226851851851852
0.6174961832061069
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': True, 'n': 75, 'difficulty': 'average', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                              

0.619140625
1
0.5700536141076158
0.7491416629347664
0.6274045801526718
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spf', 'PH2': True, 'n': 75, 'difficulty': 'hard', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                              

0.591796875
1
0.5320111084384772
0.7207949018950193
0.6005038167938931
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': False, 'n': 72, 'difficulty': 'easy', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                               

0.6976744186046512
1
0.6721407624633431
0.8088235294117647
0.7045454545454546
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': False, 'n': 72, 'difficulty': 'average', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                               

0.6744186046511628
1
0.6416666666666666
0.8
0.6818181818181819
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': False, 'n': 72, 'difficulty': 'hard', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                               

0.7906976744186046
1
0.7831932773109244
0.85
0.7954545454545454
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': True, 'n': 75, 'difficulty': 'easy', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


0.5116279069767442
1
0.3384615384615385
0.2558139534883721
0.5
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': True, 'n': 75, 'difficulty': 'average', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


0.5116279069767442
1
0.3384615384615385
0.2558139534883721
0.5
{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'batch_size': 1024, 'class_list': 'specified', 'data_unit': 'Spv', 'PH2': True, 'n': 75, 'difficulty': 'hard', 'class_numeric_list': [0, 1]}
{'device': device(type='mps'), 'arch': 'BERT', 'loadable': 'distilbert-base-uncased', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'num_epochs': 9}


                                                                               

0.5581395348837209
1
0.4361628709454796
0.7682926829268293
0.5476190476190477


# Keep metrics

In [9]:
import pandas as pd

metrics_df = pd.DataFrame(metric_tracker)
metrics_df.sort_values(by='accuracy', ascending=False)

Unnamed: 0,data_unit,PH2,num_classes,difficulty,accuracy,top2accuracy,macro_f1,macro_precision,macro_recall
2,Spf,False,2,hard,0.851562,1,0.85156,0.852273,0.852122
8,Spv,False,2,hard,0.790698,1,0.783193,0.85,0.795455
6,Spv,False,2,easy,0.697674,1,0.672141,0.808824,0.704545
7,Spv,False,2,average,0.674419,1,0.641667,0.8,0.681818
1,Spf,False,2,average,0.632812,1,0.588509,0.761656,0.640855
0,Spf,False,2,easy,0.626953,1,0.58488,0.738729,0.634763
4,Spf,True,2,average,0.619141,1,0.570054,0.749142,0.627405
3,Spf,True,2,easy,0.609375,1,0.560975,0.722685,0.617496
5,Spf,True,2,hard,0.591797,1,0.532011,0.720795,0.600504
11,Spv,True,2,hard,0.55814,1,0.436163,0.768293,0.547619


In [10]:
sup.create_dir_if_not_exists(os.path.join(sup.SCORES_ROOT, "specified", sup.TRAIN_BERT_CODE))

metrics_df.to_csv(os.path.join(sup.SCORES_ROOT, "specified", sup.TRAIN_BERT_CODE, f"PREPanalysis{num_classes}.csv"), index=False)