In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from utils import IntraSensorMulticlassDataset, get_train_validation_split
from torch.utils.data import DataLoader
from models import DualModel
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import numpy as np

In [None]:
validation_split = 0.1
batch_size = 32
num_workers = 4
pin_memory = True
lr = 0.001

num_epochs = 1  #50

train_dir = './LivDet/2013/Training/CrossMatchTrain'
test_dir = './LivDet/2013/Testing/CrossMatchTest'
model_path = './ckpts/model.pth'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.RandomAffine(
        degrees=(-20, 20),          # Rotation
        translate=(0.2, 0.2),       # Horizontal/vertical shift
        shear=(-20, 20),            # Shear
        scale=(0.8, 1.2),           # Zoom
        interpolation=InterpolationMode.NEAREST,
        fill=0
    ),
    T.RandomHorizontalFlip(p=0.5),  # Horizontal flip
    T.RandomVerticalFlip(p=0.5)     # Vertical flip
])

test_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

In [None]:
train_dataset = IntraSensorMulticlassDataset(data_dir=train_dir, transform=train_transform)
train_set, val_set = get_train_validation_split(dataset=train_dataset, validation_split=validation_split)
test_set = IntraSensorMulticlassDataset(data_dir=test_dir, transform=test_transform)

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

In [None]:
model = DualModel(num_classes=train_dataset.num_classes).to(device)
model = nn.DataParallel(model)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
train_losses = []
val_losses = []
best_val_loss = 10.0

In [None]:
print('Training begin')

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_running_loss = 0.0
    
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device, dtype=torch.float)
        
        optimizer.zero_grad()
        
        outputs = model(data)
        loss = criterion(outputs, targets.long())
        
        loss.backward()
        optimizer.step()
        
        train_running_loss += loss.item()

        if batch_idx == 0 or batch_idx == len(train_loader) - 1 or batch_idx == round(len(train_loader) / 2):
            print(f'Epoch [{epoch+1}/{num_epochs}] Batch Index [{batch_idx+1}/{len(train_loader)}]   Loss: {loss:.4f}')

    avg_train_loss = train_running_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_running_loss = 0.0
    
    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device, dtype=torch.float)
            
            outputs = model(data)
            loss = criterion(outputs, targets.long())
            val_running_loss += loss.item()
    
    avg_val_loss = val_running_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    print(f'Epoch [{epoch+1}/{num_epochs}]   Train Loss: {avg_train_loss:.4f}   Val Loss: {avg_val_loss:.4f}')

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), model_path)
        print(f'Best val loss: {best_val_loss:.4f}. Model saved!')

print('Traing done')

In [None]:
# Plot losses
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
model.load_state_dict(torch.load(model_path))

In [None]:
# Testing phase
model.eval()

all_targets = []
all_predictions = []

with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device, dtype=torch.float)
        
        outputs = model(data)
        predictions = torch.argmax(outputs, dim=1)
        
        all_targets.extend(targets.cpu().numpy())
        all_predictions.extend(predictions.cpu().numpy())

targets = np.array(all_targets > 0).astype(int)
predictions = np.array(all_predictions > 0).astype(int)

In [None]:
live_mask = (targets == 0)
spoof_mask = (targets == 1)

live_count = np.sum(live_mask)
spoof_count = np.sum(spoof_mask)

spoof_predictions = predictions[spoof_mask]
apcer = np.sum(1 - spoof_predictions) / spoof_count if spoof_count > 0 else 0

live_predictions = predictions[live_mask]
bpcer = np.sum(live_predictions) / live_count if live_count > 0 else 0

acer = (apcer + bpcer) / 2

acc = 1 - acer

accuracy = accuracy_score(targets, predictions) * 100

print(f"APCER: {apcer*100:.2f}%")
print(f"BPCER: {bpcer*100:.2f}%")
print(f"ACER:  {acer*100:.2f}%")
print(f"ACC*:  {acc*100:.2f}%")
print(f"ACC:   {accuracy:.2f}%")