In [4]:
# -----------------------------
# 1) Importing Libraries
# -----------------------------

import os,csv,random

import numpy as np
import matplotlib.pyplot as plt

import torch
from tqdm import tqdm
import timm
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from torchvision import transforms
from safetensors.torch import load_file

from utils.dataset import PCamDataset
from models.DFA_for_MLP import LinearDFA,set_dfa_error


In [None]:
# -----------------------------
#  Global Variables
# -----------------------------
def set_seed(seed: int = 42):  
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


log_dir: str = "results"
data_root: str = "data/pcam"

limit_samples = None
batch_size = 128
num_workers = 4

# Preparing Dataset

In [None]:
# # -----------------------------
# # Redefine transforms
# # -----------------------------
# train_tfms = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5,)*3, std=(0.5,)*3),
# ])

# val_tfms = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5,)*3, std=(0.5,)*3),
# ])

# # ---------------------------------------------------------------------
# # Dataset & DataLoader objects
# # ---------------------------------------------------------------------
# train_ds_full = PCamDataset(data_root, split="train", transform=train_tfms)
# val_ds_full = PCamDataset(data_root, split="val", transform=val_tfms)

# # Optionally reduce dataset size for faster experiments
# if limit_samples is not None:
#     from torch.utils.data import Subset

#     train_len = min(limit_samples, len(train_ds_full))
#     val_len = min(limit_samples, len(val_ds_full))

#     train_ds = Subset(train_ds_full, list(range(train_len)))  # type: ignore[arg-type]
#     val_ds = Subset(val_ds_full, list(range(val_len)))        # type: ignore[arg-type]
# else:
#     train_ds = train_ds_full
#     val_ds = val_ds_full

# train_loader: DataLoader = DataLoader(
#     train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

# val_loader: DataLoader = DataLoader(
#     val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

In [6]:
def compute_mean_std(dataset, batch_size=512, num_workers=4, device="cuda"):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    mean = 0.0
    sq_mean = 0.0
    n_images = 0

    for images, _ in loader:
        images = images.to(device)  
        B, C, H, W = images.shape
        n_pixels = B * H * W

        # sum over pixels per channel
        mean += images.sum(dim=[0, 2, 3])  # [C]
        sq_mean += (images ** 2).sum(dim=[0, 2, 3])  # [C]
        n_images += n_pixels

    mean /= n_images
    sq_mean /= n_images
    std = torch.sqrt(sq_mean - mean ** 2)

    return mean.tolist(), std.tolist()


tmp_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()  
])

tmp_dataset = PCamDataset(data_root, split="train", transform=tmp_tfms)

device = "cuda" if torch.cuda.is_available() else "cpu"
mean, std = compute_mean_std(tmp_dataset, batch_size=512, num_workers=num_workers, device=device)

print(f"Mean: {mean}, Std: {std}")


Mean: [0.7007558345794678, 0.538357675075531, 0.6916202306747437], Std: [0.21814213693141937, 0.2626423239707947, 0.19506362080574036]


In [None]:
# mean = [0.22222, 0.22222, 0.2222] 
# std = [0.41574, 0.41574, 0.41574]

In [8]:
# -----------------------------
# Redefine transforms
# -----------------------------
train_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(20),
    transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2,
                hue=0.1
            ),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

val_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean= mean, std= mean),
])
print(f"Mean: {mean}, Std: {std}")

# ---------------------------------------------------------------------
# Dataset & DataLoader objects
# ---------------------------------------------------------------------
train_ds_full = PCamDataset(data_root, split="train", transform=train_tfms)
val_ds_full = PCamDataset(data_root, split="val", transform=val_tfms)

# Optionally reduce dataset size for faster experiments
if limit_samples is not None:
    from torch.utils.data import Subset

    train_len = min(limit_samples, len(train_ds_full))
    val_len = min(limit_samples, len(val_ds_full))

    train_ds = Subset(train_ds_full, list(range(train_len)))  # type: ignore[arg-type]
    val_ds = Subset(val_ds_full, list(range(val_len)))        # type: ignore[arg-type]
else:
    train_ds = train_ds_full
    val_ds = val_ds_full

train_loader: DataLoader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

