In [1]:
import warnings
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import albumentations as A


from going_modular.model.TripletFaceRecognition import EmbeddingNet_Concat, TripletNet_Concat
from going_modular.dataloader.triplet import TripletDatasetConcat
from going_modular.loss.TripletLoss import TripletLoss
from going_modular.dataloader.triplet import CustomExrDatasetConcat
from going_modular.utils.MultiMetricEarlyStopping import MultiMetricEarlyStopping
from going_modular.utils.ModelCheckPoint import ModelCheckpoint
from going_modular.utils.transforms import RandomResizedCropRect, GaussianNoise
from going_modular.train_eval.triplet.train import fit

import os

import warnings
warnings.filterwarnings("ignore")

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Đặt seed toàn cục
seed = 42
torch.manual_seed(seed)

CONFIGURATION = {
    'type': 'concat',
    # Thư mục
    'data_dir': './Dataset',
    'checkpoint_dir': './checkpoint/new/',
    'normalmap_checkpoint': './checkpoint/samenetwork/triplet/normalmap/models/checkpoint.pth',
    'albedo_checkpoint': './checkpoint/samenetwork/triplet/albedo/models/checkpoint.pth',

    # Cấu hình train
    'epochs': 119,
    'num_workers': 4,
    'batch_size': 2,
    'image_size': 256,
    'embedding_size': 512,
    
    'start_lr': 1e-4,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'alpha': 0.9,
    
    # triplet
    'margin': 1.,
}


train_transform = A.Compose([
    RandomResizedCropRect(256),
    GaussianNoise(),
])

test_transform = A.Compose([
    A.Resize(height=CONFIGURATION['image_size'], width=CONFIGURATION['image_size'])
])

triplet_concat_train_dataset = TripletDatasetConcat(data_dir=CONFIGURATION['data_dir'], transform=train_transform, train=True)
triplet_concat_test_dataset = TripletDatasetConcat(data_dir=CONFIGURATION['data_dir'], transform=test_transform, train=False)


triplet_concat_train_loader = DataLoader(
    triplet_concat_train_dataset, 
    batch_size=CONFIGURATION['batch_size'], 
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)


triplet_concat_test_loader = DataLoader(
    triplet_concat_test_dataset, 
    batch_size=CONFIGURATION['batch_size'], 
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)


roc_concat_train_dataset = CustomExrDatasetConcat(data_dir=CONFIGURATION['data_dir'], transform=train_transform, train=True)
roc_concat_train_loader = DataLoader(
    roc_concat_train_dataset,
    num_workers=4,
    batch_size=CONFIGURATION['batch_size'],
)

roc_concat_test_dataset = CustomExrDatasetConcat(data_dir=CONFIGURATION['data_dir'], transform=test_transform, train=False)
roc_test_loader = DataLoader(
    roc_concat_test_dataset,
    num_workers=4,
    batch_size=CONFIGURATION['batch_size'],
)


# Load pre-trained state_dict
normal_checkpoint = torch.load(CONFIGURATION['normalmap_checkpoint'])
# Loại bỏ 14 ký tự đầu tiên từ các khóa trong state_dict
new_normalmap_state_dict = {}
for key, value in normal_checkpoint['model_state_dict'].items():
    new_key = key[14:]  # Cắt bỏ 14 ký tự đầu tiên
    new_normalmap_state_dict[new_key] = value

albedo_checkpoint = torch.load(CONFIGURATION['albedo_checkpoint'])
# Loại bỏ 14 ký tự đầu tiên từ các khóa trong state_dict
new_albedo_state_dict = {}
for key, value in normal_checkpoint['model_state_dict'].items():
    new_key = key[14:]  # Cắt bỏ 14 ký tự đầu tiên
    new_albedo_state_dict[new_key] = value
    
embedding_net = EmbeddingNet_Concat(conf=CONFIGURATION)

embedding_net.resnet1.load_state_dict(new_normalmap_state_dict, strict=False)

embedding_net.resnet2.load_state_dict(new_albedo_state_dict)

model = TripletNet_Concat(embedding_net).to(device)

# Freeze 2 mạng lại
for param in model.embedding_net.resnet1.parameters():
    param.requires_grad = False
    
for param in model.embedding_net.resnet2.parameters():
    param.requires_grad = False
    
criterion = TripletLoss(CONFIGURATION['margin'])
optimizer = Adam(model.parameters(), lr=CONFIGURATION['start_lr'])
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=1, eta_min=1e-6)

earlystop_dir = os.path.abspath(CONFIGURATION['checkpoint_dir'] + CONFIGURATION['type'] + '/models')

early_stopping = MultiMetricEarlyStopping(
    monitor_keys=['cosine_auc', 'euclidean_auc'],
    patience=50,
    mode='max',
    verbose=0,
    save_dir=earlystop_dir,
    start_from_epoch=80
)      

checkpoint_path = os.path.abspath(CONFIGURATION['checkpoint_dir'] + CONFIGURATION['type'] + '/models/checkpoint.pth')
modle_checkpoint = ModelCheckpoint(filepath=checkpoint_path, verbose=1)

2024-12-21 17:30:55.963288: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
fit(
    conf=CONFIGURATION,
    start_epoch=0,
    model=model,
    triplet_train_loader=triplet_concat_train_loader, 
    triplet_test_loader=triplet_concat_test_loader, 
    criterion=criterion,
    optimizer=optimizer, 
    scheduler=scheduler, 
    epochs=CONFIGURATION['epochs'], 
    device=device, 
    roc_train_loader=roc_concat_train_loader, 
    roc_test_loader=roc_test_loader,
    early_max_stopping=early_stopping,
    model_checkpoint=modle_checkpoint
)

Epoch 1:
	train: loss: 3.0478 | auc_cos: 0.9214 | auc_eu: 0.9212
	test: loss: 1.5250 | auc_cos: 0.9199 | auc_eu: 0.9170
[36m	Saving model and optimizer state to /media/vohoang/WorkSpace/ubuntu/projects/in-process/Bachelor-s-Project/checkpoint/new/concat/models/checkpoint.pth[0m


KeyboardInterrupt: 