# Setup

In [24]:
%matplotlib inline

import os
import torch
import torch.nn as nn
import numpy as np
from imgaug import augmenters as iaa
import torchvision
import random
import PIL.Image as Image
import cv2
import math

# Setting Parameters

- 학습 시 Augmentation 용

In [25]:
# Agumentation setting of training dataset 
train_seq = iaa.Sequential([
    iaa.GaussianBlur(sigma=(0, 0.3)),
    iaa.Dropout((0.01, 0.15), per_channel=0.5),
    iaa.Affine(
#         scale={"x": (0.95, 1.05), "y": (0.95, 1.05)},
        translate_percent={"x": (-0.15, 0.2), "y": (-0.15, 0.2)},
        order=[0, 1],
        cval=1
    )
], random_order=True) 


In [26]:
class FaceDataset(torch.utils.data.Dataset):
    def __init__(self, meta, degree):
        with open(meta, 'r') as fin:
            self.x = [x for x in fin]

        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.RandomRotation(degree),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))
        ])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        entry = self.x[idx].split('|')
        
        # for jiwon server setting

        file = Image.fromarray(cv2.imread(entry[0]))
        file = self.transform(file)

        n_id = entry[1]

        return file, np.array(n_id, dtype=np.int64)


In [27]:
from torchvision.models import resnet18
from pretrainedmodels import xception



class FaceNet(torch.nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        
        print(num_classes)
        self.basenet = resnet18(num_classes=num_classes) # feature 차원 갯수 -> 이후 작업 

    def forward(self, x):
        x = self.basenet(x)
        x = torch.nn.functional.normalize(x)
        return x
    
class FaceNetX(torch.nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        
        print(num_classes)
        self.pretrained_model = xception(pretrained='imagenet')
        self.fc_input = 2048
        self.activation = nn.ReLU()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.)
        self.last_linear = nn.Linear(self.fc_input, num_classes)

    def forward(self, x):
        x = self.pretrained_model.features(x)
        x = self.activation(x)
        x = self.pooling(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.last_linear(x)
        x = torch.nn.functional.normalize(x)
        return x
    
class FaceCentroids(torch.nn.Module): # 무게중심 
    def __init__(self, n_ids, n_dim=16):
        super().__init__()

        self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (n_ids, n_dim)))

    def forward(self, x):
        out = torch.nn.functional.linear(
            torch.nn.functional.normalize(x), 
            torch.nn.functional.normalize(self.weight))

        return out
    
    
class ArcFace(torch.nn.Module):
    def __init__(self, sp=64.0, sn=64.0, m=0.5, **kwargs):
        super(ArcFace, self).__init__()
        self.sp = sp / sn  # sn will be multiplied again
        self.sn = sn
        self.m = m
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

    def forward(self, cosine: torch.Tensor, label):
        cosine = cosine.clamp(-1, 1)  # for numerical stability
        index = torch.where(label != -1)[0]
       
        cos_theta = cosine[index]
        target_logit = cos_theta[torch.arange(0, cos_theta.size(0)), label[index]].view(
            -1, 1
        )
        sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
        cos_theta_m = (target_logit * self.cos_m - sin_theta * self.sin_m).to(
            cosine.dtype
        )  # cos(target+margin)

        cosine[index] = cosine[index].scatter(
            1, label[index, None], cos_theta_m * self.sp
        )
        cosine.mul_(self.sn)

        return cosine


# Trainining the model

- Pretrained 모델로 전이학습을 수행하기 때문에 아래 모델을 직접 학습하지 않고, Synthetic Data로 미리 학습된 모델을 전이학습으로 학습시킴

- 아래의 주석은 모델의 정의를 위해 남겨둔 코드

In [28]:

c = 2

# network selecting - ST/Resnet18
net = FaceNet(16*c)
net = FaceNetX(16*c)

net = net.to('cuda:0')

fc = FaceCentroids(1002, 16*c)
fc = fc.to('cuda:0')

loss_arcface = ArcFace(m=0.2)

optim = torch.optim.Adam(net.parameters(), lr=0.001)
batch = 32
degree = 20
    
train_data = 'fvc'

if train_data == 'fvc':
    trainset = FaceDataset('real_meta.txt', 0)
    testset = FaceDataset('real_eval_meta.txt', 0)
    
dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch, shuffle=True, drop_last=True, num_workers=5)
dataloader_val = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, drop_last=False, num_workers=5)

for epoch in range(100):
    net = net.train()
    loss_accum = []

    for idx, (img, lbl) in enumerate(dataloader):
        img = img.to('cuda:0')
        lbl = lbl.to('cuda:0')

        feat = net(img)
        logit = fc(feat)
        # print(logit)
        logit = loss_arcface(logit, lbl)
        # print(logit)
        loss = torch.nn.functional.cross_entropy(logit, lbl)
        loss_accum.append(loss.item())

        loss.backward()
        optim.step()
        optim.zero_grad()
        
        if idx % 4 == 0:
            print('.', end='')

    loss_accum = torch.tensor(loss_accum)
    print(f'epoch: {epoch} | loss: {loss_accum.mean().item():.04f}')

    net = net.eval()
    
    eval_results = []
    lbls = []

    with torch.no_grad():
        for idx, (img, lbl) in enumerate(dataloader_val):
            img = img.to('cuda:0')
            feat = net(img)
         
            eval_results.append(feat.to('cpu'))
            lbls.append(lbl)
           

        eval_results = torch.cat(eval_results)
        mat_similarity = eval_results.matmul(eval_results.T)

        lbls = torch.cat(lbls)
        lbls = lbls.view(-1, lbls.size(0)) == lbls.view(lbls.size(0), -1)

        accuracy = []

        total_comp = torch.ones_like(mat_similarity).triu(1)
        total_comp = total_comp.sum().item()

        for threshold in [0.0, 0.2, 0.4, 0.6, 0.8]:
            threshed = mat_similarity > threshold
            
            #remove diagonal
            correct = (threshed == lbls).triu(1).sum()

            accuracy.append(correct / total_comp)
            
        
        print(f'Accuracy: {" | ".join(f"{acc:.03f}" for acc in accuracy)}')

    print('=' * 20)

32
32
.......................................................................epoch: 0 | loss: 32.2711
Accuracy: 0.163 | 0.353 | 0.584 | 0.807 | 0.958
.......................................................................epoch: 1 | loss: 17.9724
Accuracy: 0.430 | 0.769 | 0.947 | 0.994 | 1.000
.......................................................................epoch: 2 | loss: 3.7085
Accuracy: 0.469 | 0.828 | 0.975 | 0.999 | 1.000
.......................................................................epoch: 3 | loss: 0.7608
Accuracy: 0.462 | 0.828 | 0.977 | 0.999 | 1.000
......................................

KeyboardInterrupt: 

In [29]:
for idx, (img, lbl) in enumerate(dataloader):
    img = img.to('cuda:0')
    lbl = lbl.to('cuda:0')

    feat = net(img)
    print(feat.shape)
    logit = fc(feat)
    print(logit.shape)
    logit = loss_arcface(logit, lbl)
    print(logit.shape)
    loss = torch.nn.functional.cross_entropy(logit, lbl)
    break

torch.Size([32, 32])
torch.Size([32, 1002])
torch.Size([32, 1002])


In [31]:
result_path = os.path.join("/media/data2/jiwon/DeepfakeAD")
os.makedirs(result_path, exist_ok=True)

snapshot = {
    "model_dict": net.state_dict(),
}
torch.save(snapshot, os.path.join(result_path, f'recog_model.pt'))