<a href="https://www.kaggle.com/code/junhyeonkwon/deepfake-detection-revalid-lightning?scriptVersionId=171820440" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
!pip install lightning
!pip install wandb

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")
wandb.login(key=secret_value_0)

In [None]:
import numpy as np
import pandas as pd
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import os
import math
import random
from glob import glob
from datetime import datetime

import torch
import torch.nn as nn
from torchvision import datasets, disable_beta_transforms_warning
from torchvision.transforms import v2
from torchvision.transforms import RandomErasing
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryConfusionMatrix

from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights

import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger

disable_beta_transforms_warning()

RANDOM_SEED = 42
seed_everything(seed=RANDOM_SEED,workers=True)

# Transforms ===============================================

class GuidedRandomErasing(RandomErasing):
    """Same as Random erasing but the rectangle region will appear with the
    given normal distribution ~N(m,s). Pixel coordinates (i,j) will be scaled to 
    fit the x domain of range (-5,5) in the normal distribution curve
    (x=(i-width/2)*10/width).
    
    m : (mean_y, mean_x)
    s : (std_y, std_x)
       
    """
    def __init__(self, 
                 p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, 
                 m=(0.5,0.5), s=(1,1), inplace=False):
        super().__init__(p,scale,ratio,value,inplace)
        self.m = torch.tensor(m)
        self.s = torch.tensor(s)
        
    @staticmethod 
    def get_erase_params(img, scale, ratio, mean, std, value=None):
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
        area = img_h * img_w
        
        log_ratio = torch.log(torch.tensor(ratio))
        for _ in range(10):
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
            if not (h < img_h and w < img_w):
                continue

            # pick i, j
            img_hw = torch.tensor((img_h,img_w))
            i_j = torch.normal(mean=torch.zeros(2), std=std)*img_hw/10
            i_j = i_j + mean*img_hw - torch.tensor((h,w))/2
            i, j = int(i_j[0].item()), int(i_j[1].item())
            # adjust h, w
            if i < 0:
                h -= -i
                i = 0
            elif i+h >= img_h:
                h = img_h-i-1
            if j < 0:
                w -= -j
                j = 0
            elif j+w >= img_w:
                w = img_w-j-1
            # if h, w < 0 continue
            if h <= 0 or w <= 0:
                continue
                
            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
                
                # print(f"erase ratio {h*w/area}, {erase_area/area}")
            return i, j, h, w, v
        
        # Return original image
        return 0, 0, img_h, img_w, img
    
    def forward(self, img):
        """
        Args:
            img (Tensor): Tensor image to be erased.

        Returns:
            img (Tensor): Erased Tensor image.
        """
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
                value = [float(self.value)]
            elif isinstance(self.value, str):
                value = None
            elif isinstance(self.value, (list, tuple)):
                value = [float(v) for v in self.value]
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
                    f"{img.shape[-3]} (number of input channels)"
                )
                
            # TODO assert shape of mean and std
            # I give up
            x, y, h, w, v = GuidedRandomErasing.get_erase_params(
                img, scale=self.scale, ratio=self.ratio,
                mean = self.m, std = self.s, value=value)
            return F.erase(img, x, y, h, w, v, self.inplace)
        return img

    def __repr__(self):
        s0 = super().__repr__()
        s = (
            f"{s0[:-1]}, "
            f"mean={self.m} "
            f"std={self.s})")
        return s
    
# transform loader
def get_valid_transform_RE(image_size, precision='32-true', random_erase=[]):
    """
    random_erase List : list of dict with 'type' and 'kwargs' as keys
    """
    # precision
    if '16' in precision:
        dtype = torch.float16
    elif '32' in precision:
        dtype = torch.float32
    elif '64' in precision:
        dtype = torch.float64
    else:
        raise NotImplementedError(f'{precision} not implemented')
        
    tr_list = [
        v2.PILToTensor(),
        v2.Resize((image_size, image_size),antialias=True)]
    # random erase
    for re_layer in random_erase:
        if re_layer['type'].lower() == 'randomerasing':
            tr_list.append(v2.RandomErasing(**re_layer['kwargs']))
        elif re_layer['type'].lower() == 'guidedrandomerasing':
            tr_list.append(GuidedRandomErasing(**re_layer['kwargs']))
        else:
            # warning
            print(f"{re_layer['type']} not implemented. Skipping {re_layer['type']}")
            continue
    
    valid_transform = v2.Compose([
        *tr_list,
        v2.ConvertImageDtype(dtype=dtype),
        v2.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
    ])
    return valid_transform

