In [None]:
import torch
import torchvision.transforms as T
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from data_loader import get_dataloader
from models.dual_model import DualModel
from metrics import find_optimal_threshold

In [None]:
INTRA = True
YEAR = 'foreground_2015'
SENSOR = 'CrossMatch'
DATASET_PATH = '/home/hmb1604/datasets/LivDet'
BINARY_CLASS = True

BATCH_SIZE = 8
NUM_WORKERS = 4

LR = 1e-3
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 1

MODEL_SAVED_PATH = './ckpts/model.pth'
os.makedirs('./ckpts', exist_ok=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = {
    'Train': T.Compose([
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomAffine(
            degrees=(-20, 20),          # Rotation
            translate=(0.2, 0.2),       # Horizontal/vertical shift
            shear=(-20, 20),            # Shear
            scale=(0.8, 1.2),           # Zoom
            interpolation=T.InterpolationMode.NEAREST,
            fill=0
        ),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),

    'Test': T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}

In [None]:
train_loader, val_loader, train_label_map = get_dataloader(intra=INTRA, year=YEAR, sensor=SENSOR, dataset_path=DATASET_PATH, train=True, binary_class=BINARY_CLASS, transform=transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_loader, test_label_map = get_dataloader(intra=INTRA, year=YEAR, sensor=SENSOR, dataset_path=DATASET_PATH, train=False, binary_class=BINARY_CLASS, transform=transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
model = DualModel(num_classes=1).to(device=device)
model = torch.nn.DataParallel(model)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = torch.nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

In [None]:
history = {
    'train_losses': [],
    'val_losses': [],
    'best_val_loss': 1e3,
}

In [None]:
print('Begin training...')

for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()
    total_train_loss = 0.0
    
    for idx, (imgs, labels) in enumerate(tqdm(train_loader, desc=f"train epoch [{epoch+1}/{NUM_EPOCHS}]")):
        imgs, labels = imgs.to(device), labels.to(device, dtype=torch.float)
        
        optimizer.zero_grad()
        
        outputs = model(imgs)
        loss = criterion(outputs.squeeze(1), labels)
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    history['train_losses'].append(avg_train_loss)
    
    # Validation phase
    model.eval()
    total_val_loss = 0.0
    
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"val epoch [{epoch+1}/{NUM_EPOCHS}]"):
            imgs, labels = imgs.to(device), labels.to(device, dtype=torch.float)
            
            outputs = model(imgs)
            loss = criterion(outputs.squeeze(1), labels)
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    history['val_losses'].append(avg_val_loss)

    scheduler.step()
    
    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss [{avg_train_loss:.4f}] Val Loss: [{avg_val_loss:.4f}]')

    if avg_val_loss < history['best_val_loss']:
        history['best_val_loss'] = avg_val_loss
        torch.save(model.state_dict(), MODEL_SAVED_PATH)
        print(f'Model saved!')

    print('=' * 63)

print('Finish training...')

In [None]:
# Plot losses
plt.plot(history['train_losses'], label='Train Loss')
plt.plot(history['val_losses'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
model.load_state_dict(torch.load(MODEL_SAVED_PATH))

# Testing phase
model.eval()
all_labels = []
all_probabilities = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device, dtype=torch.float)
        
        outputs = model(imgs)
        probabilities = torch.sigmoid(outputs.squeeze(1))
        
        all_labels.extend(labels.cpu().numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

labels = np.array(all_labels).astype(int)
probabilities = np.array(all_probabilities)

In [None]:
threshold, apcer, bpcer, accuracy, ace, accuracy = find_optimal_threshold(labels, probabilities, based_on="ace")
print(f"APCER:      {apcer*100:.2f}%")
print(f"BPCER:      {bpcer*100:.2f}%")
print(f"ACE:        {ace*100:.2f}%")
print(f"Accuracy:   {accuracy*100:.2f}%")
print(f"Accuracy*:  {(1-ace)*100:.2f}%")