In [5]:
import os

import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from going_modular.dataloader.magface import create_concat_magface_dataloader
from going_modular.model.MagFaceRecognition import MagFaceConcatRecognition
from going_modular.train_eval.train import fit
from going_modular.loss.MagLoss import MagLoss
from going_modular.utils.MultiMetricEarlyStopping import MultiMetricEarlyStopping
from going_modular.utils.ModelCheckPoint import ModelCheckpoint

import albumentations as A

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Đặt seed toàn cục
seed = 42
torch.manual_seed(seed)

CONFIGURATION = {
    # Thư mục
    'type': 'concat',
    'train_dir': './Dataset/Albedo/train',
    'test_dir': './Dataset/Albedo/test',
    'dataset_dir': './Dataset',
    'albedo_checkpoint': './checkpoint/magface/albedo/models/best_val_cosine_accuracy_270.pth',
    'depthmap_checkpoint': './checkpoint/magface/depthmap/models/best_val_cosine_accuracy_215.pth',
    'normalmap_checkpoint': './checkpoint/magface/normalmap/models/best_val_cosine_accuracy_195.pth',
    
    # Cấu hình train
    'backbone': 'iresnet18',
    'epochs': 2000,
    'num_workers': 4,
    'batch_size': 16,
    'image_size': 224,
    'num_class': len(os.listdir('./Dataset/Albedo/train')),
    'embedding_size': 512,
    'num_cpus': 10,
    
    'learning_rate': 0.1,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'alpha': 0.9,
    
    # Hàm m(ai) giúp thay đổi ai từ 0.25 đến 1.6
    'scale': 64,
    'lambda_g': 20,
    'l_margin': 0.45, 
    'u_margin': 0.8,
    'l_a': 10, 
    'u_a': 110,
}

train_transform = A.Compose([
    A.RandomCrop(CONFIGURATION['image_size'],CONFIGURATION['image_size'])
], additional_targets={
    'depthmap': 'image',
    'normalmap': 'image'   
})

test_transform = A.Compose([
    A.Resize(CONFIGURATION['image_size'], CONFIGURATION['image_size'], always_apply=True)
], additional_targets={
    'depthmap': 'image',
    'normalmap': 'image'
})

train_dataloader, val_dataloader = create_concat_magface_dataloader(CONFIGURATION, train_transform, test_transform)

In [6]:
model = MagFaceConcatRecognition(CONFIGURATION).to(device)
criterion = MagLoss(conf = CONFIGURATION)

optimizer = torch.optim.Adam(model.parameters(), lr=CONFIGURATION['learning_rate'])
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6)
checkpoint_path = os.path.abspath('checkpoint/magface/'+ CONFIGURATION['type'] + '/models/checkpoint.pth')
modle_checkpoint = ModelCheckpoint(filepath=checkpoint_path, verbose=1)
earlystop_dir = os.path.abspath('checkpoint/magface/'+ CONFIGURATION['type'] + '/models')
early_stopping = MultiMetricEarlyStopping(
    monitor_keys=['val_euclidean_accuracy', 'val_cosine_accuracy', 'val_auc_euclidean', 'val_auc_cosine'],
    patience=40,
    mode='max',
    verbose=1,
    save_dir=earlystop_dir,
    start_from_epoch=0
)

29s: Generate finish

In [7]:
fit(CONFIGURATION, 0, model, device, train_dataloader, val_dataloader, criterion, optimizer, scheduler, early_stopping, modle_checkpoint)

Epoch 1:
	train: loss 3.022 | loss id   3.00 | top_1_acc 0.4454 | top_5_acc 0.5747 | acc_eu: 0.967 | acc_cos: 0.968 | auc_eu: 0.994 | auc_cos: 0.997
	val: acc_eu: 0.806 | acc_cos: 0.769 | auc_eu: 0.821 | auc_cos: 0.876
[36m	Saving model and optimizer state to /media/vohoang/WorkSpace/ubuntu/projects/in-process/Bachelor-s-Project/checkpoint/magface/concat/models/checkpoint.pth[0m
Epoch 2:
	train: loss 1.836 | loss id   1.81 | top_1_acc 0.8759 | top_5_acc 0.9537 | acc_eu: 0.983 | acc_cos: 0.993 | auc_eu: 0.998 | auc_cos: 1.000
	val: acc_eu: 0.798 | acc_cos: 0.835 | auc_eu: 0.832 | auc_cos: 0.909
[36m	Saving model and optimizer state to /media/vohoang/WorkSpace/ubuntu/projects/in-process/Bachelor-s-Project/checkpoint/magface/concat/models/checkpoint.pth[0m
Epoch 3:
	train: loss 1.287 | loss id   1.26 | top_1_acc 0.9671 | top_5_acc 0.9880 | acc_eu: 0.985 | acc_cos: 0.995 | auc_eu: 0.998 | auc_cos: 1.000
	val: acc_eu: 0.804 | acc_cos: 0.804 | auc_eu: 0.833 | auc_cos: 0.918
[36m	Saving 

KeyboardInterrupt: 