In [1]:
!git clone https://github.com/nkt780426/concat.git

Cloning into 'concat'...
remote: Enumerating objects: 519, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 519 (delta 20), reused 64 (delta 16), pack-reused 445 (from 1)[K
Receiving objects: 100% (519/519), 77.81 MiB | 42.16 MiB/s, done.
Resolving deltas: 100% (35/35), done.


In [2]:
!gdown --quiet --folder "https://drive.google.com/drive/folders/111XLjyrsaLIAu4BsW8ZhLa68vnWLDLnH?usp=sharing"

In [3]:
!mv concat/going_modular .

In [4]:
!ls -la

total 52
drwxr-xr-x 5 root root  4096 Dec 21 18:03 .
drwxr-xr-x 6 root root  4096 Dec 21 18:02 ..
drwxr-xr-x 3 root root  4096 Dec 21 18:02 checkpoint
drwxr-xr-x 6 root root  4096 Dec 21 18:03 concat
drwxr-xr-x 8 root root  4096 Dec 21 18:02 going_modular
---------- 1 root root 32036 Dec 21 18:03 __notebook__.ipynb


In [2]:
import warnings
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import albumentations as A


from going_modular.model.TripletFaceRecognition import EmbeddingNetConcatV2, TripletNetConcatV2
from going_modular.dataloader.triplet import TripletDatasetConcatV2
from going_modular.loss.TripletLoss import TripletLoss
from going_modular.dataloader.triplet import CustomExrDatasetConCatV2
from going_modular.utils.MultiMetricEarlyStopping import MultiMetricEarlyStopping
from going_modular.utils.ModelCheckPoint import ModelCheckpoint
from going_modular.utils.transforms import RandomResizedCropRect, GaussianNoise
from going_modular.train_eval.triplet.train import fit

import os

import warnings
warnings.filterwarnings("ignore")

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Đặt seed toàn cục
seed = 42
torch.manual_seed(seed)

CONFIGURATION = {
    # Thư mục
    'type': 'concat_2/0',
    'data_dir_1': './Dataset/Normal_Map',
    'checkpoint_1': './checkpoint/new/normalmap/iresnet18/models/checkpoint.pth',
    'type_1': 'normalmap',
    'backbone1': 'iresnet18',
    'data_dir_2': './Dataset/Albedo',
    'checkpoint_2': './checkpoint/new/albedo/iresnet18/models/checkpoint.pth',
    'type_2': 'albedo',
    'backbone2': 'iresnet18',
    'checkpoint_dir': './checkpoint/new/',

    # Cấu hình train
    'epochs': 298,
    'num_workers': 4,
    'batch_size': 16,
    'image_size': 256,
    'embedding_size': 512,
    
    'start_lr': 1e-4,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'alpha': 0.9,
    
    # triplet
    'margin': 1.,
}


train_transform = A.Compose([
    RandomResizedCropRect(256),
    GaussianNoise(),
], additional_targets={
    'image_1': 'image',
})


test_transform = A.Compose([
    A.Resize(height=CONFIGURATION['image_size'], width=CONFIGURATION['image_size'])
], additional_targets={
    'image_1': 'image',
})

triplet_concat_train_dataset = TripletDatasetConcatV2(data_dir_1=CONFIGURATION['data_dir_1'], type_1=CONFIGURATION['type_1'], data_dir_2=CONFIGURATION['data_dir_2'], type_2=CONFIGURATION['type_2'], transform=train_transform, train=True)
triplet_concat_test_dataset = TripletDatasetConcatV2(data_dir_1=CONFIGURATION['data_dir_1'], type_1=CONFIGURATION['type_1'], data_dir_2=CONFIGURATION['data_dir_2'], type_2=CONFIGURATION['type_2'], transform=test_transform, train=False)


triplet_concat_train_loader = DataLoader(
    triplet_concat_train_dataset, 
    batch_size=CONFIGURATION['batch_size'], 
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)


triplet_concat_test_loader = DataLoader(
    triplet_concat_test_dataset, 
    batch_size=CONFIGURATION['batch_size'], 
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)


roc_concat_train_dataset = CustomExrDatasetConCatV2(data_dir_1=CONFIGURATION['data_dir_1'], type_1=CONFIGURATION['type_1'], data_dir_2=CONFIGURATION['data_dir_2'], type_2=CONFIGURATION['type_2'], transform=train_transform, train=True)
roc_concat_train_loader = DataLoader(
    roc_concat_train_dataset,
    num_workers=4,
    batch_size=CONFIGURATION['batch_size'],
)

roc_concat_test_dataset = CustomExrDatasetConCatV2(data_dir_1=CONFIGURATION['data_dir_1'], type_1=CONFIGURATION['type_1'], data_dir_2=CONFIGURATION['data_dir_2'], type_2=CONFIGURATION['type_2'], transform=test_transform, train=False)
roc_test_loader = DataLoader(
    roc_concat_test_dataset,
    num_workers=4,
    batch_size=CONFIGURATION['batch_size'],
)


checkpoint_1 = torch.load(CONFIGURATION['checkpoint_1'])
new_checkpoint_1_state_dict = {}
for key, value in checkpoint_1['model_state_dict'].items():
    new_key = key[14:]  # Cắt bỏ 14 ký tự đầu tiên
    new_checkpoint_1_state_dict[new_key] = value
    
checkpoint_2 = torch.load(CONFIGURATION['checkpoint_2'])
new_checkpoint_2_state_dict = {}
for key, value in checkpoint_2['model_state_dict'].items():
    new_key = key[14:]  # Cắt bỏ 14 ký tự đầu tiên
    new_checkpoint_2_state_dict[new_key] = value
    
embedding_net = EmbeddingNetConcatV2(conf=CONFIGURATION)

embedding_net.resnet1.load_state_dict(new_checkpoint_1_state_dict)

embedding_net.resnet2.load_state_dict(new_checkpoint_2_state_dict)

model = TripletNetConcatV2(embedding_net).to(device)

# Freeze 3 mạng lại
for param in model.embedding_net.resnet1.parameters():
    param.requires_grad = False
    
for param in model.embedding_net.resnet2.parameters():
    param.requires_grad = False

    
criterion = TripletLoss(CONFIGURATION['margin'])
optimizer = Adam(model.parameters(), lr=CONFIGURATION['start_lr'])
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=1, eta_min=1e-6)