# Datasets =======================================================

class MyDataset(Dataset):
    def __init__(self, root, df, transform):
        super(MyDataset).__init__()
        self.root = root
        self.paths = list(df['path'])
        self.labels = list(df['label'])
        self.transform = transform
        
    def __repr__(self):
        out = super().__repr__()
        return f"{out} of size {self.__len__()}\n{self.transform.__repr__()}"
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # get dtype conversion from the transform
        dtype = torch.float32
        for t in self.transform.transforms:
            if isinstance(t,v2.ConvertImageDtype):
                dtype = t.dtype
        if self.labels[idx] == 'FAKE':
            label = torch.zeros(1,dtype=dtype)
        else:
            label = torch.ones(1,dtype=dtype)
        
        img = Image.open(os.path.join(self.root,path))
        
        if self.transform:
            img = self.transform(img)
        
        return img, label

# helper function =========================================================

# model loader
def get_model_from_wandb(logger, artifact_name, device, save_dir_root='/kaggle/working'):
    """
    model artifact will be downloaded as /{save_dir_root}/{run_id}/model.ckpt
    """
    # download artifact to /kaggle/working
    run_id = artifact_name.split('/')[2]
    logger.download_artifact(artifact_name, save_dir=f"{save_dir_root}/{run_id}", artifact_type='model')
    ckpt_path = sorted(glob(f"{save_dir_root}/{run_id}/*.ckpt"))[-1]
    return LitEfficientNet.load_from_checkpoint(ckpt_path, map_location=device)

# Lightning Module =========================================================

class LitEfficientNet(L.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = efficientnet_v2_s(weights = EfficientNet_V2_S_Weights)
        self.model.classifier[1] = \
            nn.Linear(in_features=1280, out_features=1, bias=True)
        
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.acc_fn = BinaryAccuracy()
        self.f1_score = BinaryF1Score() # ?
        self.cm_fn = BinaryConfusionMatrix()
        self.val_step_preds = []
        self.val_step_ys = []
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.log('valid_loss', loss)
        acc = self.acc_fn(pred, y)
        self.log('valid_acc', acc)
        f1_score = self.f1_score(pred, y)
        self.log('valid_f1',f1_score)
        self.val_step_preds.append(pred)
        self.val_step_ys.append(y)
        if torch.isnan(loss):
            self.log("nan error",1)
        return loss
    
    def on_validation_epoch_end(self):
        pred = torch.cat(self.val_step_preds)
        y = torch.cat(self.val_step_ys)
        self.log('total_f1', self.f1_score(pred, y))
        self.log('total_acc', self.acc_fn(pred, y))
        
        y = y.flatten()
        pred = pred.flatten()
        
        cm = self.cm_fn(pred, y)
        for key, val in zip(['TN','FP','FN','TP'],cm.flatten().tolist()):
            self.log(key,val)
        
        self.val_step_preds.clear()
        self.val_step_ys.clear()
            

In [None]:
# read dataset metadata df
FACE_DSET_META_PATH = '/kaggle/input/using-yunet/deepfake-detection-face-dataset.csv'
df = pd.read_csv(FACE_DSET_META_PATH)
df.head()
# about the dataset
# column name : video, frame_id, path, label, split
# 73287 faces, 2448 videos, avg 29 faces per video
# 82%(60343) Fakes, 18%(12944) Reals
# 12 splits : fake > A1, A2, A3, ... ,E2 real > R1, R2

# prepare dataframe in order to build dataset & dataloader
df_real = df[df['label'] == 'REAL']
df_real_vid = df_real['video'].drop_duplicates()
df_real_splits = []
for i in range(4):
    drv = df_real_vid.sample(frac=1/(4-i), random_state=RANDOM_SEED)
    df_real_vid.drop(drv.index, inplace = True)
    df_real_splits.append(
        pd.merge(drv,df_real,how='left',left_on='video',right_on='video'))
    
df_fake = df[(df['label'] == 'FAKE') & (df['frame_id'] < 8)]
df_fake_vid = df_fake['video'].drop_duplicates()
df_fake_splits = []
for i in range(4):
    dfv = df_fake_vid.sample(frac=1/(4-i),random_state=RANDOM_SEED)
    df_fake_vid.drop(dfv.index, inplace=True)
    df_fake_splits.append(
        pd.merge(dfv,df_fake,how='left',left_on='video',right_on='video'))
    
df_splits = [
    pd.concat((df_fake_splits[i],df_real_splits[i])) \
    for i in range(4)
]

In [None]:
# 0.5^4 * (1*0 + 4*0.02 + 6*0.04 + 4*0.06 + 1*0.08) = 0.04
# 0.25 * (1*0 + 2*a + 1*2*a) = a = 0.04

# configs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger_config = {
    'logger_type' : 'wandblogger',
    'project' : 'Deepfake_Detection-lightning-RE',
    'log_model' : False,
}
data_config = {
    'root' : '/kaggle/input/using-yunet',
    'image_size' : 224,
    'batch_size' : 16,
    'num_workers' : 3
}
transform_layers = [
    { # original random erasing 1
        'type' : 'RandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.02,0.06),
            'ratio' : (0.3,3.0),
            'value' : 'random'
        }
    },
    { # original random erasing 2
        'type' : 'RandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.02,0.06),
            'ratio' : (0.3,3.0),
            'value' : 'random'
        }
    },
    { # left eye random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.5,1.3),
            'value' : 'random',
            'm' : (0.41,0.61),
            's' : (0.24,0.46)
        }
    },
    { # right eye random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.5,1.3),
            'value' : 'random',
            'm' : (0.41,0.38),
            's' : (0.24,0.43)
        }
    },
    { # nose random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.3,1.2),
            'value' : 'random',
            'm' : (0.54,0.5),
            's' : (0.24,0.7)
        }
    },
    { # lips random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.2,1),
            'value' : 'random',
            'm' : (0.66,0.5),
            's' : (0.22,0.48)
        }
    }
]
artifacts = [
    # TODO copy paste artifact api ids
    # a. gb0.1 - 이거 다시 돌려봐야함
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021253_fold_0:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021319_fold_1:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021345_fold_2:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021412_fold_3:v0',
    # b. gb0.1-re0.2:1
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240728_fold_0:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240757_fold_1:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240826_fold_2:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240854_fold_3:v0',
    # c. gb0.1-re0.2:2
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241057_fold_0:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241124_fold_1:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241150_fold_2:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241217_fold_3:v0'
]
trained_group = [
    'gb0.1', 'gb0.1-re0.2:1', 'gb0.1-re0.2:2'
]

