In [1]:
import os
import torch
import torch.nn as nn


from config import get_config
from data.loader import get_dataloaders
from models.model import create_model
from train_utils.resume import init_resume_state
from train_utils.train_epoch import train_one_epoch
from train_utils.evaluate import evaluate
from train_utils.train_metrics_logger import update_train_logs
from train_utils.train_metrics_logger import update_val_logs
from train_utils.checkpoint_saver import save_epoch_checkpoint
from train_utils.train_metrics_logger import record_and_save_epoch
from train_utils.early_stopping import check_early_stopping
from train_utils.training_summary import finalize_training_summary
from train_utils.training_summary import print_best_model_summary
from train_utils.plot_metrics import plot_train_val_metrics
from train_utils.plot_metrics import plot_loss_accuracy


In [None]:
# cfg=get_config(config_path="config/convnext_bs512_ep50_lr1e-04_ds1000_g5.yml")
cfg=get_config()
print(cfg)

[INFO] Config Path: config/convnext_bs512_ep50_lr1e-04_ds1000_g5.yml
[INFO] Detected native Ubuntu host: DS044955
[INFO] Using dataset root: /home/arsalan/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_1000_balanced_unshuffled
[INFO] Detected dataset size: 1000
namespace(model_tag='ConvNeXt_g5', backbone='convnext', batch_size=512, epochs=50, learning_rate=0.0001, patience=5, input_shape=(1, 32, 32), global_max=121.79151153564453, dataset_root_dir='/home/arsalan/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_1000_balanced_unshuffled', train_csv='/home/arsalan/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_1000_balanced_unshuffled/file_labels_aggregated_g5_train.csv', val_csv='/home/arsalan/Projects/110_JetscapeML/hm_jetscapeml_source/data

In [3]:
os.makedirs(cfg.output_dir, exist_ok=True)
print(f"[INFO] Saving all outputs to: {cfg.output_dir}")

[INFO] Saving all outputs to: training_output/ConvNeXt_g5_bs512_ep50_lr1e-04_ds1000_g5


In [4]:
# Set seed, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
# torch.backends.cudnn.benchmark = True

[INFO] Using device: cuda


In [5]:
# Data
train_loader, val_loader, test_loader = get_dataloaders(cfg)

[INFO] Training samples: 153
[INFO] Validation samples: 19
[INFO] Test samples: 20
[INFO] Length of training dataloader: 1
[INFO] Length of validation dataloader: 1
[INFO] Length of test dataloader: 1


In [6]:
# Model and optimizer
model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)
model.to(device)

[INFO] Using ConvNeXt backbone with input shape: (1, 32, 32)
[INFO] ConvNeXt backbone initialized with input shape: (1, 32, 32)


