In [None]:
import torch
import torchvision.transforms as T
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from data_loader import get_dataloader
from models.multitask_models import MultitaskModel, MobileNetV2Backbone
from metrics import compute_metrics

In [None]:
INTRA = True
YEAR = 'foreground_2015'
SENSOR = 'CrossMatch'
DATASET_PATH = '/home/hmb1604/datasets/LivDet'
BINARY_CLASS = False

BATCH_SIZE = 8
NUM_WORKERS = 4

LR = 1e-3
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 1

SPOOF_WEIGHT = 3.6
MATERIAL_WEIGHT = 1.0
THRESHOLD = 0.5

MODEL_SAVE_PATH = './ckpts/model.pth'
os.makedirs('./ckpts', exist_ok=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = {
    'Train': T.Compose([
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomAffine(
            degrees=(-20, 20),          # Rotation
            translate=(0.2, 0.2),       # Horizontal/vertical shift
            shear=(-20, 20),            # Shear
            scale=(0.8, 1.2),           # Zoom
            interpolation=T.InterpolationMode.NEAREST,
            fill=0
        ),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),

    'Test': T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}

In [None]:
train_loader, val_loader, train_label_map = get_dataloader(intra=INTRA, year=YEAR, sensor=SENSOR, dataset_path=DATASET_PATH, train=True, binary_class=BINARY_CLASS, transform=transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_loader, test_label_map = get_dataloader(intra=INTRA, year=YEAR, sensor=SENSOR, dataset_path=DATASET_PATH, train=False, binary_class=BINARY_CLASS, transform=transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
backbone = MobileNetV2Backbone()
model = MultitaskModel(feature_extractor=backbone, num_material_classes=len(train_label_map)-1).to(device)
model = torch.nn.DataParallel(model)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
spoof_criterion = torch.nn.BCEWithLogitsLoss()
material_criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

In [None]:
history = {
    'train_total_loss': [],
    'train_spoof_loss': [],
    'train_material_loss': [],
    'train_spoof_acc': [],
    'train_material_acc': [],
    'val_total_loss': [],
    'val_spoof_loss': [],
    'val_material_loss': [],
    'val_spoof_acc': [],
    'val_material_acc': [],
    'lr': []
}

best_val_loss = float('inf')
best_model_state = None

In [None]:
# Training loop
for epoch in range(NUM_EPOCHS):
    print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
    print('-' * 36)
    
    # Training phase
    model.train()
    running_spoof_loss = 0.0
    running_material_loss = 0.0
    running_total_loss = 0.0
    running_correct_spoof = 0
    running_correct_material = 0
    total_samples = 0
    total_spoof_samples = 0
    
    for imgs, labels in tqdm(train_loader, desc="train"):
        imgs = imgs.to(device)
        labels = labels.to(device, dtype=torch.float)
        
        spoof_labels = (labels > 0).float().unsqueeze(1)
        
        material_labels = (labels - 1).long()
        
        optimizer.zero_grad()
        
        spoof_outputs, material_outputs = model(imgs)
        
        spoof_loss = spoof_criterion(spoof_outputs, spoof_labels)
        
        spoof_indices = (spoof_labels.squeeze() == 1).nonzero(as_tuple=True)[0]
        if len(spoof_indices) > 0:
            spoof_material_outputs = material_outputs[spoof_indices]
            spoof_material_labels = material_labels[spoof_indices]
            material_loss = material_criterion(spoof_material_outputs, spoof_material_labels)
        else:
            material_loss = torch.tensor(0.0, device=device, requires_grad=True)
        
        total_loss = SPOOF_WEIGHT * spoof_loss + MATERIAL_WEIGHT * material_loss
        
        total_loss.backward()
        optimizer.step()
        
        running_spoof_loss += spoof_loss.item()
        running_material_loss += material_loss.item()
        running_total_loss += total_loss.item()
        
        spoof_preds = (spoof_outputs > THRESHOLD).float()
        running_correct_spoof += (spoof_preds == spoof_labels).sum().item()
        
        if len(spoof_indices) > 0:
            spoof_material_preds = torch.argmax(spoof_material_outputs, dim=1)
            running_correct_material += (spoof_material_preds == spoof_material_labels).sum().item()
            total_spoof_samples += len(spoof_indices)
        
        total_samples += imgs.size(0)
    
    epoch_spoof_loss = running_spoof_loss / len(train_loader)
    epoch_material_loss = running_material_loss / len(train_loader)
    epoch_total_loss = running_total_loss / len(train_loader)
    epoch_spoof_acc = (running_correct_spoof / total_samples) * 100.0
    epoch_material_acc = (running_correct_material / total_spoof_samples) * 100.0 if total_spoof_samples > 0 else 0.0
    
    print(f'Train Loss: Total=[{epoch_total_loss:.4f}] Spoof=[{epoch_spoof_loss:.4f}] Material=[{epoch_material_loss:.4f}]')
    print(f'Train Acc: Spoof=[{epoch_spoof_acc:.2f}] Material=[{epoch_material_acc:.2f}]')
    
    # Validation phase
    model.eval()
    val_spoof_loss = 0.0
    val_material_loss = 0.0
    val_total_loss = 0.0
    val_correct_spoof = 0
    val_correct_material = 0
    val_total_samples = 0
    val_spoof_samples_total = 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="val"):
            imgs = imgs.to(device)
            labels = labels.to(device, dtype=torch.float)
            
            spoof_labels = (labels > 0).float().unsqueeze(1)
            material_labels = (labels - 1).long()
            
            spoof_outputs, material_outputs = model(imgs)
            
            spoof_loss = spoof_criterion(spoof_outputs, spoof_labels)
            
            spoof_indices = (spoof_labels.squeeze() == 1).nonzero(as_tuple=True)[0]
            if len(spoof_indices) > 0:
                spoof_material_outputs = material_outputs[spoof_indices]
                spoof_material_labels = material_labels[spoof_indices]
                material_loss = material_criterion(spoof_material_outputs, spoof_material_labels)
            else:
                material_loss = torch.tensor(0.0, device=device)
            
            total_loss = SPOOF_WEIGHT * spoof_loss + MATERIAL_WEIGHT * material_loss
            
            val_spoof_loss += spoof_loss.item()
            val_material_loss += material_loss.item()
            val_total_loss += total_loss.item()
            
            spoof_preds = (spoof_outputs > THRESHOLD).float()
            val_correct_spoof += (spoof_preds == spoof_labels).sum().item()
            
            if len(spoof_indices) > 0:
                spoof_material_preds = torch.argmax(spoof_material_outputs, dim=1)
                val_correct_material += (spoof_material_preds == spoof_material_labels).sum().item()
            
            val_total_samples += imgs.size(0)
            val_spoof_samples_total += spoof_labels.sum().item()
    
    val_epoch_spoof_loss = val_spoof_loss / len(val_loader)
    val_epoch_material_loss = val_material_loss / len(val_loader)
    val_epoch_total_loss = val_total_loss / len(val_loader)
    val_epoch_spoof_acc = (val_correct_spoof / val_total_samples) * 100.0
    val_epoch_material_acc = (val_correct_material / val_spoof_samples_total) * 100.0 if val_spoof_samples_total > 0 else 0.0
    
    print(f'Val Loss: Total=[{val_epoch_total_loss:.4f}] Spoof=[{val_epoch_spoof_loss:.4f}] Material=[{val_epoch_material_loss:.4f}]')
    print(f'Val Acc: Spoof=[{val_epoch_spoof_acc:.2f}] Material=[{val_epoch_material_acc:.2f}]')
    print()
    
    history['train_total_loss'].append(epoch_total_loss)
    history['train_spoof_loss'].append(epoch_spoof_loss)
    history['train_material_loss'].append(epoch_material_loss)
    history['train_spoof_acc'].append(epoch_spoof_acc)
    history['train_material_acc'].append(epoch_material_acc)
    
    history['val_total_loss'].append(val_epoch_total_loss)
    history['val_spoof_loss'].append(val_epoch_spoof_loss)
    history['val_material_loss'].append(val_epoch_material_loss)
    history['val_spoof_acc'].append(val_epoch_spoof_acc)
    history['val_material_acc'].append(val_epoch_material_acc)
    
    current_lr = optimizer.param_groups[0]['lr']
    history['lr'].append(current_lr)
    
    if val_epoch_total_loss < best_val_loss:
        best_val_loss = val_epoch_total_loss
        best_model_state = model.state_dict().copy()
        print(f"New best model found! Saving to {MODEL_SAVE_PATH}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': best_model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'history': history
        }, MODEL_SAVE_PATH)
    
    scheduler.step()

if best_model_state is not None:
    model.load_state_dict(best_model_state)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))

