In [None]:
import torch
import os
import random
import numpy as np
import torch
import torch.nn as nn
import wandb
import torchmetrics
from collections import defaultdict, OrderedDict
import math
import torchvision.models as models

def set_seed(seed_value):
    """Set seeds for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Random seed set to {seed_value} for reproducibility.")

MY_SEED = 42
set_seed(MY_SEED)
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
REG_IMPORTANCE=0.5

Random seed set to 42 for reproducibility.


In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class FingerprintDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Path to the dataset split (Train, Validation, or Test).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.label_to_idx = {
            'AC': 0, 'Ambiguous': 1, 'TA': 2, 'SA': 3, 'TAUL_TARL': 4, 'UL_RL': 5,
            'SW': 6, 'TCW': 7, 'UPE_RPE': 8, 'DL': 9, 'ESW': 10, 'ECW': 11
        }

        self.idx_to_label = {v: k for k, v in self.label_to_idx.items()}
        self.class_names = [self.idx_to_label[i] for i in sorted(self.idx_to_label.keys())]

        self.zero_ridge_classes = {'AC', 'Ambiguous', 'TA', 'SA'}
        self._prepare_dataset()

    def _prepare_dataset(self):
        for class_name in os.listdir(self.root_dir):
            class_path = os.path.join(self.root_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            label = self.label_to_idx[class_name]
            for fname in os.listdir(class_path):
                if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.tif')):
                    continue
                img_path = os.path.join(class_path, fname)
                ridge_count = self._get_ridge_count(fname, class_name)
                self.data.append((img_path, label, ridge_count))

    def _get_ridge_count(self, filename, class_name):
        if class_name in self.zero_ridge_classes:
            return 0
        name_without_ext, _ = os.path.splitext(filename)
        base = name_without_ext.split('_aug')[0] if '_aug' in name_without_ext else name_without_ext
        parts = base.split('_')
        try:

            return int(parts[-1])
        except ValueError:
            return 0  # fallback if parsing fails

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label, ridge_count = self.data[idx]
        image = Image.open(img_path).convert('L')

        if self.transform:
            image = self.transform(image)

        return image,(label,ridge_count)

In [None]:
from torch.utils.data import DataLoader
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize grayscale (L mode)
])

train_dataset = FingerprintDataset('/kaggle/input/dataset-augmented/content/drive/MyDrive/Fingerprint wharehouse/Dataset-augmented/Train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = FingerprintDataset('/kaggle/input/dataset-augmented/content/drive/MyDrive/Fingerprint wharehouse/Dataset-augmented/Validation', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

test_dataset = FingerprintDataset('/kaggle/input/dataset-augmented/content/drive/MyDrive/Fingerprint wharehouse/Dataset-augmented/Test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

CLASS_NAMES = train_dataset.class_names
NUM_CLASSES = len(CLASS_NAMES)

In [None]:
import torchmetrics
from torch import nn

criterion_cls = nn.CrossEntropyLoss()
criterion_reg = nn.MSELoss()


In [None]:
import numpy as np
import torchmetrics

def run_epoch(model, loader, train=True, optimizer=None, criterion_cls=None, criterion_reg=None, DEVICE=None, NUM_CLASSES=None, REG_IMPORTANCE=None):
    # Classification metrics (keeping only per-class and macro)
    accuracy_per = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES, average="none").to(DEVICE)
    accuracy_macro = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(DEVICE)

    accuracy_global = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES, average="micro").to(DEVICE)

    precision_per = torchmetrics.Precision(task="multiclass", num_classes=NUM_CLASSES, average="none", zero_division=0).to(DEVICE)
    recall_per = torchmetrics.Recall(task="multiclass", num_classes=NUM_CLASSES, average="none", zero_division=0).to(DEVICE)
    f1_per = torchmetrics.F1Score(task="multiclass", num_classes=NUM_CLASSES, average="none", zero_division=0).to(DEVICE)

    precision_macro = torchmetrics.Precision(task="multiclass", num_classes=NUM_CLASSES, average="macro", zero_division=0).to(DEVICE)
    recall_macro = torchmetrics.Recall(task="multiclass", num_classes=NUM_CLASSES, average="macro", zero_division=0).to(DEVICE)
    f1_macro = torchmetrics.F1Score(task="multiclass", num_classes=NUM_CLASSES, average="macro", zero_division=0).to(DEVICE)

    specificity_per = torchmetrics.Specificity(task="multiclass", num_classes=NUM_CLASSES, average="none", zero_division=0).to(DEVICE)
    specificity_macro = torchmetrics.Specificity(task="multiclass", num_classes=NUM_CLASSES, average="macro", zero_division=0).to(DEVICE)

    conf_mat = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=NUM_CLASSES).to(DEVICE)

    all_classification_preds = []
    all_classification_targets = []

    preds_per_class_reg = defaultdict(list)
    targets_per_class_reg = defaultdict(list)

    if train:
        model.train()
    else:
        model.eval()

    # Reset all classification metrics
    accuracy_per.reset(); accuracy_macro.reset(); accuracy_global.reset()
    precision_per.reset(); recall_per.reset(); f1_per.reset()
    precision_macro.reset(); recall_macro.reset(); f1_macro.reset()
    specificity_per.reset(); specificity_macro.reset()
    conf_mat.reset()

    total_loss = 0.0
    total_samples = 0

    with torch.set_grad_enabled(train):
        for imgs, (class_labels, ridge_counts) in loader:
            imgs = imgs.to(DEVICE)
            class_labels = class_labels.to(DEVICE)
            ridge_counts = ridge_counts.to(DEVICE).float()

            class_logits, ridge_preds = model(imgs)

            loss_cls = criterion_cls(class_logits, class_labels)
            ridge_preds = ridge_preds.squeeze()
            loss_reg = criterion_reg(ridge_preds, ridge_counts)
            loss = (1 - REG_IMPORTANCE) * loss_cls + REG_IMPORTANCE * loss_reg

            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            batch_size = imgs.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            preds_prob = class_logits.argmax(dim=1)

            # Update classification metrics
            accuracy_per.update(preds_prob, class_labels)
            accuracy_macro.update(preds_prob, class_labels)
            accuracy_global.update(preds_prob, class_labels)

            precision_per.update(preds_prob, class_labels)
            recall_per.update(preds_prob, class_labels)
            f1_per.update(preds_prob, class_labels)

            precision_macro.update(preds_prob, class_labels)
            recall_macro.update(preds_prob, class_labels)
            f1_macro.update(preds_prob, class_labels)

            specificity_per.update(preds_prob, class_labels)
            specificity_macro.update(preds_prob, class_labels)

            conf_mat.update(preds_prob, class_labels)

            all_classification_preds.extend(preds_prob.cpu().tolist())
            all_classification_targets.extend(class_labels.cpu().tolist())

            # Accumulate per-class regression preds and targets for macro calculation
            class_labels_cpu = class_labels.cpu().numpy()
            ridge_preds_cpu = ridge_preds.detach().cpu().numpy()
            ridge_targets_cpu = ridge_counts.cpu().numpy()
            for c, p, t in zip(class_labels_cpu, ridge_preds_cpu, ridge_targets_cpu):
                preds_per_class_reg[c].append(p)
                targets_per_class_reg[c].append(t)

    avg_loss = total_loss / total_samples

    # Compute classification metrics
    accuracy_per_val = accuracy_per.compute().cpu().numpy()
    accuracy_macro_val = accuracy_macro.compute().item()
    accuracy_global_val = accuracy_global.compute().item()

    precision_per_val = precision_per.compute().cpu().numpy()
    recall_per_val = recall_per.compute().cpu().numpy()
    f1_per_val = f1_per.compute().cpu().numpy()

    precision_macro_val = precision_macro.compute().item()
    recall_macro_val = recall_macro.compute().item()
    f1_macro_val = f1_macro.compute().item()

    specificity_per_val = specificity_per.compute().cpu().numpy()
    specificity_macro_val = specificity_macro.compute().item()

    # Compute confusion matrix
    conf_mat_val = conf_mat.compute().cpu().numpy()

    # Compute per-class regression metrics and then macro-average them
    mse_per_class_reg = np.full(NUM_CLASSES, np.nan)
    mae_per_class_reg = np.full(NUM_CLASSES, np.nan)
    r2_per_class_reg = np.full(NUM_CLASSES, np.nan)

    # Lists to hold per-class metric values for macro averaging
    mse_values_for_macro = []
    mae_values_for_macro = []
    r2_values_for_macro = []

    for c in range(NUM_CLASSES): # Iterate through all possible classes
        preds_c = torch.tensor(preds_per_class_reg[c])
        targets_c = torch.tensor(targets_per_class_reg[c])

        if len(preds_c) > 0:
            mse_val = torchmetrics.functional.mean_squared_error(preds_c, targets_c).item()
            mae_val = torchmetrics.functional.mean_absolute_error(preds_c, targets_c).item()

            mse_per_class_reg[c] = mse_val
            mae_per_class_reg[c] = mae_val

            # R2 needs variability in targets
            if len(torch.unique(targets_c)) > 1:
                r2_val = torchmetrics.functional.r2_score(preds_c, targets_c).item()
                r2_per_class_reg[c] = r2_val
            else:
                r2_per_class_reg[c] = 0.0 # Assign 0 if no variability for R2 (perfect prediction of a constant)

            # Add to lists for macro averaging if the class was present in the batch
            mse_values_for_macro.append(mse_per_class_reg[c])
            mae_values_for_macro.append(mae_per_class_reg[c])
            r2_values_for_macro.append(r2_per_class_reg[c])

    # Compute macro-averaged regression metrics
    mse_macro_val = np.mean(mse_values_for_macro) if mse_values_for_macro else np.nan
    mae_macro_val = np.mean(mae_values_for_macro) if mae_values_for_macro else np.nan
    r2_macro_val = np.mean(r2_values_for_macro) if r2_values_for_macro else np.nan


    metrics = {
        "loss": avg_loss,
        # Classification metrics
        "accuracy_per_class": accuracy_per_val, # This is a NumPy array
        "accuracy_macro": accuracy_macro_val,
        "accuracy_global": accuracy_global_val,
        "precision_per_class": precision_per_val, # This is a NumPy array
        "recall_per_class": recall_per_val, # This is a NumPy array
        "f1_per_class": f1_per_val, # This is a NumPy array
        "specificity_per_class": specificity_per_val, # Specificity per class
        "precision_macro": precision_macro_val,
        "recall_macro": recall_macro_val,
        "f1_macro": f1_macro_val,
        "specificity_macro": specificity_macro_val, # Specificity macro
        "confusion_matrix": conf_mat_val, # This is a NumPy array
        "raw_classification_preds": all_classification_preds, # List of Python ints
        "raw_classification_targets": all_classification_targets, # List of Python ints
        # Regression metrics
        "mse_per_class": mse_per_class_reg, # This is a NumPy array
        "mae_per_class": mae_per_class_reg, # This is a NumPy array
        "r2_per_class": r2_per_class_reg, # This is a NumPy array
        "mse_macro": mse_macro_val,
        "mae_macro": mae_macro_val,
        "r2_macro": r2_macro_val,
    }

    return metrics

# Setup monitoring

In [None]:
! pip install -q wandb

In [None]:
import os
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
os.environ["WANDB_API_KEY"] = user_secrets.get_secret("wandb_api")
wandb.login()


[34m[1mwandb[0m: Currently logged in as: [33mossamaoutmani[0m ([33mossamaoutmani-nexos[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
if wandb.run is not None:
    wandb.finish()

# ONNX Setup

In [None]:
!pip install -q onnx
!pip install -q onnxruntime

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch

def save_onnx(model,model_path):

    # 1) Put model in eval mode
    model = model.module if isinstance(model, torch.nn.DataParallel) else model
    model.eval()

    # 2) Create a dummy input matching your training dimensions
    device = next(model.parameters()).device
    dummy_input = torch.randn(1, 3, 224, 224, device=device)

    # 3) Export to ONNX
    torch.onnx.export(
        model,
        dummy_input,
        model_path,
        export_params=True,        # store weights
        opset_version=17,          # use a recent ONNX opset
        do_constant_folding=True,  # pre-fold constants for speed
        input_names=["input"],     # names for graph inputs
        output_names=["output"],   # names for graph outputs
        dynamic_axes={             # allow variable batch size
            "input": {0: "batch"},
            "output": {0: "batch"}
        }
    )

    print(model_path ,"saved successfully")


# Repvgg model

In [None]:
if wandb.run is not None:
    wandb.finish()

In [None]:
import wandb

wandb.init(
    project="Fingerprint-classification-ridgecount",entity="elharkaouimeriem-ensa",
    reinit=False,
    name="repvgg_model-run",
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
import torchvision.models as models
import torch
import torch.nn as nn
import timm

class RepVGGMultiTask(nn.Module):
    def __init__(self, num_classes=12, dropout_rate=0.5):
        super().__init__()
        # Load pre-trained RepVGG-A2 from timm
        self.base_model = timm.create_model('repvgg_a2', pretrained=True)

        in_features = self.base_model.num_features

        # Replace the original classifier head with an Identity layer.
        self.base_model.head = nn.Identity()

        # Add a Dropout layer
        self.dropout = nn.Dropout(p=dropout_rate)

        # Classification head
        self.class_head = nn.Linear(in_features, num_classes)

        # Ridge count regression head
        self.ridge_head = nn.Linear(in_features, 1)

    def forward(self, x):
        features = self.base_model.forward_features(x)

        features = F.adaptive_avg_pool2d(features, (1, 1))
        features = torch.flatten(features, 1)
        features = self.dropout(features)

        class_logits = self.class_head(features)
        ridge_output = self.ridge_head(features)
        return class_logits, ridge_output.squeeze(1)

In [None]:
import torch.nn.functional as F

repvgg_model = RepVGGMultiTask(num_classes=NUM_CLASSES, dropout_rate=0.5).to(DEVICE)

# If using DataParallel, ensure the model variable name is updated
# and device_ids are correct for your setup.
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for DataParallel.")
    repvgg_model = nn.DataParallel(repvgg_model).to(DEVICE) # .cuda() is redundant if .to(DEVICE) handles it
elif torch.cuda.is_available():
    print("Using single GPU.")
else:
    print("Using CPU.")


optimizer = torch.optim.Adam(repvgg_model.parameters(), lr=1e-4)

# Losses (remain the same as they are generic)
classification_loss_fn = nn.CrossEntropyLoss()
regression_loss_fn = nn.MSELoss()


model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Using single GPU.


In [None]:
from collections import defaultdict

EPOCHS = 2 # Max number of epochs
EARLY_STOPPING_PATIENCE = 25 # Number of epochs to wait for improvement before stopping

best_val_loss = float("inf")
best_val_f1_macro = 0.0 # Use F1 Macro for early stopping as it balances precision and recall
epochs_without_improvement = 0

best_metrics = {}
best_train_metrics = {}
last_metrics = {}
last_train_metrics = {}

# --- Track epoch numbers for best and last models ---
best_epoch_num = 0
last_epoch_num = 0 # Will be updated at the end of the loop

BEST_PATH = "/kaggle/working/repvgg_best.onnx"
LAST_PATH = "/kaggle/working/repvgg_last.onnx"
BEST_PTH = "/kaggle/working/repvgg_best.pth"
LAST_PTH = "/kaggle/working/repvgg_last.pth"


current_model = repvgg_model

for epoch in range(1, EPOCHS + 1):
    train_stats = run_epoch(current_model, train_loader, train=True, optimizer=optimizer, criterion_cls=classification_loss_fn, criterion_reg=regression_loss_fn, DEVICE=DEVICE, NUM_CLASSES=NUM_CLASSES, REG_IMPORTANCE=REG_IMPORTANCE)
    val_stats = run_epoch(current_model, val_loader, train=False, optimizer=optimizer, criterion_cls=classification_loss_fn, criterion_reg=regression_loss_fn, DEVICE=DEVICE, NUM_CLASSES=NUM_CLASSES, REG_IMPORTANCE=REG_IMPORTANCE)

    log_dict = {"epoch": epoch}

    # 1. Track all training/validation metrics epoch-by-epoch
    for phase in ("train", "val"):
        stats = train_stats if phase == "train" else val_stats

        # Log GLOBAL (Macro) Metrics directly to W&B run (per epoch)
        log_dict[f"{phase}/loss"] = stats["loss"]
        log_dict[f"{phase}/accuracy_macro"] = stats["accuracy_macro"]
        log_dict[f"{phase}/accuracy_global"] = stats["accuracy_global"]
        log_dict[f"{phase}/precision_macro"] = stats["precision_macro"]
        log_dict[f"{phase}/recall_macro"] = stats["recall_macro"]
        log_dict[f"{phase}/f1_macro"] = stats["f1_macro"]
        log_dict[f"{phase}/specificity_macro"] = stats["specificity_macro"]
        log_dict[f"{phase}/mse_macro"] = stats["mse_macro"]
        log_dict[f"{phase}/mae_macro"] = stats["mae_macro"]
        log_dict[f"{phase}/r2_macro"] = stats["r2_macro"]

        # Log PER-CLASS Metrics individually for epoch-wise logging
        for cls_idx, class_name in enumerate(CLASS_NAMES):
            # Use .item() to convert NumPy scalar to Python float for W&B logging
            log_dict[f"{phase}/accuracy_{class_name}"] = stats["accuracy_per_class"][cls_idx].item()
            log_dict[f"{phase}/precision_{class_name}"] = stats["precision_per_class"][cls_idx].item()
            log_dict[f"{phase}/recall_{class_name}"] = stats["recall_per_class"][cls_idx].item()
            log_dict[f"{phase}/f1_{class_name}"] = stats["f1_per_class"][cls_idx].item()
            log_dict[f"{phase}/specificity_{class_name}"] = stats["specificity_per_class"][cls_idx].item()
            log_dict[f"{phase}/mse_{class_name}"] = stats["mse_per_class"][cls_idx].item()
            log_dict[f"{phase}/mae_{class_name}"] = stats["mae_per_class"][cls_idx].item()
            log_dict[f"{phase}/r2_{class_name}"] = stats["r2_per_class"][cls_idx].item()

        # Log confusion matrix to W&B
        log_dict[f"{phase}/confusion_matrix"] = wandb.plot.confusion_matrix(
                     preds=stats["raw_classification_preds"],
                     y_true=stats["raw_classification_targets"],
                     class_names=CLASS_NAMES
                   )

    wandb.log(log_dict)

    print(f"Epoch {epoch:02d}")
    for phase_name, stats in (("Train", train_stats), ("Val ", val_stats)):
        print(
            f"  {phase_name} | "
            f"Loss {stats['loss']:.4f}, Acc_macro {stats['accuracy_macro']:.4f}, "
            f"Acc_global {stats['accuracy_global']:.4f}, "
            f"F1_macro {stats['f1_macro']:.4f}, "
            f"Prec_macro {stats['precision_macro']:.4f}, "
            f"Rec_macro {stats['recall_macro']:.4f}, "
            f"Spec_macro {stats['specificity_macro']:.4f}"
        )
        for cls_idx, class_name in enumerate(CLASS_NAMES):
            a = stats['accuracy_per_class'][cls_idx]
            p = stats['precision_per_class'][cls_idx]
            r = stats['recall_per_class'][cls_idx]
            f = stats['f1_per_class'][cls_idx]
            s = stats['specificity_per_class'][cls_idx]
            print(f"    Class {class_name:<12}: Acc {a:.4f}, Prec {p:.4f}, Rec {r:.4f}, F1 {f:.4f}, Spec {s:.4f}")

        print(
            f"  {phase_name} Regression | MSE_macro {stats['mse_macro']:.4f}, "
            f"MAE_macro {stats['mae_macro']:.4f}, R2_macro {stats['r2_macro']:.4f}"
        )
        for cls_idx, class_name in enumerate(CLASS_NAMES):
            mse = stats['mse_per_class'][cls_idx]
            mae = stats['mae_per_class'][cls_idx]
            r2 = stats['r2_per_class'][cls_idx]
            print(f"    Class {class_name:<12} Regression: MSE {mse:.4f}, MAE {mae:.4f}, R2 {r2:.4f}")

    current_val_f1_macro = val_stats["f1_macro"]
    current_val_loss = val_stats["loss"]

    # --- Early Stopping Logic ---
    # We prioritize F1_macro for best model, then loss if F1_macro is tied
    if (current_val_f1_macro > best_val_f1_macro) or \
       (current_val_f1_macro == best_val_f1_macro and current_val_loss < best_val_loss):
        best_val_f1_macro = current_val_f1_macro
        best_val_loss = current_val_loss
        epochs_without_improvement = 0 # Reset counter
        best_metrics = val_stats.copy()
        best_train_metrics = train_stats.copy()
        best_epoch_num = epoch # Store the epoch number for the best model

        save_onnx(current_model, BEST_PATH)
        torch.save(current_model.state_dict(), BEST_PTH)
        print(f"*** New best model saved at epoch {epoch} with Val Macro F1: {best_val_f1_macro:.4f} ***")
    else:
        epochs_without_improvement += 1
        print(f"No improvement for {epochs_without_improvement} epochs.")
        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping triggered after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
            last_epoch_num = epoch # The last epoch before stopping
            break # Exit training loop

# --- Set last_epoch_num if loop completes naturally ---
if epoch == EPOCHS: # If the loop ran for all epochs without early stopping
    last_epoch_num = EPOCHS
# If early stopping occurred, last_epoch_num was already set in the break condition

# Save last epoch metrics (for the last model artifact)
last_metrics = val_stats.copy()
last_train_metrics = train_stats.copy()

save_onnx(current_model, LAST_PATH)
torch.save(current_model.state_dict(), LAST_PTH)
print(f"Last model saved. Val Macro F1: {last_metrics['f1_macro']:.4f}")

Epoch 01
  Train | Loss 7.2342, Acc_macro 0.5060, Acc_global 0.5072, F1_macro 0.5090, Prec_macro 0.5183, Rec_macro 0.5060, Spec_macro 0.9552
    Class AC          : Acc 0.6831, Prec 0.8428, Rec 0.6831, F1 0.7546, Spec 0.9881
    Class Ambiguous   : Acc 0.5395, Prec 0.6208, Rec 0.5395, F1 0.5773, Spec 0.9699
    Class TA          : Acc 0.5729, Prec 0.6463, Rec 0.5729, F1 0.6074, Spec 0.9726
    Class SA          : Acc 0.5289, Prec 0.6359, Rec 0.5289, F1 0.5775, Spec 0.9769
    Class TAUL_TARL   : Acc 0.6277, Prec 0.5296, Rec 0.6277, F1 0.5745, Spec 0.9524
    Class UL_RL       : Acc 0.6791, Prec 0.5526, Rec 0.6791, F1 0.6094, Spec 0.9429
    Class SW          : Acc 0.2748, Prec 0.2698, Rec 0.2748, F1 0.2723, Spec 0.9332
    Class TCW         : Acc 0.4770, Prec 0.3925, Rec 0.4770, F1 0.4306, Spec 0.9318
    Class UPE_RPE     : Acc 0.4643, Prec 0.4945, Rec 0.4643, F1 0.4789, Spec 0.9563
    Class DL          : Acc 0.5095, Prec 0.4836, Rec 0.5095, F1 0.4962, Spec 0.9474
    Class ESW      

In [None]:
import math

current_model.eval()
test_last_stats = run_epoch(current_model, test_loader, train=False, optimizer=None, criterion_cls=classification_loss_fn, criterion_reg=regression_loss_fn, DEVICE=DEVICE, NUM_CLASSES=NUM_CLASSES, REG_IMPORTANCE=REG_IMPORTANCE)
test_last_log = {
    "test/loss": test_last_stats["loss"],
    "test/accuracy_macro": test_last_stats["accuracy_macro"],
    "test/accuracy_global": test_last_stats["accuracy_global"],
    "test/precision_macro": test_last_stats["precision_macro"],
    "test/recall_macro": test_last_stats["recall_macro"],
    "test/f1_macro": test_last_stats["f1_macro"],
    "test/specificity_macro": test_last_stats["specificity_macro"],
    "test/mse_macro": test_last_stats["mse_macro"],
    "test/mae_macro": test_last_stats["mae_macro"],
    "test/r2_macro": test_last_stats["r2_macro"],
    "epochs_trained": last_epoch_num # Log epoch number for last model
}

# Add per-class metrics as a nested dictionary to the artifact metadata for the last model (TEST)
test_last_per_class_metrics = {}
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = test_last_stats["accuracy_per_class"][cls_idx].item()
    prec = test_last_stats["precision_per_class"][cls_idx].item()
    rec = test_last_stats["recall_per_class"][cls_idx].item()
    f1 = test_last_stats["f1_per_class"][cls_idx].item()
    spec = test_last_stats["specificity_per_class"][cls_idx].item()
    mse = test_last_stats["mse_per_class"][cls_idx].item()
    mae = test_last_stats["mae_per_class"][cls_idx].item()
    r2 = test_last_stats["r2_per_class"][cls_idx].item()

    test_last_per_class_metrics[class_name] = {
        "accuracy": acc if not math.isnan(acc) else 0.0,
        "precision": prec if not math.isnan(prec) else 0.0,
        "recall": rec if not math.isnan(rec) else 0.0,
        "f1_score": f1 if not math.isnan(f1) else 0.0,
        "specificity": spec if not math.isnan(spec) else 0.0,
        "mse": mse if not math.isnan(mse) else 0.0,
        "mae": mae if not math.isnan(mae) else 0.0,
        "r2": r2 if not math.isnan(r2) else 0.0,
    }
test_last_log["test/per_class_metrics"] = test_last_per_class_metrics

# Include LAST validation macro metrics in the artifact metadata
test_last_log.update({
    "val_last/loss": last_metrics["loss"],
    "val_last/accuracy_macro": last_metrics["accuracy_macro"],
    "val_last/accuracy_global": last_metrics["accuracy_global"],
    "val_last/precision_macro": last_metrics["precision_macro"],
    "val_last/recall_macro": last_metrics["recall_macro"],
    "val_last/f1_macro": last_metrics["f1_macro"],
    "val_last/specificity_macro": last_metrics["specificity_macro"],
    "val_last/mse_macro": last_metrics["mse_macro"],
    "val_last/mae_macro": last_metrics["mae_macro"],
    "val_last/r2_macro": last_metrics["r2_macro"],
})

# Add per-class metrics as a nested dictionary to the artifact metadata for the last model (VAL)
val_last_per_class_metrics = {}
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = last_metrics["accuracy_per_class"][cls_idx].item()
    prec = last_metrics["precision_per_class"][cls_idx].item()
    rec = last_metrics["recall_per_class"][cls_idx].item()
    f1 = last_metrics["f1_per_class"][cls_idx].item()
    spec = last_metrics["specificity_per_class"][cls_idx].item()
    mse = last_metrics["mse_per_class"][cls_idx].item()
    mae = last_metrics["mae_per_class"][cls_idx].item()
    r2 = last_metrics["r2_per_class"][cls_idx].item()

    val_last_per_class_metrics[class_name] = {
        "accuracy": acc if not math.isnan(acc) else 0.0,
        "precision": prec if not math.isnan(prec) else 0.0,
        "recall": rec if not math.isnan(rec) else 0.0,
        "f1_score": f1 if not math.isnan(f1) else 0.0,
        "specificity": spec if not math.isnan(spec) else 0.0,
        "mse": mse if not math.isnan(mse) else 0.0,
        "mae": mae if not math.isnan(mae) else 0.0,
        "r2": r2 if not math.isnan(r2) else 0.0,
    }
test_last_log["val_last/per_class_metrics"] = val_last_per_class_metrics

# Include LAST training macro metrics in the artifact metadata
test_last_log.update({
    "train_last/loss": last_train_metrics["loss"],
    "train_last/accuracy_macro": last_train_metrics["accuracy_macro"],
    "train_last/accuracy_global": last_train_metrics["accuracy_global"],
    "train_last/precision_macro": last_train_metrics["precision_macro"],
    "train_last/recall_macro": last_train_metrics["recall_macro"],
    "train_last/f1_macro": last_train_metrics["f1_macro"],
    "train_last/specificity_macro": last_train_metrics["specificity_macro"],
    "train_last/mse_macro": last_train_metrics["mse_macro"],
    "train_last/mae_macro": last_train_metrics["mae_macro"],
    "train_last/r2_macro": last_train_metrics["r2_macro"],
})

# Add per-class metrics as a nested dictionary to the artifact metadata for the last model (TRAIN)
train_last_per_class_metrics = {}
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = last_train_metrics["accuracy_per_class"][cls_idx].item()
    prec = last_train_metrics["precision_per_class"][cls_idx].item()
    rec = last_train_metrics["recall_per_class"][cls_idx].item()
    f1 = last_train_metrics["f1_per_class"][cls_idx].item()
    spec = last_train_metrics["specificity_per_class"][cls_idx].item()
    mse = last_train_metrics["mse_per_class"][cls_idx].item()
    mae = last_train_metrics["mae_per_class"][cls_idx].item()
    r2 = last_train_metrics["r2_per_class"][cls_idx].item()

    train_last_per_class_metrics[class_name] = {
        "accuracy": acc if not math.isnan(acc) else 0.0,
        "precision": prec if not math.isnan(prec) else 0.0,
        "recall": rec if not math.isnan(rec) else 0.0,
        "f1_score": f1 if not math.isnan(f1) else 0.0,
        "specificity": spec if not math.isnan(spec) else 0.0,
        "mse": mse if not math.isnan(mse) else 0.0,
        "mae": mae if not math.isnan(mae) else 0.0,
        "r2": r2 if not math.isnan(r2) else 0.0,
    }
test_last_log["train_last/per_class_metrics"] = train_last_per_class_metrics


print(f"\n{'='*40}")
print("=== TEST RESULTS LAST MODEL ===")
print(
    f"Test | Loss {test_last_stats['loss']:.4f}, "
    f"Acc_macro {test_last_stats['accuracy_macro']:.4f}, "
    f"Acc_global {test_last_stats['accuracy_global']:.4f}, "
    f"Prec_macro {test_last_stats['precision_macro']:.4f}, "
    f"Rec_macro {test_last_stats['recall_macro']:.4f}, "
    f"F1_macro {test_last_stats['f1_macro']:.4f}, "
    f"Spec_macro {test_last_stats['specificity_macro']:.4f}"
)

print("\nPer-Class Classification Metrics:")
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = test_last_stats["accuracy_per_class"][cls_idx]
    p = test_last_stats["precision_per_class"][cls_idx]
    r = test_last_stats["recall_per_class"][cls_idx]
    f = test_last_stats["f1_per_class"][cls_idx]
    s = test_last_stats["specificity_per_class"][cls_idx]
    print(f"  Class {class_name:<12}: Acc {acc:.4f}, Prec {p:.4f}, Rec {r:.4f}, F1 {f:.4f}, Spec {s:.4f}")

print("\nMacro Regression Metrics:")
print(
    f"  MSE_macro {test_last_stats['mse_macro']:.4f}, "
    f"MAE_macro {test_last_stats['mae_macro']:.4f}, "
    f"R2_macro {test_last_stats['r2_macro']:.4f}"
)

print("\nPer-Class Regression Metrics (MSE, MAE, R2):")
for cls_idx, class_name in enumerate(CLASS_NAMES):
    mse = test_last_stats["mse_per_class"][cls_idx]
    mae = test_last_stats["mae_per_class"][cls_idx]
    r2 = test_last_stats["r2_per_class"][cls_idx]
    if r2 == 0.0 and test_last_stats["mse_per_class"][cls_idx] > 0.0001:
        print(f"  Class {class_name:<12}: MSE {mse:.4f}, MAE {mae:.4f}, R2 {r2:.4f} (R2=0 likely due to zero target variance)")
    else:
        print(f"  Class {class_name:<12}: MSE {mse:.4f}, MAE {mae:.4f}, R2 {r2:.4f}")

wandb.log({
    "test/confusion_matrix_last_model": wandb.plot.confusion_matrix(
        preds=test_last_stats["raw_classification_preds"],
        y_true=test_last_stats["raw_classification_targets"],
        class_names=CLASS_NAMES
    )
})

# --- W&B Artifact Logging (updated metadata) ---
last_art = wandb.Artifact("repvgg-last", type="model")
last_art.add_file(LAST_PATH)
last_art.add_file(LAST_PTH)
last_art.metadata.update(test_last_log)
last_art.metadata.update({"opset": 17})
logged_last = wandb.log_artifact(last_art, aliases=["last"])
logged_last.wait()
print(f"Model saved and pushed to W&B : {wandb.run.get_url()}")


=== TEST RESULTS LAST MODEL ===
Test | Loss 3.6584, Acc_macro 0.7306, Acc_global 0.8427, Prec_macro 0.6698, Rec_macro 0.7306, F1_macro 0.6851, Spec_macro 0.9844

Per-Class Classification Metrics:
  Class AC          : Acc 0.7500, Prec 0.5455, Rec 0.7500, F1 0.6316, Spec 0.9954
  Class Ambiguous   : Acc 0.8969, Prec 0.9607, Rec 0.8969, F1 0.9277, Spec 0.9602
  Class TA          : Acc 0.7292, Prec 0.8333, Rec 0.7292, F1 0.7778, Spec 0.9861
  Class SA          : Acc 0.9524, Prec 0.8000, Rec 0.9524, F1 0.8696, Spec 0.9692
  Class TAUL_TARL   : Acc 0.9091, Prec 0.6579, Rec 0.9091, F1 0.7634, Spec 0.9751
  Class UL_RL       : Acc 0.7232, Prec 0.7570, Rec 0.7232, F1 0.7397, Spec 0.9737
  Class SW          : Acc 0.6538, Prec 0.5312, Rec 0.6538, F1 0.5862, Spec 0.9860
  Class TCW         : Acc 0.8333, Prec 0.3333, Rec 0.8333, F1 0.4762, Spec 0.9909
  Class UPE_RPE     : Acc 0.7857, Prec 0.7674, Rec 0.7857, F1 0.7765, Spec 0.9905
  Class DL          : Acc 0.4783, Prec 0.6471, Rec 0.4783, F1 0.5

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
from collections import OrderedDict

repvgg_best = RepVGGMultiTask(num_classes=NUM_CLASSES, dropout_rate=0.5).to(DEVICE)

checkpoint = torch.load(BEST_PTH, map_location=DEVICE)

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] if k.startswith('module.') else k
    new_state_dict[name] = v

repvgg_best.load_state_dict(new_state_dict)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for best model evaluation.")
    repvgg_best = nn.DataParallel(repvgg_best)
repvgg_best.to(DEVICE)
repvgg_best.eval()

test_best_stats = run_epoch(repvgg_best, test_loader, train=False, optimizer=None, criterion_cls=classification_loss_fn, criterion_reg=regression_loss_fn, DEVICE=DEVICE, NUM_CLASSES=NUM_CLASSES, REG_IMPORTANCE=REG_IMPORTANCE)

# Compose best_test_log for W&B artifact metadata (Macro/aggregate metrics)
best_test_log = {
    "test/loss": test_best_stats["loss"],
    "test/accuracy_macro": test_best_stats["accuracy_macro"],
    "test/accuracy_global": test_best_stats["accuracy_global"],
    "test/precision_macro": test_best_stats["precision_macro"],
    "test/recall_macro": test_best_stats["recall_macro"],
    "test/f1_macro": test_best_stats["f1_macro"],
    "test/specificity_macro": test_best_stats["specificity_macro"],
    "test/mse_macro": test_best_stats["mse_macro"],
    "test/mae_macro": test_best_stats["mae_macro"],
    "test/r2_macro": test_best_stats["r2_macro"],
    "epochs_trained": best_epoch_num # Log epoch number for best model
}

# Add per-class metrics as a nested dictionary to the artifact metadata for the best model (TEST)
test_best_per_class_metrics = {}
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = test_best_stats["accuracy_per_class"][cls_idx].item()
    prec = test_best_stats["precision_per_class"][cls_idx].item()
    rec = test_best_stats["recall_per_class"][cls_idx].item()
    f1 = test_best_stats["f1_per_class"][cls_idx].item()
    spec = test_best_stats["specificity_per_class"][cls_idx].item()
    mse = test_best_stats["mse_per_class"][cls_idx].item()
    mae = test_best_stats["mae_per_class"][cls_idx].item()
    r2 = test_best_stats["r2_per_class"][cls_idx].item()

    test_best_per_class_metrics[class_name] = {
        "accuracy": acc if not math.isnan(acc) else 0.0,
        "precision": prec if not math.isnan(prec) else 0.0,
        "recall": rec if not math.isnan(rec) else 0.0,
        "f1_score": f1 if not math.isnan(f1) else 0.0,
        "specificity": spec if not math.isnan(spec) else 0.0,
        "mse": mse if not math.isnan(mse) else 0.0,
        "mae": mae if not math.isnan(mae) else 0.0,
        "r2": r2 if not math.isnan(r2) else 0.0,
    }
best_test_log["test/per_class_metrics"] = test_best_per_class_metrics


# Include BEST validation macro metrics in the artifact metadata
best_test_log.update({
    "val_best/loss": best_metrics["loss"],
    "val_best/accuracy_macro": best_metrics["accuracy_macro"],
    "val_best/accuracy_global": best_metrics["accuracy_global"],
    "val_best/precision_macro": best_metrics["precision_macro"],
    "val_best/recall_macro": best_metrics["recall_macro"],
    "val_best/f1_macro": best_metrics["f1_macro"],
    "val_best/specificity_macro": best_metrics["specificity_macro"],
    "val_best/mse_macro": best_metrics["mse_macro"],
    "val_best/mae_macro": best_metrics["mae_macro"],
    "val_best/r2_macro": best_metrics["r2_macro"],
})

# Add per-class metrics as a nested dictionary to the artifact metadata for the best model (VAL)
val_best_per_class_metrics = {}
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = best_metrics["accuracy_per_class"][cls_idx].item()
    prec = best_metrics["precision_per_class"][cls_idx].item()
    rec = best_metrics["recall_per_class"][cls_idx].item()
    f1 = best_metrics["f1_per_class"][cls_idx].item()
    spec = best_metrics["specificity_per_class"][cls_idx].item()
    mse = best_metrics["mse_per_class"][cls_idx].item()
    mae = best_metrics["mae_per_class"][cls_idx].item()
    r2 = best_metrics["r2_per_class"][cls_idx].item()

    val_best_per_class_metrics[class_name] = {
        "accuracy": acc if not math.isnan(acc) else 0.0,
        "precision": prec if not math.isnan(prec) else 0.0,
        "recall": rec if not math.isnan(rec) else 0.0,
        "f1_score": f1 if not math.isnan(f1) else 0.0,
        "specificity": spec if not math.isnan(spec) else 0.0,
        "mse": mse if not math.isnan(mse) else 0.0,
        "mae": mae if not math.isnan(mae) else 0.0,
        "r2": r2 if not math.isnan(r2) else 0.0,
    }
best_test_log["val_best/per_class_metrics"] = val_best_per_class_metrics

# Include BEST training macro metrics in the artifact metadata
best_test_log.update({
    "train_best/loss": best_train_metrics["loss"],
    "train_best/accuracy_macro": best_train_metrics["accuracy_macro"],
    "train_best/accuracy_global": best_train_metrics["accuracy_global"],
    "train_best/precision_macro": best_train_metrics["precision_macro"],
    "train_best/recall_macro": best_train_metrics["recall_macro"],
    "train_best/f1_macro": best_train_metrics["f1_macro"],
    "train_best/specificity_macro": best_train_metrics["specificity_macro"],
    "train_best/mse_macro": best_train_metrics["mse_macro"],
    "train_best/mae_macro": best_train_metrics["mae_macro"],
    "train_best/r2_macro": best_train_metrics["r2_macro"],
})

# Add per-class metrics as a nested dictionary to the artifact metadata for the best model (TRAIN)
train_best_per_class_metrics = {}
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = best_train_metrics["accuracy_per_class"][cls_idx].item()
    prec = best_train_metrics["precision_per_class"][cls_idx].item()
    rec = best_train_metrics["recall_per_class"][cls_idx].item()
    f1 = best_train_metrics["f1_per_class"][cls_idx].item()
    spec = best_train_metrics["specificity_per_class"][cls_idx].item()
    mse = best_train_metrics["mse_per_class"][cls_idx].item()
    mae = best_train_metrics["mae_per_class"][cls_idx].item()
    r2 = best_train_metrics["r2_per_class"][cls_idx].item()

    train_best_per_class_metrics[class_name] = {
        "accuracy": acc if not math.isnan(acc) else 0.0,
        "precision": prec if not math.isnan(prec) else 0.0,
        "recall": rec if not math.isnan(rec) else 0.0,
        "f1_score": f1 if not math.isnan(f1) else 0.0,
        "specificity": spec if not math.isnan(spec) else 0.0,
        "mse": mse if not math.isnan(mse) else 0.0,
        "mae": mae if not math.isnan(mae) else 0.0,
        "r2": r2 if not math.isnan(r2) else 0.0,
    }
best_test_log["train_best/per_class_metrics"] = train_best_per_class_metrics


print(f"\n{'='*40}")
print("=== TEST RESULTS BEST MODEL ===")
print(
    f"Test | Loss {test_best_stats['loss']:.4f}, "
    f"Acc_macro {test_best_stats['accuracy_macro']:.4f}, "
    f"Acc_global {test_best_stats['accuracy_global']:.4f}, "
    f"Prec_macro {test_best_stats['precision_macro']:.4f}, "
    f"Rec_macro {test_best_stats['recall_macro']:.4f}, "
    f"F1_macro {test_best_stats['f1_macro']:.4f}, "
    f"Spec_macro {test_best_stats['specificity_macro']:.4f}"
)

print("\nPer-Class Classification Metrics:")
for cls_idx, class_name in enumerate(CLASS_NAMES):
    acc = test_best_stats["accuracy_per_class"][cls_idx]
    p = test_best_stats["precision_per_class"][cls_idx]
    r = test_best_stats["recall_per_class"][cls_idx]
    f = test_best_stats["f1_per_class"][cls_idx]
    s = test_best_stats["specificity_per_class"][cls_idx]
    print(f"  Class {class_name:<12}: Acc {acc:.4f}, Prec {p:.4f}, Rec {r:.4f}, F1 {f:.4f}, Spec {s:.4f}")

print("\nMacro Regression Metrics:")
print(
    f"  MSE_macro {test_best_stats['mse_macro']:.4f}, "
    f"MAE_macro {test_best_stats['mae_macro']:.4f}, "
    f"R2_macro {test_best_stats['r2_macro']:.4f}"
)

print("\nPer-Class Regression Metrics (MSE, MAE, R2):")
for cls_idx, class_name in enumerate(CLASS_NAMES):
    mse = test_best_stats["mse_per_class"][cls_idx]
    mae = test_best_stats["mae_per_class"][cls_idx]
    r2 = test_best_stats["r2_per_class"][cls_idx]
    if r2 == 0.0 and test_best_stats["mse_per_class"][cls_idx] > 0.0001:
        print(f"  Class {class_name:<12}: MSE {mse:.4f}, MAE {mae:.4f}, R2 {r2:.4f} (R2=0 likely due to zero target variance)")
    else:
        print(f"  Class {class_name:<12}: MSE {mse:.4f}, MAE {mae:.4f}, R2 {r2:.4f}")

wandb.log({
    "test/confusion_matrix_best_model": wandb.plot.confusion_matrix(
        preds=test_best_stats["raw_classification_preds"],
        y_true=test_best_stats["raw_classification_targets"],
        class_names=CLASS_NAMES
    )
})

# --- W&B Artifact Logging (updated metadata) ---
best_art = wandb.Artifact("repvgg-best", type="model")
best_art.add_file(BEST_PATH)
best_art.add_file(BEST_PTH)
best_art.metadata.update(best_test_log)
best_art.metadata.update({"opset": 17})
logged_best = wandb.log_artifact(best_art, aliases=["best"])
logged_best.wait()
print(f"Model saved and pushed to W&B : {wandb.run.get_url()}")


  checkpoint = torch.load(BEST_PTH, map_location=DEVICE)



=== TEST RESULTS BEST MODEL ===
Test | Loss 3.6584, Acc_macro 0.7306, Acc_global 0.8427, Prec_macro 0.6698, Rec_macro 0.7306, F1_macro 0.6851, Spec_macro 0.9844

Per-Class Classification Metrics:
  Class AC          : Acc 0.7500, Prec 0.5455, Rec 0.7500, F1 0.6316, Spec 0.9954
  Class Ambiguous   : Acc 0.8969, Prec 0.9607, Rec 0.8969, F1 0.9277, Spec 0.9602
  Class TA          : Acc 0.7292, Prec 0.8333, Rec 0.7292, F1 0.7778, Spec 0.9861
  Class SA          : Acc 0.9524, Prec 0.8000, Rec 0.9524, F1 0.8696, Spec 0.9692
  Class TAUL_TARL   : Acc 0.9091, Prec 0.6579, Rec 0.9091, F1 0.7634, Spec 0.9751
  Class UL_RL       : Acc 0.7232, Prec 0.7570, Rec 0.7232, F1 0.7397, Spec 0.9737
  Class SW          : Acc 0.6538, Prec 0.5312, Rec 0.6538, F1 0.5862, Spec 0.9860
  Class TCW         : Acc 0.8333, Prec 0.3333, Rec 0.8333, F1 0.4762, Spec 0.9909
  Class UPE_RPE     : Acc 0.7857, Prec 0.7674, Rec 0.7857, F1 0.7765, Spec 0.9905
  Class DL          : Acc 0.4783, Prec 0.6471, Rec 0.4783, F1 0.5

In [None]:
if wandb.run is not None:
    wandb.finish()

0,1
epoch,▁█
train/accuracy_AC,▁█
train/accuracy_Ambiguous,▁█
train/accuracy_DL,▁█
train/accuracy_ECW,▁█
train/accuracy_ESW,▁█
train/accuracy_SA,▁█
train/accuracy_SW,▁█
train/accuracy_TA,▁█
train/accuracy_TAUL_TARL,▁█

0,1
epoch,2.0
train/accuracy_AC,0.96924
train/accuracy_Ambiguous,0.8305
train/accuracy_DL,0.70608
train/accuracy_ECW,0.57658
train/accuracy_ESW,0.50051
train/accuracy_SA,0.79693
train/accuracy_SW,0.44439
train/accuracy_TA,0.74323
train/accuracy_TAUL_TARL,0.82819
