In [1]:
!git clone https://github.com/deepinsight/insightface.git

Cloning into 'insightface'...
remote: Enumerating objects: 12592, done.[K
remote: Counting objects: 100% (148/148), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 12592 (delta 104), reused 89 (delta 89), pack-reused 12444 (from 3)[K
Receiving objects: 100% (12592/12592), 58.40 MiB | 41.79 MiB/s, done.
Resolving deltas: 100% (6542/6542), done.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import math
from tqdm import tqdm
import numpy as np
import itertools
from collections import defaultdict
import os
from PIL import Image
from pathlib import Path

from insightface.recognition.arcface_torch.backbones.mobilefacenet import get_mbf
import sys
sys.path.append("/kaggle/input/scores")
from torch.cuda.amp import autocast, GradScaler
import arc_scores

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
IMAGE_SIZE = 112
BATCH_SIZE = 256
NUM_EPOCHS = 25
FEATURE_DIM = 512

base_lr_backbone = 0.1
base_lr_margin = 0.5
weight_decay = 5e-4

step_milestones = [5, 13, 19]
step_gamma = 0.1

In [5]:
class FastImageFolder(Dataset):
    def __init__(self, root, transform=None, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
        self.root = Path(root)
        self.transform = transform
        self.extensions = extensions

        # Scan nhanh và cache paths
        self.samples = []
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        print(f"Scanning {root}...")
        for class_name in self.classes:
            class_dir = self.root / class_name
            class_idx = self.class_to_idx[class_name]

            for ext in self.extensions:
                for img_path in class_dir.glob(f'*{ext}'):
                    self.samples.append((str(img_path), class_idx))

        print(f"Found {len(self.samples)} images in {len(self.classes)} classes")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        image = Image.open(path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, target

In [6]:
train_transforms = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [7]:
train_folder = '/kaggle/input/train-ds/train'
train_dataset = FastImageFolder(train_folder, transform=train_transforms)

test_folder = '/kaggle/input/val-ds/val'
test_dataset = FastImageFolder(test_folder, transform=test_transforms)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

NUM_CLASSES = len(train_dataset.classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

print("="*60)
print(f"MULTI-GPU TRAINING SETUP")
print("="*60)
print(f"GPUs available: {n_gpus}")
for i in range(n_gpus):
    print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
print(f"Dataset: {NUM_CLASSES} classes, {len(train_dataset)} images")
print(f"Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE})")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Steps per epoch: {len(train_loader)}")
print("="*60)

Scanning /kaggle/input/train-ds/train...
Found 1119807 images in 5115 classes
Scanning /kaggle/input/val-ds/val...
Found 114964 images in 555 classes
MULTI-GPU TRAINING SETUP
GPUs available: 2
  GPU 0: Tesla T4
  GPU 1: Tesla T4
Dataset: 5115 classes, 1119807 images
Batch size: 256 (effective: 256)
Epochs: 25
Steps per epoch: 4375


In [8]:
class SubCenterArcFace(nn.Module):
    """
    Sub-Center ArcFace Loss (CVPR 2020)
    Paper: https://arxiv.org/abs/2005.10671
    """
    def __init__(self, in_features, out_features, s=64.0, m=0.5, k=3):
        super(SubCenterArcFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.k = k  # số sub-centers mỗi class
        
        # Weight: [num_classes * k, embedding_dim]
        self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
        nn.init.xavier_uniform_(self.weight)
        
        # Pre-compute trigonometric values
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, embeddings, labels):
        batch_size = embeddings.size(0)
        
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)
        weight_norm = F.normalize(self.weight, p=2, dim=1)
        
        cosine_all = F.linear(embeddings_norm, weight_norm)
        
        cosine_all = cosine_all.view(batch_size, self.out_features, self.k)
        
        sine_all = torch.sqrt(torch.clamp(1.0 - cosine_all ** 2, 1e-9, 1.0))
        phi_all = cosine_all * self.cos_m - sine_all * self.sin_m
        phi_all = torch.where(cosine_all > self.th, phi_all, cosine_all - self.mm)
        
        one_hot = torch.zeros(batch_size, self.out_features, 1, device=embeddings.device)
        one_hot.scatter_(1, labels.view(-1, 1, 1), 1.0)
        
        cosine_with_margin = one_hot * phi_all + (1.0 - one_hot) * cosine_all
        
        output, _ = torch.max(cosine_with_margin, dim=2)
        
        # Scale
        output *= self.s
        
        return output

In [9]:
def generate_balanced_pairs(labels, max_per_class=None, random_state=42):
    rng = np.random.RandomState(random_state)

    label2idx = defaultdict(list)
    for i, lb in enumerate(labels):
        label2idx[lb].append(i)

    pos_pairs = []
    for lb, idxs in label2idx.items():
        if len(idxs) < 2:
            continue

        idxs = np.array(idxs)
        if max_per_class and len(idxs) > max_per_class:
            idxs = rng.choice(idxs, max_per_class, replace=False)

        pos_pairs.extend(list(itertools.combinations(idxs, 2)))

    n_pos = len(pos_pairs)
    labels_unique = list(label2idx.keys())

    neg_pairs = []
    class_pairs = list(itertools.combinations(labels_unique, 2))

    for _ in range(n_pos):
        lb1, lb2 = class_pairs[rng.randint(len(class_pairs))]
        i = rng.choice(label2idx[lb1])
        j = rng.choice(label2idx[lb2])
        neg_pairs.append((i, j))

    pairs = [(i, j, 1) for (i, j) in pos_pairs] + \
            [(i, j, 0) for (i, j) in neg_pairs]

    rng.shuffle(pairs)
    return pairs

In [10]:
def evaluate(embs, labels, max_per_class=50, n_linspace=1000, epsilon=1e-6, random_state=42):
    embs = torch.cat(embs).cpu()
    labels = torch.cat(labels).cpu().numpy()

    pairs = generate_balanced_pairs(labels, max_per_class)
    pairs = np.array(pairs)

    idx_a = pairs[:, 0].astype(int)
    idx_b = pairs[:, 1].astype(int)
    similarity_scores = torch.sum(embs[idx_a] * embs[idx_b], dim=1).numpy()

    targets = pairs[:, 2].astype(int)

    # Best accuracy
    thresholds = np.linspace(
        similarity_scores.min() - epsilon,
        similarity_scores.max() + epsilon,
        n_linspace
    )
    preds = similarity_scores[None, :] >= thresholds[:, None]
    accs = (preds == targets).mean(axis=1)
    best_acc = accs.max()
    best_th = thresholds[accs.argmax()]

    # ROC & TAR
    roc_auc = arc_scores.compute_roc_auc(similarity_scores, targets)["auc"]
    tar_far1 = arc_scores.tar_at_far(similarity_scores, targets)
    tar_far2 = arc_scores.tar_at_far(similarity_scores, targets, 1e-4)

    return {
        "accuracy": float(best_acc),
        "roc_auc": float(roc_auc),
        "tar_far1": float(tar_far1),
        "tar_far2": float(tar_far2),
        "threshold": float(best_th),
        "pos_samples": len(pairs) // 2,
        "neg_samples": len(pairs) // 2
    }

In [11]:
def step_lr(optimizer, base_lr_backbone, base_lr_margin, epoch,
            milestones=[10, 15], gamma=0.1):
    lr_scale = 1.0
    for milestone in milestones:
        if epoch >= milestone:
            lr_scale *= gamma

    lr_backbone = base_lr_backbone * lr_scale
    lr_margin = base_lr_margin * lr_scale

    optimizer.param_groups[0]["lr"] = lr_backbone
    optimizer.param_groups[1]["lr"] = lr_margin

    return lr_backbone, lr_margin

In [12]:
# =============================================================================
# EARLY STOPPING
# =============================================================================
class EarlyStopping:
    def __init__(self, patience=5, epsilon=0.001, save_path="best.pt"):
        self.patience = patience
        self.save_path = save_path
        self.epsilon = epsilon
        self.best_acc = -1
        self.counter = 0
        self.should_stop = False

    def step(self, val_acc, model, margin):
        if val_acc > self.best_acc + self.epsilon:
            self.best_acc = val_acc
            self.counter = 0

            model_to_save = model.module if hasattr(model, 'module') else model
            margin_to_save = margin.module if hasattr(margin, 'module') else margin

            checkpoint = {
                'model_state_dict': model.state_dict(),
                'margin_state_dict': margin.state_dict(),
                'best_tar_far': self.best_acc
            }
            torch.save(checkpoint, self.save_path)
            print(f"✓ Saved best model: TAR@FAR={val_acc:.4f}")
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
                print("⚠ Early stopping triggered!")

In [13]:
# =============================================================================
# MODEL & OPTIMIZER
# =============================================================================
model = get_mbf(fp16=False, num_features=512).to(device)
margin = SubCenterArcFace(
    in_features=FEATURE_DIM,
    out_features=NUM_CLASSES,
    s=64.0,
    m=0.4
).to(device)

if n_gpus > 1:
    print(f"Using DataParallel with {n_gpus} GPUs")
    model = nn.DataParallel(model)
    margin = nn.DataParallel(margin)

model = model.to(device)
margin = margin.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD([
    {"params": model.parameters(), "lr": base_lr_backbone},
    {"params": margin.parameters(), "lr": base_lr_margin}
], momentum=0.9, weight_decay=5e-4)

steps_per_epoch = len(train_loader)

scaler = GradScaler()

Using DataParallel with 2 GPUs


  scaler = GradScaler()


In [14]:
def load_checkpoint(path):
    start_epoch = 0
    resume_path = path

    if os.path.exists(resume_path):
        print(f"==> Loading checkpoint from {resume_path}")
        checkpoint = torch.load(resume_path, map_location=device, weights_only=False)

        model_to_load = model.module if hasattr(model, 'module') else model
        margin_to_load = margin.module if hasattr(margin, 'module') else margin

        model_to_load.load_state_dict(checkpoint['model_state_dict'])
        margin_to_load.load_state_dict(checkpoint['margin_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        start_epoch = checkpoint['epoch'] + 1
        print(f"==> Resuming from epoch {start_epoch}")

    return start_epoch

In [15]:
def train():
    train_losses = []
    accs = []
    rocs = []
    tfs = []
    
    early = EarlyStopping(patience=5, epsilon=0.001)
    
    start_epoch = load_checkpoint("//kaggle/input/checkpoint12/checkpoint_epoch_12.pt")
    
    for epoch in range(start_epoch, NUM_EPOCHS):
        model.train()
        margin.train()

        train_loss = 0
        
        # Get initial LR for display
        lr_backbone, lr_margin = step_lr(
            optimizer, base_lr_backbone, base_lr_margin,
            epoch, milestones=step_milestones, gamma=step_gamma
        )
        pbar = tqdm(
            enumerate(train_loader), 
            total=len(train_loader), 
            desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [LR: {lr_backbone:.6f}]"
        )

        for step, (inputs, targets) in pbar:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad()

            # Forward (with mixed precision if enabled)
            try:
                # Try mixed precision first
                from torch.cuda.amp import autocast, GradScaler
                if not hasattr(train, '_scaler_initialized'):
                    train._scaler = GradScaler()
                    train._scaler_initialized = True
                
                with autocast():
                    outputs = model(inputs)
                    outputs = F.normalize(outputs, p=2, dim=1)
                    logits = margin(outputs, targets)
                    loss = criterion(logits, targets)
                
                # Backward with mixed precision
                train._scaler.scale(loss).backward()
                train._scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                torch.nn.utils.clip_grad_norm_(margin.parameters(), 5.0)
                train._scaler.step(optimizer)
                train._scaler.update()
                
            except:
                # Fallback to FP32 if mixed precision fails
                outputs = model(inputs)
                outputs = F.normalize(outputs, p=2, dim=1)
                logits = margin(outputs, targets)
                loss = criterion(logits, targets)
                
                # Backward
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                torch.nn.utils.clip_grad_norm_(margin.parameters(), 5.0)
                optimizer.step()

            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_train_loss = train_loss / len(train_loader)

        # ==================== EVALUATION ====================
        model.eval()
        margin.eval()

        embs = []
        labels_list = []

        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Evaluating"):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)

                outputs = model(inputs)
                outputs = F.normalize(outputs, p=2, dim=1)

                embs.append(outputs.cpu())
                labels_list.append(targets.cpu())

        eval_res = evaluate(embs, labels_list, max_per_class=50, n_linspace=1000)
        tar_far1 = eval_res["tar_far1"]
        tar_far2 = eval_res["tar_far2"]

        # Append
        train_losses.append(avg_train_loss)
        accs.append(eval_res['accuracy'])
        rocs.append(eval_res['roc_auc'])
        tfs.append(tar_far1)

        # Get current learning rates
        current_lr_backbone = optimizer.param_groups[0]['lr']
        current_lr_margin = optimizer.param_groups[1]['lr']

        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        print(f"Learning Rate: Backbone={current_lr_backbone:.6f}, Margin={current_lr_margin:.6f}")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Eval Metrics:")
        print(f"  - Accuracy: {eval_res['accuracy']:.4f}")
        print(f"  - ROC AUC: {eval_res['roc_auc']:.4f}")
        print(f"  - TAR@FAR1e-3: {eval_res['tar_far1']:.4f}")
        print(f"  - TAR@FAR1e-4: {eval_res['tar_far2']:.4f}")
        print(f"  - Threshold: {eval_res['threshold']:.4f}")
        print(f"{'='*60}\n")

        # Early Stopping (with multi-GPU support)
        early.step(tar_far1, model, margin)
        
        # Save checkpoint every 5 epochs (unwrap DataParallel)
        model_to_save = model.module if hasattr(model, 'module') else model
        margin_to_save = margin.module if hasattr(margin, 'module') else margin
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_to_save.state_dict(),
            'margin_state_dict': margin_to_save.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'tar_far1e-3': tar_far1,
            'tar_far1e-4': tar_far2,
            'n_gpus': n_gpus
        }, f'checkpoint_epoch_{epoch+1}.pt')
        print(f"✓ Saved checkpoint at epoch {epoch+1}")
        
        if early.should_stop:
            print("⚠ Training stopped early.")
            break

    print(f"\n✓ Training completed! Best TAR@FAR: {early.best_acc:.4f}")

    return train_losses, accs, rocs, tfs

train()

==> Loading checkpoint from //kaggle/input/checkpoint12/checkpoint_epoch_12.pt
==> Resuming from epoch 12


  train._scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 13/25 [LR: 0.010000]: 100%|██████████| 4375/4375 [44:28<00:00,  1.64it/s, loss=4.7729]
Evaluating: 100%|██████████| 450/450 [04:35<00:00,  1.63it/s]



Epoch 13/25
Learning Rate: Backbone=0.010000, Margin=0.050000
Train Loss: 5.8582
Eval Metrics:
  - Accuracy: 0.9542
  - ROC AUC: 0.9869
  - TAR@FAR1e-3: 0.8162
  - TAR@FAR1e-4: 0.6785
  - Threshold: 0.2218

✓ Saved best model: TAR@FAR=0.8162
✓ Saved checkpoint at epoch 13


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 14/25 [LR: 0.001000]: 100%|██████████| 4375/4375 [32:46<00:00,  2.23it/s, loss=3.5445]
Evaluating: 100%|██████████| 450/450 [01:41<00:00,  4.44it/s]



Epoch 14/25
Learning Rate: Backbone=0.001000, Margin=0.005000
Train Loss: 3.9486
Eval Metrics:
  - Accuracy: 0.9606
  - ROC AUC: 0.9887
  - TAR@FAR1e-3: 0.8572
  - TAR@FAR1e-4: 0.7498
  - Threshold: 0.2208

✓ Saved best model: TAR@FAR=0.8572
✓ Saved checkpoint at epoch 14


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 15/25 [LR: 0.001000]: 100%|██████████| 4375/4375 [32:46<00:00,  2.22it/s, loss=4.6483]
Evaluating: 100%|██████████| 450/450 [02:00<00:00,  3.72it/s]



Epoch 15/25
Learning Rate: Backbone=0.001000, Margin=0.005000
Train Loss: 3.3805
Eval Metrics:
  - Accuracy: 0.9608
  - ROC AUC: 0.9887
  - TAR@FAR1e-3: 0.8578
  - TAR@FAR1e-4: 0.7616
  - Threshold: 0.2184

✓ Saved checkpoint at epoch 15


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 16/25 [LR: 0.001000]: 100%|██████████| 4375/4375 [32:45<00:00,  2.23it/s, loss=3.3048]
Evaluating: 100%|██████████| 450/450 [02:09<00:00,  3.49it/s]



Epoch 16/25
Learning Rate: Backbone=0.001000, Margin=0.005000
Train Loss: 3.2487
Eval Metrics:
  - Accuracy: 0.9610
  - ROC AUC: 0.9887
  - TAR@FAR1e-3: 0.8601
  - TAR@FAR1e-4: 0.7598
  - Threshold: 0.2201

✓ Saved best model: TAR@FAR=0.8601
✓ Saved checkpoint at epoch 16


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 17/25 [LR: 0.001000]: 100%|██████████| 4375/4375 [32:43<00:00,  2.23it/s, loss=4.5213]
Evaluating: 100%|██████████| 450/450 [01:50<00:00,  4.08it/s]



Epoch 17/25
Learning Rate: Backbone=0.001000, Margin=0.005000
Train Loss: 3.1971
Eval Metrics:
  - Accuracy: 0.9609
  - ROC AUC: 0.9887
  - TAR@FAR1e-3: 0.8613
  - TAR@FAR1e-4: 0.7613
  - Threshold: 0.2185

✓ Saved best model: TAR@FAR=0.8613
✓ Saved checkpoint at epoch 17


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 18/25 [LR: 0.001000]: 100%|██████████| 4375/4375 [33:42<00:00,  2.16it/s, loss=3.9364]
Evaluating: 100%|██████████| 450/450 [02:03<00:00,  3.65it/s]



Epoch 18/25
Learning Rate: Backbone=0.001000, Margin=0.005000
Train Loss: 3.1717
Eval Metrics:
  - Accuracy: 0.9610
  - ROC AUC: 0.9887
  - TAR@FAR1e-3: 0.8614
  - TAR@FAR1e-4: 0.7657
  - Threshold: 0.2172

✓ Saved checkpoint at epoch 18


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 19/25 [LR: 0.001000]: 100%|██████████| 4375/4375 [32:44<00:00,  2.23it/s, loss=6.1081]
Evaluating: 100%|██████████| 450/450 [01:59<00:00,  3.77it/s]



Epoch 19/25
Learning Rate: Backbone=0.001000, Margin=0.005000
Train Loss: 3.1541
Eval Metrics:
  - Accuracy: 0.9609
  - ROC AUC: 0.9886
  - TAR@FAR1e-3: 0.8620
  - TAR@FAR1e-4: 0.7603
  - Threshold: 0.2167

✓ Saved checkpoint at epoch 19


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 20/25 [LR: 0.000100]: 100%|██████████| 4375/4375 [32:44<00:00,  2.23it/s, loss=3.2649]
Evaluating: 100%|██████████| 450/450 [01:50<00:00,  4.06it/s]



Epoch 20/25
Learning Rate: Backbone=0.000100, Margin=0.000500
Train Loss: 2.8148
Eval Metrics:
  - Accuracy: 0.9610
  - ROC AUC: 0.9886
  - TAR@FAR1e-3: 0.8634
  - TAR@FAR1e-4: 0.7688
  - Threshold: 0.2189

✓ Saved best model: TAR@FAR=0.8634
✓ Saved checkpoint at epoch 20


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 21/25 [LR: 0.000100]: 100%|██████████| 4375/4375 [32:47<00:00,  2.22it/s, loss=2.4869]
Evaluating: 100%|██████████| 450/450 [02:09<00:00,  3.47it/s]



Epoch 21/25
Learning Rate: Backbone=0.000100, Margin=0.000500
Train Loss: 2.7411
Eval Metrics:
  - Accuracy: 0.9612
  - ROC AUC: 0.9886
  - TAR@FAR1e-3: 0.8645
  - TAR@FAR1e-4: 0.7675
  - Threshold: 0.2184

✓ Saved best model: TAR@FAR=0.8645
✓ Saved checkpoint at epoch 21


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 22/25 [LR: 0.000100]: 100%|██████████| 4375/4375 [32:48<00:00,  2.22it/s, loss=2.4036]
Evaluating: 100%|██████████| 450/450 [01:52<00:00,  4.01it/s]



Epoch 22/25
Learning Rate: Backbone=0.000100, Margin=0.000500
Train Loss: 2.7059
Eval Metrics:
  - Accuracy: 0.9611
  - ROC AUC: 0.9886
  - TAR@FAR1e-3: 0.8651
  - TAR@FAR1e-4: 0.7657
  - Threshold: 0.2179

✓ Saved checkpoint at epoch 22


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 23/25 [LR: 0.000100]: 100%|██████████| 4375/4375 [32:45<00:00,  2.23it/s, loss=3.5952]
Evaluating: 100%|██████████| 450/450 [02:18<00:00,  3.24it/s]



Epoch 23/25
Learning Rate: Backbone=0.000100, Margin=0.000500
Train Loss: 2.6866
Eval Metrics:
  - Accuracy: 0.9611
  - ROC AUC: 0.9886
  - TAR@FAR1e-3: 0.8644
  - TAR@FAR1e-4: 0.7688
  - Threshold: 0.2170

✓ Saved checkpoint at epoch 23


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 24/25 [LR: 0.000100]: 100%|██████████| 4375/4375 [32:53<00:00,  2.22it/s, loss=3.1970]
Evaluating: 100%|██████████| 450/450 [01:56<00:00,  3.87it/s]



Epoch 24/25
Learning Rate: Backbone=0.000100, Margin=0.000500
Train Loss: 2.6705
Eval Metrics:
  - Accuracy: 0.9612
  - ROC AUC: 0.9886
  - TAR@FAR1e-3: 0.8639
  - TAR@FAR1e-4: 0.7652
  - Threshold: 0.2168

✓ Saved checkpoint at epoch 24


  with autocast():
  with torch.cuda.amp.autocast(self.fp16):
Epoch 25/25 [LR: 0.000100]: 100%|██████████| 4375/4375 [33:12<00:00,  2.20it/s, loss=4.1574]
Evaluating: 100%|██████████| 450/450 [02:12<00:00,  3.38it/s]



Epoch 25/25
Learning Rate: Backbone=0.000100, Margin=0.000500
Train Loss: 2.6607
Eval Metrics:
  - Accuracy: 0.9611
  - ROC AUC: 0.9885
  - TAR@FAR1e-3: 0.8645
  - TAR@FAR1e-4: 0.7623
  - Threshold: 0.2168

✓ Saved checkpoint at epoch 25

✓ Training completed! Best TAR@FAR: 0.8645


([5.858192835780552,
  3.9486011795043945,
  3.3805037422725133,
  3.2486515484401157,
  3.1970557082584925,
  3.171720915930612,
  3.1540656592777796,
  2.814766678128924,
  2.741131239782061,
  2.7058794523784093,
  2.686638185828073,
  2.6704825529643466,
  2.660726494925363],
 [0.9542366021509824,
  0.9605557221418833,
  0.9607728588035622,
  0.960950851781333,
  0.9608577932120421,
  0.960950851781333,
  0.9608718258534431,
  0.9609981196260523,
  0.961238151650017,
  0.961119982038219,
  0.9611495244411685,
  0.9611731583635281,
  0.9611347532396938],
 [0.9869045477417863,
  0.9887044691870086,
  0.9886916414637344,
  0.9887142933921392,
  0.9886622791840307,
  0.9887055752319741,
  0.9886359686595606,
  0.9885553221013806,
  0.9886280377569368,
  0.9885501707541513,
  0.9886482643555302,
  0.9885761871050297,
  0.9885100781241934],
 [0.816184214607832,
  0.857236337746476,
  0.8577873035614844,
  0.8601122906736111,
  0.8613412546363108,
  0.861366365678818,
  0.8619645993385456