val_loader: DataLoader = DataLoader(
    val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

Mean: [0.7007558345794678, 0.538357675075531, 0.6916202306747437], Std: [0.21814213693141937, 0.2626423239707947, 0.19506362080574036]


# Load Model

In [9]:
# Set dfa = False for normal MLP
dfa = True

In [10]:
# Custom MLP Block with DFA

class MLPBlockWithDFA(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, drop=0.05):
        super().__init__()
        self.fc1 = LinearDFA(in_features = 192, out_features = 768,num_classes=2) 
        self.act = nn.GELU()
        self.dropout1 = nn.Dropout(drop)

        self.fc2 = LinearDFA(in_features = 768, out_features = 192,num_classes=2)
        self.dropout2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.dropout2(x)
        return x


In [11]:
def load_vit_tiny(model_path = "./vit_tiny_patch16_224.augreg_in21k/model.safetensors" ,dfa = False):
    # Create ViT-Tiny model with 2 output classes
    model = timm.create_model('vit_tiny_patch16_224.augreg_in21k', pretrained=False,
                             num_classes=2,drop_rate=0.1,drop_path_rate=0.1)

    # Load weights from safetensors file
    state_dict = load_file(model_path)

    # Remove 'model.' prefix if present
    if all(k.startswith("model.") for k in state_dict.keys()):
        state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

    # Remove head weights (21843 → 2 mismatch)
    state_dict = {k: v for k, v in state_dict.items() if not k.startswith("head.")}

    # Load backbone weights (ignore mismatches)
    model.load_state_dict(state_dict, strict=False)

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze the classification head , final normalization layer and last transformer blocks
    for param in model.head.parameters():
        param.requires_grad = True

    for param in model.norm.parameters():
        param.requires_grad = True

    for param in model.blocks[-1].parameters():
        param.requires_grad = True


    if(dfa == True):
        print("Using DFA")
        # --- Replace last block's MLP with DFA ---
        last_block = model.blocks[-1]
        old_mlp = last_block.mlp
        last_block.mlp = MLPBlockWithDFA(in_features=old_mlp.fc1.in_features, hidden_features=old_mlp.fc1.out_features,
            out_features=old_mlp.fc2.out_features)
        

    model.eval()
    return model

In [12]:
set_seed()
# Load model and move to device
model = load_vit_tiny(dfa=dfa)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Trainable: {name}")


# ✅ Check trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable} / {total}")

Using DFA
Trainable: blocks.11.norm1.weight
Trainable: blocks.11.norm1.bias
Trainable: blocks.11.attn.qkv.weight
Trainable: blocks.11.attn.qkv.bias
Trainable: blocks.11.attn.proj.weight
Trainable: blocks.11.attn.proj.bias
Trainable: blocks.11.norm2.weight
Trainable: blocks.11.norm2.bias
Trainable: blocks.11.mlp.fc1.weight
Trainable: blocks.11.mlp.fc1.bias
Trainable: blocks.11.mlp.fc2.weight
Trainable: blocks.11.mlp.fc2.bias
Trainable: norm.weight
Trainable: norm.bias
Trainable: head.weight
Trainable: head.bias
Trainable parameters: 445634 / 5526722


In [13]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)


In [14]:
# -----------------------------
#  Logging Data
# -----------------------------

tl_arr,ta_arr,vl_arr,va_arr = [],[],[],[]
# Create CSV log file
os.makedirs(log_dir, exist_ok=True)
csv_path = os.path.join(log_dir, "training_log.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["epoch", "train_loss", "train_acc", "val_loss", "val_acc"])


In [15]:
# -----------------------------
#  Hyperparameters
# -----------------------------
lr = 1e-4
weight_decay = 1e-3 
epochs = 9

In [None]:
# ----------------------------
# 4) Device, criterion, optimizer
# -----------------------------
# optimizer = optim.AdamW(
#     filter(lambda p: p.requires_grad, model.parameters()),lr=lr,weight_decay=weight_decay)


# warmup_steps = 1000  # total warm-up steps
# warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, step / warmup_steps))

# # Cosine decay scheduler after warmup
# total_steps = epochs * len(train_loader)
# cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)

# scheduler = SequentialLR(
#     optimizer,
#     schedulers=[warmup_scheduler, cosine_scheduler],
#     milestones=[warmup_steps]
# )

# criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # try 0.05


In [16]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),lr=lr,weight_decay=weight_decay)

scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer,T_0=3,T_mult=2)

In [17]:
# -----------------------------
# 6) Training & Validation Loops
# -----------------------------
def train_one_epoch(model, loader, criterion, optimizer, scheduler,epoch,dfa=False):
    model.train()
    total_loss = 0
    total_acc  = 0
    max_grad_norm = 1.0
    
    for imgs, labels in tqdm(loader, desc=f"Epoch {epoch}/{epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)

        if(dfa == True):
            # Compute global error and set for DFA
            error = torch.autograd.grad(loss, logits, retain_graph=True, create_graph=True)[0]
            set_dfa_error(error)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_acc  += (logits.argmax(dim=1) == labels).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    avg_acc  = total_acc  / len(loader.dataset)
    return avg_loss, avg_acc

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_acc  = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)

            total_loss += loss.item() * imgs.size(0)
            total_acc  += (logits.argmax(dim=1) == labels).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    avg_acc  = total_acc  / len(loader.dataset)
    return avg_loss, avg_acc