ax1.plot(history['train_total_loss'], label='Train Total Loss')
ax1.plot(history['train_spoof_loss'], label='Train Spoof Loss')
ax1.plot(history['train_material_loss'], label='Train Material Loss')
ax1.plot(history['val_total_loss'], label='Val Total Loss')
ax1.plot(history['val_spoof_loss'], label='Val Spoof Loss')
ax1.plot(history['val_material_loss'], label='Val Material Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)
ax1.set_title('Training and Validation Losses')

ax2.plot(history['train_spoof_acc'], label='Train Spoof Acc')
ax2.plot(history['train_material_acc'], label='Train Material Acc')
ax2.plot(history['val_spoof_acc'], label='Val Spoof Acc')
ax2.plot(history['val_material_acc'], label='Val Material Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True)
ax2.set_title('Training and Validation Accuracies')

ax3.plot(history['lr'], label='Learning Rate')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.legend()
ax3.grid(True)
ax3.set_title('Learning Rate Schedule')

plt.tight_layout()
plt.show()

In [None]:
# Testing phase
model.eval()
all_labels = []
all_probabilities = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device, dtype=torch.float)

        spoof_labels = (labels > 0).float()
        
        spoof_outputs, material_outputs = model(imgs)

        probabilities = torch.sigmoid(spoof_outputs.squeeze(1))
        
        all_labels.extend(spoof_labels.cpu().numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

labels = np.array(all_labels).astype(int)
probabilities = np.array(all_probabilities)
predictions = (probabilities >= THRESHOLD).astype(int)

In [None]:
apcer, bpcer, ace, accuracy = compute_metrics(labels, predictions)
print(f"APCER:      {apcer*100:.2f}%")
print(f"BPCER:      {bpcer*100:.2f}%")
print(f"ACE:        {ace*100:.2f}%")
print(f"Accuracy:   {accuracy*100:.2f}%")
print(f"Accuracy*:  {(1-ace)*100:.2f}%")