# Context
This notebook drives the training process for different models.

In [1]:
# Set project's environment variables
import os
import sys
from dotenv import load_dotenv
load_dotenv(dotenv_path="../../../project.env")
sys.path.append(os.environ["PYTHONPATH"])

In [2]:
# Import project-wide and PH2 specific variables and functions
import superheader as sup
from TRAIN.architecture.archeader import Arch, print_best
import TRAIN.architecture.BERT.bert as bert

In [3]:
import pandas as pd
from datetime import datetime

# Models

## Setup

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

import gc

In [5]:
TRAIN_classes = 'ten-classes'

base_data_config = {
  "PH3" : False,
  "reducer" : '',
  "kernel" : '',
  "label_col" : sup.class_numeric_column,
  "class_list" : TRAIN_classes,
  "batch_size" : 1024,
}

base_train_config = {
  "device" : bert.device,
  "arch" : sup.TRAIN_BERT_CODE,
  "optimizer" : optim.AdamW,
  "lr" : 1e-5,
  "weight_decay" : 0,
  "loss_fn" : nn.CrossEntropyLoss
}

base_num_epochs = 37

## Train

In [6]:
sup.bert_score_tracker = []

In [None]:
configs = list()

for data_unit in [sup.DATA_S_PF, sup.DATA_S_PV]:
  for PH2 in [False, True]:
    data_config = base_data_config.copy()

    data_config["data_unit"] = data_unit
    data_unit_multiplier = 4 if data_unit == sup.DATA_S_PV else 1

    data_config["PH2"] = PH2
    data_config["n"] = 75 if PH2 else 72

    arch = Arch(data_config=data_config, df=None, 
                train_config={"arch" : "generic"})
    save_df = arch.df
    for loadable in [bert.BERT_TINY, bert.DISTILBERT]:
      train_config = base_train_config.copy()

      train_config["loadable"] = loadable
      loadable_multiplier = 5 if loadable == bert.BERT_TINY else 1

      train_config["num_epochs"] = base_num_epochs \
                                    * data_unit_multiplier \
                                    * loadable_multiplier
      
      bert.train_one_model((data_config, train_config, save_df))
    
    del save_df
    bert.clean_bert(arch)
    gc.collect()
    if torch.backends.mps.is_available():
      torch.mps.empty_cache()

{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spf', 'PH2': False, 'n': 72}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'prajjwal1/bert-tiny', 'num_epochs': 185}


                                                                                    

0.7360285374554102
updating best... at 2025-06-14_17:31:48



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spf', 'PH2': False, 'n': 72}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'distilbert-base-uncased', 'num_epochs': 37}


                                                                                 

0.8783194609591756
updating best... at 2025-06-14_17:33:15



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spf', 'PH2': True, 'n': 75}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'prajjwal1/bert-tiny', 'num_epochs': 185}


                                                                                    

0.43162901307966706
not best... at 2025-06-14_17:33:51



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spf', 'PH2': True, 'n': 75}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'distilbert-base-uncased', 'num_epochs': 37}


                                                                                 

0.5307173999207293
not best... at 2025-06-14_17:35:16



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spv', 'PH2': False, 'n': 72}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'prajjwal1/bert-tiny', 'num_epochs': 740}


                                                                                    

0.7962085308056872
updating best... at 2025-06-14_17:35:44



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spv', 'PH2': False, 'n': 72}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'distilbert-base-uncased', 'num_epochs': 148}


                                                                                     

0.8957345971563981
updating best... at 2025-06-14_17:40:59



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spv', 'PH2': True, 'n': 75}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'prajjwal1/bert-tiny', 'num_epochs': 740}


                                                                                    

0.41706161137440756
not best... at 2025-06-14_17:41:28



{'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spv', 'PH2': True, 'n': 75}
{'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'distilbert-base-uncased', 'num_epochs': 148}


                                                                                     

0.5734597156398105
not best... at 2025-06-14_17:46:40





## Keep scores

In [8]:
print_best(sup.TRAIN_BERT_CODE, sup.DATA_S_PF)
print_best(sup.TRAIN_BERT_CODE, sup.DATA_S_PV)

Data Unit: Spf
Best score: 0.8783194609591756
Best data config: {'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spf', 'PH2': False, 'n': 72}
Best train config: {'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'distilbert-base-uncased', 'num_epochs': 37}
Data Unit: Spv
Best score: 0.8957345971563981
Best data config: {'PH3': False, 'reducer': '', 'kernel': '', 'label_col': 'class_numeric', 'class_list': 'ten-classes', 'batch_size': 1024, 'data_unit': 'Spv', 'PH2': False, 'n': 72}
Best train config: {'device': device(type='mps'), 'arch': 'BERT', 'optimizer': <class 'torch.optim.adamw.AdamW'>, 'lr': 1e-05, 'weight_decay': 0, 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'loadable': 'distilbert-base-uncased', 'num_epochs': 148}


In [9]:
bert_scores_df = pd.DataFrame(sup.bert_score_tracker, columns=sup.bert_scores_columns)
now = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
sup.create_dir_if_not_exists(os.path.join(sup.TRAIN_SCORES_ROOT, TRAIN_classes, 
                                          sup.TRAIN_BERT_CODE,
                                          "best"))
bert_scores_df.to_csv(os.path.join(sup.TRAIN_SCORES_ROOT, TRAIN_classes, 
                                          sup.TRAIN_BERT_CODE,
                                          "best", 
                                          f"{now}.csv"), index=False)