In [18]:
# -----------------------------
# 7) Run Fine-Tuning with TQDM
# -----------------------------
best_val_acc = 0.0

# Create CSV log file
os.makedirs(log_dir, exist_ok=True)
csv_path = os.path.join(log_dir, "training_log.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["epoch", "train_loss", "train_acc", "val_loss", "val_acc"])

# Epoch progress bar
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch, dfa)
    val_loss, val_acc = validate(model, val_loader, criterion)

    scheduler.step()

    tl_arr.append(train_loss)
    ta_arr.append(train_acc)
    vl_arr.append(val_loss)
    va_arr.append(val_acc)

    # Log epoch metrics
    with open(csv_path, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([epoch, train_loss, train_acc, val_loss, val_acc])

    tqdm.write(
        f"[Epoch {epoch:2d}/{epochs}] "
        f"Train: loss={train_loss:.4f}, acc={train_acc:.4f} | "
        f"Val:   loss={val_loss:.4f}, acc={val_acc:.4f}"
    )
    if val_acc > best_val_acc:  
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_vit96_dfa.pth")
        tqdm.write(f"→ New best val_acc={val_acc:.4f}, model saved.")

tqdm.write(f"Training complete. Best Val Acc: {best_val_acc:.4f}")


Epoch 1/9: 100%|██████████| 2048/2048 [10:55<00:00,  3.12it/s]


[Epoch  1/9] Train: loss=0.5670, acc=0.7333 | Val:   loss=0.6285, acc=0.6701
→ New best val_acc=0.6701, model saved.


Epoch 2/9: 100%|██████████| 2048/2048 [11:05<00:00,  3.08it/s]


[Epoch  2/9] Train: loss=0.5578, acc=0.7781 | Val:   loss=0.6344, acc=0.7234
→ New best val_acc=0.7234, model saved.


Epoch 3/9: 100%|██████████| 2048/2048 [10:25<00:00,  3.27it/s]


[Epoch  3/9] Train: loss=0.5518, acc=0.7869 | Val:   loss=0.6110, acc=0.7264
→ New best val_acc=0.7264, model saved.


Epoch 4/9: 100%|██████████| 2048/2048 [10:39<00:00,  3.20it/s]


[Epoch  4/9] Train: loss=0.5507, acc=0.7877 | Val:   loss=0.6070, acc=0.7303
→ New best val_acc=0.7303, model saved.


Epoch 5/9: 100%|██████████| 2048/2048 [10:42<00:00,  3.19it/s]


[Epoch  5/9] Train: loss=0.5483, acc=0.7896 | Val:   loss=0.6053, acc=0.7321
→ New best val_acc=0.7321, model saved.


Epoch 6/9: 100%|██████████| 2048/2048 [10:38<00:00,  3.21it/s]


[Epoch  6/9] Train: loss=0.5465, acc=0.7912 | Val:   loss=0.5984, acc=0.7363
→ New best val_acc=0.7363, model saved.


Epoch 7/9: 100%|██████████| 2048/2048 [10:28<00:00,  3.26it/s]


[Epoch  7/9] Train: loss=0.5432, acc=0.7933 | Val:   loss=0.5974, acc=0.7495
→ New best val_acc=0.7495, model saved.


Epoch 8/9: 100%|██████████| 2048/2048 [10:31<00:00,  3.24it/s]


[Epoch  8/9] Train: loss=0.5417, acc=0.7946 | Val:   loss=0.6029, acc=0.7382


Epoch 9/9: 100%|██████████| 2048/2048 [10:22<00:00,  3.29it/s]


[Epoch  9/9] Train: loss=0.5412, acc=0.7945 | Val:   loss=0.6058, acc=0.7403
Training complete. Best Val Acc: 0.7495


In [19]:
# -----------------------------
# 9) Plot training curves
# -----------------------------
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
epochs_range = range(1, epochs + 1)

# Loss curves
axs[0].plot(epochs_range, tl_arr, label="Train Loss", color="tab:blue", linestyle="-")
axs[0].plot(epochs_range, vl_arr, label="Val Loss", color="tab:blue", linestyle="--")
axs[0].set_title("Loss vs Epoch")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[0].grid(True)
axs[0].legend()

# Accuracy curves
axs[1].plot(epochs_range, ta_arr, label="Train Acc", color="tab:red", linestyle="-")
axs[1].plot(epochs_range, va_arr, label="Val Acc", color="tab:red", linestyle="--")
axs[1].set_title("Accuracy vs Epoch")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Accuracy")
axs[1].grid(True)
axs[1].legend()

plt.suptitle("Metrics per Epoch", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])  # leave space for suptitle
plot_path = os.path.join(log_dir, "training_curves.png")
plt.savefig(plot_path, dpi=300)
plt.close()
print(f"Training curves saved to {plot_path}")