MultiHeadClassifier(
  (backbone): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(1, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=96, out_features=384, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=384, out_features=96, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (conv

In [7]:
criterion = {
    # 'energy_loss_output': nn.BCELoss(),
    'energy_loss_output': nn.BCEWithLogitsLoss(),
    'alpha_output': nn.CrossEntropyLoss(),
    'q0_output': nn.CrossEntropyLoss()
}
print(f"[INFO] Loss functions:{criterion}")

[INFO] Loss functions:{'energy_loss_output': BCEWithLogitsLoss(), 'alpha_output': CrossEntropyLoss(), 'q0_output': CrossEntropyLoss()}


In [8]:
print(f"[INFO] Init Training Trackers")
train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list = [], [], [],[]
train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list = [], [], [], []

print(f"[INFO] Init Validation Trackers")
val_loss_energy_list, val_loss_alpha_list,val_loss_q0_list,val_loss_list = [], [], [], []
val_acc_energy_list, val_acc_alpha_list,val_acc_q0_list ,val_acc_list = [],[],[],[]

[INFO] Init Training Trackers
[INFO] Init Validation Trackers


In [9]:
model, optimizer, start_epoch, best_acc, early_stop_counter, best_epoch, best_metrics, training_summary, all_epoch_metrics = init_resume_state( model, optimizer, device,cfg)

[INFO] Init Resume/Training Parameters
[INFO] Starting fresh training run by initializing training summary
[INFO] 📄 Training summary saved to: training_output/ConvNeXt_g5_bs512_ep50_lr1e-04_ds1000_g5/training_summary.json


In [None]:
# train_metrics = train_one_epoch(train_loader, model, criterion, optimizer, device)
# print(f"[INFO] Training metrics: {train_metrics}")

                                                       

[INFO] Training metrics: {'loss': 3.417614459991455, 'loss_energy': 0.6740033626556396, 'loss_alpha': 1.225578784942627, 'loss_q0': 1.5180323123931885, 'accuracy': 0.0392156862745098, 'accuracy_energy': 0.39215686274509803, 'accuracy_alpha': 0.37254901960784315, 'accuracy_q0': 0.2679738562091503}




In [11]:
for epoch in range(start_epoch, cfg.epochs):
    print(f"[INFO] Epoch {epoch+1}/{cfg.epochs}")
    train_metrics={}
    train_metrics = train_one_epoch(train_loader, model, criterion, optimizer, device)
    (train_loss_list,
    train_loss_energy_list,
    train_loss_alpha_list,
    train_loss_q0_list,
    train_acc_list,
    train_acc_energy_list,
    train_acc_alpha_list,
    train_acc_q0_list
    ) = update_train_logs(
        train_metrics,
        train_loss_list,
        train_loss_energy_list,
        train_loss_alpha_list,
        train_loss_q0_list,
        train_acc_list,
        train_acc_energy_list,
        train_acc_alpha_list,
        train_acc_q0_list
    )
    val_metrics = evaluate(val_loader, model, criterion, device)
    (val_loss_list,
    val_loss_energy_list,
    val_loss_alpha_list,
    val_loss_q0_list,
    val_acc_list,
    val_acc_energy_list,
    val_acc_alpha_list,
    val_acc_q0_list,
    ) = update_val_logs(
        val_metrics,
        val_loss_list,
        val_loss_energy_list,
        val_loss_alpha_list,
        val_loss_q0_list,
        val_acc_list,
        val_acc_energy_list,
        val_acc_alpha_list,
        val_acc_q0_list,
    )
    print(f"[INFO] Epoch {epoch+1}: Energy Acc ={val_metrics['energy']['accuracy']:.4f}, αs Acc = {val_metrics['alpha']['accuracy']:.4f}, Q0 Acc = {val_metrics['q0']['accuracy']:.4f}, Total Acc = {val_metrics['accuracy']:.4f}")
    print(f"[INFO] Epoch {epoch+1}: Energy Loss ={val_metrics['loss_energy']:.4f}, αs Loss = {val_metrics['loss_alpha']:.4f}, Q0 Loss = {val_metrics['loss_q0']:.4f}, Total Loss = {val_metrics['loss']:.4f}")
    
    all_epoch_metrics=record_and_save_epoch(epoch, train_metrics, val_metrics, all_epoch_metrics, cfg.output_dir)
    
    # save_epoch_checkpoint(
    #     epoch=epoch,
    #     model=model,
    #     optimizer=optimizer,
    #     metrics=val_metrics,
    #     output_dir=cfg.output_dir
    # )

    best_acc, best_metrics, best_epoch, early_stop_counter, should_stop = check_early_stopping(
        best_acc=best_acc,
        best_metrics=best_metrics,
        early_stop_counter=early_stop_counter,
        best_epoch=best_epoch,
        model=model,
        optimizer=optimizer,
        val_metrics=val_metrics,
        output_dir=cfg.output_dir,
        patience=cfg.patience,
        epoch=epoch
    )
    
    if should_stop:
        break
    
    print("="*150)
    

[INFO] Epoch 1/50


                                                       

[INFO] Epoch 1: Energy Acc =0.7895, αs Acc = 0.2105, Q0 Acc = 0.3158, Total Acc = 0.0526
[INFO] Epoch 1: Energy Loss =0.5642, αs Loss = 1.2530, Q0 Loss = 1.3782, Total Loss = 3.1954
[INFO] Epoch 1: Saving metrics to disk




✅ Best model saved at epoch 1 with total accuracy: 0.0526
[INFO] Epoch 2/50


                                                       

[INFO] Epoch 2: Energy Acc =0.7895, αs Acc = 0.2632, Q0 Acc = 0.2632, Total Acc = 0.0526
[INFO] Epoch 2: Energy Loss =0.5596, αs Loss = 1.3122, Q0 Loss = 1.4905, Total Loss = 3.3622
[INFO] Epoch 2: Saving metrics to disk
⏳ No improvement. Early stop counter: 1/5
[INFO] Epoch 3/50


                                                       

[INFO] Epoch 3: Energy Acc =0.7895, αs Acc = 0.3684, Q0 Acc = 0.2632, Total Acc = 0.1053
[INFO] Epoch 3: Energy Loss =0.5659, αs Loss = 1.1003, Q0 Loss = 1.4260, Total Loss = 3.0921
[INFO] Epoch 3: Saving metrics to disk
✅ Best model saved at epoch 3 with total accuracy: 0.1053
[INFO] Epoch 4/50


                                                       

[INFO] Epoch 4: Energy Acc =0.7895, αs Acc = 0.3684, Q0 Acc = 0.4211, Total Acc = 0.1579
[INFO] Epoch 4: Energy Loss =0.5548, αs Loss = 1.0690, Q0 Loss = 1.3619, Total Loss = 2.9857
[INFO] Epoch 4: Saving metrics to disk




✅ Best model saved at epoch 4 with total accuracy: 0.1579
[INFO] Epoch 5/50


                                                       

[INFO] Epoch 5: Energy Acc =0.7895, αs Acc = 0.3684, Q0 Acc = 0.2632, Total Acc = 0.0526
[INFO] Epoch 5: Energy Loss =0.5639, αs Loss = 1.0800, Q0 Loss = 1.3971, Total Loss = 3.0410
[INFO] Epoch 5: Saving metrics to disk
⏳ No improvement. Early stop counter: 1/5
[INFO] Epoch 6/50


                                                       

[INFO] Epoch 6: Energy Acc =0.7895, αs Acc = 0.3684, Q0 Acc = 0.3158, Total Acc = 0.0526
[INFO] Epoch 6: Energy Loss =0.6166, αs Loss = 1.0563, Q0 Loss = 1.4144, Total Loss = 3.0873
[INFO] Epoch 6: Saving metrics to disk
⏳ No improvement. Early stop counter: 2/5
[INFO] Epoch 7/50


                                                       

[INFO] Epoch 7: Energy Acc =0.7895, αs Acc = 0.4737, Q0 Acc = 0.4211, Total Acc = 0.1053
[INFO] Epoch 7: Energy Loss =0.6245, αs Loss = 1.0615, Q0 Loss = 1.3807, Total Loss = 3.0666
[INFO] Epoch 7: Saving metrics to disk
⏳ No improvement. Early stop counter: 3/5
[INFO] Epoch 8/50


                                                       

[INFO] Epoch 8: Energy Acc =0.7895, αs Acc = 0.3158, Q0 Acc = 0.4211, Total Acc = 0.1053
[INFO] Epoch 8: Energy Loss =0.5767, αs Loss = 1.0869, Q0 Loss = 1.3521, Total Loss = 3.0157
[INFO] Epoch 8: Saving metrics to disk
⏳ No improvement. Early stop counter: 4/5
[INFO] Epoch 9/50


                                                       

[INFO] Epoch 9: Energy Acc =0.7895, αs Acc = 0.3158, Q0 Acc = 0.2632, Total Acc = 0.0526
[INFO] Epoch 9: Energy Loss =0.5454, αs Loss = 1.0850, Q0 Loss = 1.3624, Total Loss = 2.9927
[INFO] Epoch 9: Saving metrics to disk
⏳ No improvement. Early stop counter: 5/5
🛑 Early stopping triggered at epoch 9. Best was at epoch 4.




In [None]:
finalize_training_summary(
    summary=training_summary,
    best_epoch=best_epoch,
    best_acc=best_acc,
    best_metrics=best_metrics,
    output_dir=cfg.output_dir
)
print_best_model_summary(
    best_epoch=best_epoch,
    best_acc=best_acc,
    best_metrics=best_metrics
)

In [None]:
plot_train_val_metrics(train_loss_list, val_loss_list, train_acc_list, val_acc_list, cfg.output_dir)
plot_loss_accuracy(train_loss_list,
                    train_loss_energy_list,
                    train_loss_alpha_list,
                    train_loss_q0_list,
                    train_acc_list,
                    train_acc_energy_list,
                    train_acc_alpha_list,
                    train_acc_q0_list,
                    cfg.output_dir,
                    title="Train Loss and Accuracy per Epoch")
plot_loss_accuracy(val_loss_list,
                    val_loss_energy_list,
                    val_loss_alpha_list,
                    val_loss_q0_list,
                    val_acc_list,
                    val_acc_energy_list,
                    val_acc_alpha_list,
                    val_acc_q0_list,
                    cfg.output_dir,
                    title="Validation Loss and Accuracy per Epoch")