# 1. Baseline

## Libraries

In [None]:
import os
from pathlib import Path
from tqdm import tqdm
from easydict import EasyDict as edict

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import transforms as trans

from data.ms1m import get_train_loader
from data.lfw import LFW

from backbone.arcfacenet import SEResNet_IR
from margin.ArcMarginProduct import ArcMarginProduct 
from util.utils import save_checkpoint, test

## Configuration

In [None]:
conf = edict()

conf.train_root = './dataset/MS1M'
conf.lfw_test_root = './dataset/lfw_aligned_112'
conf.lfw_file_list = './dataset/lfw_pair.txt'

conf.mode = 'se_ir' # 'ir'
conf.depth = 50
conf.margin_type = 'ArcFace'
conf.feature_dim = 512
conf.scale_size = 32.0
conf.batch_size = 16 #16
conf.lr = 0.01
conf.milestones = [8, 10, 12]
conf.total_epoch = 14

conf.save_folder = './saved'
conf.save_dir = os.path.join(conf.save_folder, conf.mode + '_' + str(conf.depth)) # ./saved/se_ir_50
conf.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
conf.num_workers = 4
conf.pin_memory = True

In [None]:
os.makedirs(conf.save_dir, exist_ok = True)

## Data Loader

In [None]:
transform = trans.Compose([
    trans.ToTensor(), # range [0,255] -> [0.0, 1.0]
    trans.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

In [None]:
trainloader, class_num = get_train_loader(conf)

In [None]:
print(f'number of id: {class_num}')

In [None]:
print(trainloader.dataset)

In [None]:
lfwdataset = LFW(conf.lfw_test_root, conf.lfw_file_list, transform = transform)
lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size = 128, num_workers = conf.num_workers)

# Model

In [None]:
print(conf.device)

In [None]:
net = SEResNet_IR(conf.depth, feature_dim = conf.feature_dim, mode = conf.mode).to(conf.device)
margin = ArcMarginProduct(conf.feature_dim, class_num).to(conf.device)

In [None]:
print(net)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.SGD([
    {'params' : net.parameters(),
     'weight_decay' : 5e-4  
    },
    {'params' : margin.parameters(),
     'weight_decay' : 5e-4  
    }
], lr = conf.lr, momentum = 0.9, nesterov = True)

In [None]:
print(optimizer)

In [None]:
def schedule_lr():
    for params in optimizer.param_groups:
        params['lr'] /= 10
    print(optimizer)

## Train

In [None]:
best_acc = 0

for epoch in range(1, conf.total_epoch+1):
    
    net.train()
    
    #net.eval()
    
    print(f'epoch {epoch}/{conf.total_epoch}', flush = True)
    
    if epoch == conf.milestones[0]:
        schedule_lr()
    if epoch == conf.milestones[1]:
        schedule_lr()
    if epoch == conf.milestones[2]:
        schedule_lr()
        
    for data in tqdm(trainloader):
        img, label = data[0].to(conf.device), data[1].to(conf.device)
        optimizer.zero_grad()
        
        logits = net(img)
        output = margin(logits, label)
        total_loss = criterion(output, label)
        total_loss.backward()
        optimizer.step()
        
    #test
    net.eval()
    lfw_acc = test(conf, net, şfwdataset, lfwloader)
    
    print(f'\nLFW : {lfw_acc} | train_loss : {total_loss.item()} \n')
    
    is_best = lfw_acc > best_acc
    best_acc = max(lfw_acc, best_acc)
    
    #saving model
    save_checkpoint({
        'epoch' : epoch,
        'net_state_dict' : net.state_dict(),
        'margin_state_dict' : margib.state_dict(),
        'best_acc' : best_acc
    }, is_best, checkpoint = conf.save_dir)

In [None]:
'''
SOTA : The state of the art

1. MS1M datasetinin tamaminin indirilmesi ; alternatif olarak CASIA kullinilabilir; aşagidaki parametlere değiştirilebilir
2. conf.mode = 'ir'
3. conf.depth = '100'
4. conf.total_epoch = 20
5. conf.milestones = [12,16,18]

lfw = 99.83%

# 2 adet v100 (32 GB) -> 5 gün sürüyor

Egitilen modeli cihaz üzerinde çalıştırmak (mobil yada kamera gibi) istenirse : MobileFaceNet arastirmasi yap
'''