Training curves saved to logs\training_curves.png


In [20]:
def analyze_model_confidence(model, dataloader, device):
    model.eval()
    all_probs = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())

    # Combine all probability batches
    probs_tensor = torch.cat(all_probs, dim=0)
    
    # Compute confidence as the max probability per sample
    confidence = probs_tensor.max(dim=1).values.numpy()

    # Summary statistics
    mean_conf = np.mean(confidence)
    std_conf = np.std(confidence)
    min_conf = np.min(confidence)
    max_conf = np.max(confidence)

    print("🔍 Model Confidence Statistics:")
    print(f"  Mean Confidence     : {mean_conf:.4f}")
    print(f"  Std Deviation       : {std_conf:.4f}")
    print(f"  Min Confidence      : {min_conf:.4f}")
    print(f"  Max Confidence      : {max_conf:.4f}")
    # Histogram plot
    plt.figure(figsize=(10, 6))
    n, bins, patches = plt.hist(
        confidence,
        bins=20,
        color='#4BA3C3',           # Softer blue
        edgecolor='black',
        alpha=0.85,
        linewidth=1.2
    )

    # Optional: add grid lines on y-axis only for clarity
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Titles and labels
    plt.title("Model Confidence Distribution on Validation Set", fontsize=14, fontweight='bold')
    plt.xlabel("Max Softmax Probability (Confidence)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)

    # Tick styling
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)

    # Tight layout for spacing
    plt.tight_layout()

    # Save the figure
    plt.savefig("analy_confid.png", dpi=300)
    plt.close()

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = timm.create_model("vit_tiny_patch16_224.augreg_in21k",pretrained = False,num_classes = 2)
# model.load_state_dict(torch.load("best_vit96_dfa.pth",map_location="cpu"))

model = model.to(device)

analyze_model_confidence(model, val_loader, device)

🔍 Model Confidence Statistics:
  Mean Confidence     : 0.7725
  Std Deviation       : 0.0036
  Min Confidence      : 0.5869
  Max Confidence      : 0.7754


# Evaluate Model

In [22]:
test_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,)*3, std=(0.5,)*3),
])

test_ds = PCamDataset(data_root, split="test", transform=test_tfms)

test_loader = DataLoader(
    test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [23]:
# test_model = timm.create_model("vit_tiny_patch16_224.augreg_in21k",pretrained = False,num_classes = 2)

# checkpoint = torch.load("best_model.pth", map_location="cpu")
# model.load_state_dict(checkpoint["model_state_dict"])
# model.to(device)
model.eval()


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)


In [24]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)

all_preds = []
all_labels = []
all_probs = []  # for ROC-AUC

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)  # shape: [batch, num_classes]

        probs = torch.softmax(outputs, dim=1)[:, 1]  # probability of class 1
        _, predicted = torch.max(outputs, 1)

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Convert to tensors/arrays
y_true = torch.tensor(all_labels)
y_pred = torch.tensor(all_preds)

# Compute metrics (PyTorch where possible)
accuracy = (y_true == y_pred).float().mean().item()
precision = ( ( (y_true * y_pred).sum().float() ) / (y_pred.sum().float() + 1e-8) ).item()
recall = ( ( (y_true * y_pred).sum().float() ) / (y_true.sum().float() + 1e-8) ).item()
f1 = (2 * precision * recall / (precision + recall + 1e-8))

# Sklearn for ROC-AUC & confusion matrix
roc_auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)

# Print results
print(f"Test Accuracy : {accuracy*100:.2f}%")
print(f"Precision     : {precision:.4f}")
print(f"Recall        : {recall:.4f}")
print(f"F1 Score      : {f1:.4f}")
print(f"ROC-AUC       : {roc_auc:.4f}")
print("Confusion Matrix:\n", cm)

Evaluating: 100%|██████████| 256/256 [01:07<00:00,  3.82it/s]


Test Accuracy : 77.04%
Precision     : 0.8079
Recall        : 0.7092
F1 Score      : 0.7553
ROC-AUC       : 0.8175
Confusion Matrix:
 [[13629  2762]
 [ 4763 11614]]