earlystop_dir = os.path.abspath(CONFIGURATION['checkpoint_dir'] + CONFIGURATION['type'] + '/models')

early_stopping = MultiMetricEarlyStopping(
    monitor_keys=['cosine_auc', 'euclidean_auc'],
    patience=1000,
    mode='max',
    verbose=0,
    save_dir=earlystop_dir,
    start_from_epoch=0
)      

checkpoint_path = os.path.abspath(CONFIGURATION['checkpoint_dir'] + CONFIGURATION['type'] + '/models/checkpoint.pth')
modle_checkpoint = ModelCheckpoint(filepath=checkpoint_path, verbose=1)

In [3]:
fit(
    conf=CONFIGURATION,
    start_epoch=0,
    model=model,
    triplet_train_loader=triplet_concat_train_loader, 
    triplet_test_loader=triplet_concat_test_loader, 
    criterion=criterion,
    optimizer=optimizer, 
    scheduler=scheduler, 
    epochs=CONFIGURATION['epochs'], 
    device=device,
    roc_train_loader=roc_concat_train_loader, 
    roc_test_loader=roc_test_loader,
    early_stopping=early_stopping,
    model_checkpoint=modle_checkpoint
)

Epoch 1:
	train: loss: 0.1111 | auc_cos: 0.9956 | auc_eu: 0.9976
	test: loss: 0.9561 | auc_cos: 0.9614 | auc_eu: 0.9551
[36m	Saving model and optimizer state to /media/vohoang/WorkSpace/ubuntu/projects/in-process/concat/checkpoint/new/concat_2/0/models/checkpoint.pth[0m
Epoch 2:
	train: loss: 0.0741 | auc_cos: 0.9954 | auc_eu: 0.9979
	test: loss: 0.8646 | auc_cos: 0.9613 | auc_eu: 0.9547
[36m	Saving model and optimizer state to /media/vohoang/WorkSpace/ubuntu/projects/in-process/concat/checkpoint/new/concat_2/0/models/checkpoint.pth[0m
Epoch 3:
	train: loss: 0.0112 | auc_cos: 0.9949 | auc_eu: 0.9971
	test: loss: 0.9214 | auc_cos: 0.9612 | auc_eu: 0.9545
[36m	Saving model and optimizer state to /media/vohoang/WorkSpace/ubuntu/projects/in-process/concat/checkpoint/new/concat_2/0/models/checkpoint.pth[0m
Epoch 4:
	train: loss: 0.0090 | auc_cos: 0.9944 | auc_eu: 0.9973
	test: loss: 1.0737 | auc_cos: 0.9585 | auc_eu: 0.9505
[36m	Saving model and optimizer state to /media/vohoang/Work

In [7]:
!zip -r output.zip checkpoint/new/concat_2/0

  adding: checkpoint/new/concat/ (stored 0%)
  adding: checkpoint/new/concat/logs/ (stored 0%)
  adding: checkpoint/new/concat/logs/Cosine_auc_test/ (stored 0%)
  adding: checkpoint/new/concat/logs/Cosine_auc_test/events.out.tfevents.1734804322.3c7a2fe95029.21.4 (deflated 63%)
  adding: checkpoint/new/concat/logs/Loss_train/ (stored 0%)
  adding: checkpoint/new/concat/logs/Loss_train/events.out.tfevents.1734804322.3c7a2fe95029.21.1 (deflated 58%)
  adding: checkpoint/new/concat/logs/Cosine_auc_train/ (stored 0%)
  adding: checkpoint/new/concat/logs/Cosine_auc_train/events.out.tfevents.1734804322.3c7a2fe95029.21.3 (deflated 63%)
  adding: checkpoint/new/concat/logs/events.out.tfevents.1734804214.3c7a2fe95029.21.0 (deflated 9%)
  adding: checkpoint/new/concat/logs/Euclidean_auc_test/ (stored 0%)
  adding: checkpoint/new/concat/logs/Euclidean_auc_test/events.out.tfevents.1734804322.3c7a2fe95029.21.6 (deflated 64%)
  adding: checkpoint/new/concat/logs/Loss_test/ (stored 0%)
  a