In [1]:
# Set project's environment variables
import os
import sys
from dotenv import load_dotenv
load_dotenv(dotenv_path="../project.env")
sys.path.append(os.environ["PYTHONPATH"])

import warnings
warnings.filterwarnings(
    "ignore",
    message=".*",
    category=UserWarning
)

In [2]:
import superheader as sup
import os
import torch
import torch.nn as nn
import torch.optim as optim
import TRAIN.architecture.BERT.bert as bert
import live

I0000 00:00:1750840908.009244 16207612 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M3 Pro
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1750840908.015255 16207740 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1750840908.019241 16207740 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
I0000 00:00:1750840908.021404 16207612 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M3 Pro
W0000 00:00:1750840908.082103 16207751 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1750840908.088691 16207756 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support 

In [3]:
TRAIN_classes = 'all-classes'

for data_unit in [sup.DATA_S_PF, sup.DATA_S_PV]:
    for PH2 in [False, True]:
        for PH3 in [False, True]:

            # --- Phase-dependent config ---
            if PH3:
                reducer_name = sup.PH3_REDUCER_NAME_PCA
                n_components = 15
            else:
                reducer_name = ''
                n_components = 75 if PH2 else 72

            # --- Build model ---
            data_config = {
                "data_unit" : data_unit,
                "label_col" : sup.class_numeric_column,
                "class_list" : TRAIN_classes,
                "batch_size" : 1024,
                "PH2" : PH2,
                "PH3" : PH3,
                "kernel" : '',
                "reducer" : reducer_name,
                "n" : n_components,
            }

            train_config = {
                "arch" : sup.TRAIN_BERT_CODE,
                "device" : bert.device,
                "loadable" : bert.BERT_MINI,
                "optimizer" : optim.AdamW,
                "lr" : 1e-5,
                "weight_decay" : 0,
                "loss_fn" : nn.CrossEntropyLoss,
                "num_epochs" : 1
            }

            model = bert.BERT(data_config=data_config, df=None, train_config=train_config)

            model_path_file = live.model_paths[TRAIN_classes][data_unit][PH2][PH3]
            full_model_path = os.path.join(live.model_path_root, data_unit, live.bertmini_path, model_path_file)
            model.me.load_state_dict(torch.load(full_model_path, map_location='cpu'))
            model.me.eval()

            model.test()
            model.full_score()

            print(f"\n\n{data_unit}; {PH2}; {PH3}")
            print(model.accuracy)
            print(model.top2accuracy)
            print(model.macro_f1)
            print(model.macro_precision)
            print(model.macro_recall)



Spf; False; False
0.9939122076257609
0.9980775392502403
0.9938941642078613
0.9939043757561687
0.9939184332235567


Spf; False; True
0.991455730001068
0.9972231122503471
0.9914424390054666
0.9915179935305322
0.9914029179716116


Spf; True; False
0.9848339207518958
0.9931645840008544
0.9848631405948567
0.9849600123312862
0.98486023520075


Spf; True; True
0.9756488305030439
0.9895332692513084
0.9757332722684726
0.9758342841998079
0.9757342510040086


Spv; False; False
0.9795134443021767
0.9859154929577465
0.9796815629876463
0.9807412720456199
0.9794304557462454


Spv; False; True
0.9590268886043534
0.9743918053777209
0.958241004438198
0.9602779571601082
0.958981708981709


Spv; True; False
0.9282970550576184
0.9564660691421255
0.928594928186584
0.9310171862689024
0.9286611286611285


Spv; True; True
0.9218950064020487
0.9462227912932138
0.9220301959849777
0.9274350660398583
0.9221237984395879