In [None]:
# iterate through given artifacts
for a_idx, artifact_id in enumerate(artifacts):
    # download artifact if not exist at /kaggle/working
    # init model
    run_id = artifact_id.split('/')[-1]
    fold_id = int(run_id.split(':')[0].split('_fold_')[-1])
    if not os.path.exists(f'/kaggle/working/artifacts/{run_id}'):
        model = get_model_from_wandb(
            WandbLogger(), artifact_id, device, save_dir_root = '/kaggle/working/artifacts')
    else:
        print(f"Artifact '{artifact_id}' already exists. Using existing file...")
        ckpt_path = sorted(glob(f"/kaggle/working/artifacts/{run_id}/*.ckpt"))[-1]
        model = LitEfficientNet.load_from_checkpoint(ckpt_path, map_location=device)
    
    # iterate through given transform_layers
    for t_id, t_layers in enumerate((transform_layers[0:2],transform_layers[2:6])):
        # prep validation loader
        valid_df = df_splits[fold_id]
        valid_dset = MyDataset(data_config['root'], valid_df,
                               get_valid_transform_RE(
                                   image_size = data_config['image_size'],
                                   random_erase = t_layers))
        valid_loader = DataLoader(dataset=valid_dset, 
                                  batch_size=data_config['batch_size'],
                                  num_workers=data_config['num_workers'],
                                  drop_last=True)
        print(valid_dset)
        
        # update logger exp config
        # init logger
        if logger_config['logger_type'].lower() == 'wandblogger':
            logger = WandbLogger(
                        project=logger_config['project'], 
                        name=f'fold_{fold_id}_t{t_id}',
                        id=f"{datetime.now().strftime('%y%m%d%H%M')}_f{fold_id}",
                        group=f"{trained_group[a_idx//4]}_t{t_id}", 
                        log_model=logger_config['log_model'])
        else:
            raise NotImplementedError('other logger types not implemented')
        logger.experiment.config.update({
            'Note' : artifact_id,
            'logger_config' : logger_config,
            'data_config' : data_config,
            'transform_layers' : t_layers
        })
        # model.validate? (log with wandb)
        trainer = Trainer(accelerator=device.type,logger=logger)
        trainer.validate(model=model, dataloaders=valid_loader)
        # wandb finish
        if logger_config['logger_type'].lower() == 'wandblogger':
            wandb.finish()