<a href="https://www.kaggle.com/code/junhyeonkwon/deepfake-detection-lightning?scriptVersionId=171821718" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install lightning
!pip install wandb

Collecting lightning
  Downloading lightning-2.2.1-py3-none-any.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading lightning-2.2.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m37.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: lightning
Successfully installed lightning-2.2.1


In [2]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
import numpy as np
import pandas as pd
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import os
import random
from glob import glob
from tqdm import tqdm
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR, OneCycleLR
from torchvision import datasets, disable_beta_transforms_warning
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import BinaryAccuracy

from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights

import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint

disable_beta_transforms_warning()

RANDOM_SEED = 42
seed_everything(seed=RANDOM_SEED,workers=True)

# Transforms ===============================================

def get_train_transform(
    image_size, precision='32-true', 
    gaussian_blur = None, random_erase = None):
    
    if '16' in precision:
        dtype = torch.float16
    elif '32' in precision:
        dtype = torch.float32
    elif '64' in precision:
        dtype = torch.float64
    else:
        raise NotImplementedError(f'{precision} not implemented')
    tr_list = [
        v2.PILToTensor(),
        v2.Resize((image_size, image_size),antialias=True)]
    if gaussian_blur:
        tr_list.append(
            v2.GaussianBlur(**gaussian_blur))
    tr_list.append(
        v2.RandomAdjustSharpness(sharpness_factor=2, p=0.5),)
    if random_erase:
        re = random_erase.copy()
        n = re.pop('stack_layer',1)
        for i in range(n):
            tr_list.append(
                v2.RandomErasing(**re))
    
    train_transform = v2.Compose([
        *tr_list,
        v2.RandomHorizontalFlip(p=0.5),
        v2.ConvertImageDtype(dtype=dtype),
        v2.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
    ])
    return train_transform

def get_valid_transform(image_size, precision='32-true'):
    if '16' in precision:
        dtype = torch.float16
    elif '32' in precision:
        dtype = torch.float32
    elif '64' in precision:
        dtype = torch.float64
    else:
        raise NotImplementedError(f'{precision} not implemented')
    
    valid_transform = v2.Compose([
        v2.PILToTensor(),
        v2.Resize((image_size, image_size),antialias=True),
        v2.ConvertImageDtype(dtype=dtype),
        v2.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
    ])
    return valid_transform

# Datasets ===========================================================

class MyDataset(Dataset):
    def __init__(self, root, df, transform):
        super(MyDataset).__init__()
        self.root = root
        self.paths = list(df['path'])
        self.labels = list(df['label'])
        self.transform = transform
        
    def __repr__(self):
        out = super().__repr__()
        return f"{out} of size {self.__len__()}\n{self.transform.__repr__()}"
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # get dtype conversion from the transform
        dtype = torch.float32
        for t in self.transform.transforms:
            if isinstance(t,v2.ConvertImageDtype):
                dtype = t.dtype
        if self.labels[idx] == 'FAKE':
            label = torch.zeros(1,dtype=dtype)
        else:
            label = torch.ones(1,dtype=dtype)
        
        img = Image.open(os.path.join(self.root,path))
        
        if self.transform:
            img = self.transform(img)
        
        return img, label

#===============================================
# helper function
def get_ckpt_path(fold_id, wb_project=None):
    if wb_project: # using wandb
        root=f"/kaggle/working/ckpts/{wb_project}/*_fold_{fold_id}"
        paths = glob(f"{root}/checkpoints/*.ckpt")
    else: # using CSVLogger
        root = f'/kaggle/working/ckpts/fold_{fold_id}'
        paths = glob(f"{root}/*/checkpoints/*.ckpt")
    if len(paths) > 0:
        paths.sort()
        return paths[-1]
    else:
        return None
    
#===============================================
# Lightning Module

class LitEfficientNet(L.LightningModule):
    def __init__(self, model_config):
        '''
        parameters:
        model_config: model configurations required to initialize lightning model
        '''
        super().__init__()
        self.save_hyperparameters()
        self.model = efficientnet_v2_s(weights = EfficientNet_V2_S_Weights)
        self.model.classifier[1] = \
            nn.Linear(in_features=1280, out_features=1, bias=True)
        # self.model.half
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.acc_fn = BinaryAccuracy() # TF threshold is 0.5
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.log("train_loss",loss)
        acc = self.acc_fn(pred, y)
        self.log("train_acc", acc)
        if torch.isnan(loss):
            self.log("nan error", 1)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.log('valid_loss', loss)
        acc = self.acc_fn(pred, y)
        self.log("valid_acc", acc)
        if torch.isnan(loss):
            self.log("nan error", 1)
        return loss
    
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        # get params from config dict
        lr = self.hparams.model_config['learning_rate']
        optim_cfg = self.hparams.model_config['optim_config']
        sched_cfg = self.hparams.model_config['scheduler_config']
        # set optimizer
        if optim_cfg['optimizer'].lower() == 'sgd':
            optimizer = optim.SGD(self.parameters(), 
                                  lr=lr, 
                                  momentum=optim_cfg['momentum'])
        else:
            raise NotImplementedError(
                'optimizers other than SGD not implemented')
        # set scheduler 
        if sched_cfg['scheduler'].lower() == 'onecyclelr':
            scheduler = OneCycleLR(
                optimizer,
                max_lr=lr*sched_cfg['max_lr_coef'],
                epochs=sched_cfg['epochs'],
                steps_per_epoch=sched_cfg['steps_per_epoch'])
        else:
            raise NotImplementedError(
                'schedulers other than onecyclelr not implemented')
            
        return {
            'optimizer':optimizer,
            'lr_scheduler':{
                "scheduler": scheduler,
                "interval" : sched_cfg['interval']
            }
        }
        

INFO: Seed set to 42


In [5]:
# read dataset metadata df
FACE_DSET_META_PATH = '/kaggle/input/using-yunet/deepfake-detection-face-dataset.csv'
df = pd.read_csv(FACE_DSET_META_PATH)
df.head()
# about the dataset
# column name : video, frame_id, path, label, split
# 73287 faces, 2448 videos, avg 29 faces per video
# 82%(60343) Fakes, 18%(12944) Reals
# 12 splits : fake > A1, A2, A3, ... ,E2 real > R1, R2

Unnamed: 0,video,frame_id,path,label,split
0,aodrcrvodk.mp4,0,deepfake-detection-face-dataset/aodrcrvodk/00.jpg,FAKE,A1
1,aodrcrvodk.mp4,1,deepfake-detection-face-dataset/aodrcrvodk/01.jpg,FAKE,A1
2,aodrcrvodk.mp4,2,deepfake-detection-face-dataset/aodrcrvodk/02.jpg,FAKE,A1
3,aodrcrvodk.mp4,3,deepfake-detection-face-dataset/aodrcrvodk/03.jpg,FAKE,A1
4,aodrcrvodk.mp4,4,deepfake-detection-face-dataset/aodrcrvodk/04.jpg,FAKE,A1


In [None]:
# Experiment 1
# 15K Fakes + 13K Reals --> 4 Fold CV

# hyperparameters
limit_fold = 4
trainer_config = {
    'accelerator' : 'auto',
    'precision' : '16-true',
    'max_epochs' : 6,
    'limit_train_batches' : 1.0,
    'limit_val_batches' : 1.0,
    'log_every_n_steps' : 32,
    'profiler' : None,
    'num_sanity_val_steps' : -1,
    'accumulate_grad_batches' : 4,
    'detect_anomaly'  :  False,
    'check_val_every_n_epoch' : 1,
    'deterministic' : True,
    'enable_checkpointing' : True,
    'gradient_clip_val' : 2
}
callbacks_config = {
    'lr_monitor' : {
        'logging_interval' : 'step'
    }, 
    'early_stopping' : {
        'monitor' : 'valid_loss',
        'min_delta' : 0.0,
        'patience' : 3
    }
}
logger_config = {
    'logger_type' : 'wandblogger',
    'project' : 'Deepfake_Detection-lightning-4cv',
    'log_model' : True,
    'group' : 'exp2-gblur_s0.1'
}
data_config = {
    'root' : '/kaggle/input/using-yunet',
    'image_size' : 224,
    'gaussian_blur' : {
        'kernel_size' : (5,5),
        'sigma' : 0.1
    },
    'random_erase' : None,
    'batch_size' : 16,
    'num_workers' : 3
}
model_config = {
    'learning_rate' : 1e-4,
    'optim_config' : {
        'optimizer' : 'SGD',
        'momentum' : .9,
        'weight_decay' : 1.2e-3
    },
    'scheduler_config' : {
        'scheduler' : 'OneCycleLR',
        'max_lr_coef' : 50,
        'interval' : 'step'
    }
} 

# prepare dataframe in order to build dataset & dataloader
df_real = df[df['label'] == 'REAL']
df_real_vid = df_real['video'].drop_duplicates()
df_real_splits = []
for i in range(4):
    drv = df_real_vid.sample(frac=1/(4-i), random_state=RANDOM_SEED)
    df_real_vid.drop(drv.index, inplace = True)
    df_real_splits.append(
        pd.merge(drv,df_real,how='left',left_on='video',right_on='video'))
    
df_fake = df[(df['label'] == 'FAKE') & (df['frame_id'] < 8)]
df_fake_vid = df_fake['video'].drop_duplicates()
df_fake_splits = []
for i in range(4):
    dfv = df_fake_vid.sample(frac=1/(4-i),random_state=RANDOM_SEED)
    df_fake_vid.drop(dfv.index, inplace=True)
    df_fake_splits.append(
        pd.merge(dfv,df_fake,how='left',left_on='video',right_on='video'))
    
df_splits = [
    pd.concat((df_fake_splits[i],df_real_splits[i])) \
    for i in range(4)
]

# 4 Fold CV Training ===========================
for fold_id in range(min(limit_fold,4)):
    # prepare train loader
    train_df = pd.concat([df_splits[i] for i in range(4) if i != fold_id])
    train_dset = MyDataset(data_config['root'],train_df,
                           get_train_transform(
                               image_size = data_config['image_size'],
                               precision = trainer_config['precision'],
                               gaussian_blur=data_config['gaussian_blur'],
                               random_erase=data_config['random_erase']))
    train_loader = DataLoader(dataset=train_dset, 
                              batch_size=data_config['batch_size'],
                              shuffle=True,
                              num_workers=data_config['num_workers'])
    # prepare validation loader
    valid_df = df_splits[fold_id]
    valid_dset = MyDataset(data_config['root'],valid_df,
                          get_valid_transform(
                              image_size = data_config['image_size'],
                              precision = trainer_config['precision']))
    valid_loader = DataLoader(dataset=valid_dset, 
                              batch_size=data_config['batch_size'], 
                              num_workers=data_config['num_workers'],
                              drop_last=True)
    print(train_dset, valid_dset)
    
    # init model, logger, monitor, trainer
    if model_config['scheduler_config']['scheduler'].lower() == 'onecyclelr':
        model_config['scheduler_config']['epochs'] = \
            trainer_config['max_epochs']
        lim_tb = trainer_config['limit_train_batches']
        acc_gb = trainer_config['accumulate_grad_batches']
        model_config['scheduler_config']['steps_per_epoch'] = \
            int(len(train_loader)*lim_tb/acc_gb)+1
    
    model = LitEfficientNet(model_config = model_config)
    ## use wandb logger when specified
    if logger_config['logger_type'].lower() == 'wandblogger':
        logger = WandbLogger(
            project=logger_config['project'],
            name=f"fold_{fold_id}",
            id=f"{datetime.now().strftime('%y%m%d%H%M')}_fold_{fold_id}",
            group=logger_config['group'],
            log_model=logger_config['log_model'],
            save_dir=f'/kaggle/working/ckpts')
        logger.experiment.config.update({
            'limit_fold' : limit_fold,
            'trainer_config': trainer_config,
            'callbacks_config':callbacks_config,
            'logger_config' : logger_config,
            'data_config' : data_config,
            'model_config' : model_config
        })
    else: # use CSV logger as default
        logger = CSVLogger(name=f"fold_{fold_id}",
                           save_dir=f'/kaggle/working/ckpts')
        logger.log_hyperparams({
            'limit_fold' : limit_fold,
            'trainer_config': trainer_config,
            'callbacks_config':callbacks_config,
            'logger_config' : logger_config,
            'data_config' : data_config,
            'model_config' : model_config
        })
    # init lr monitor
    callbacks = []
    if callbacks_config.get('lr_monitor',False):
        callbacks.append(
            LearningRateMonitor(**callbacks_config['lr_monitor']))
    # init early stopping callback
    if callbacks_config.get('early_stopping',False):
        callbacks.append(
            EarlyStopping(**callbacks_config['early_stopping']))
    # init trainer
    trainer = Trainer(**trainer_config, 
                      default_root_dir=f'/kaggle/working/ckpts',
                      logger=logger,
                      callbacks=callbacks)

    # run training
    trainer.fit(model=model, 
                train_dataloaders=train_loader, 
                val_dataloaders=valid_loader)
    # wandb finish
    if logger_config['logger_type'].lower() == 'wandblogger':
        wandb.finish()
    

<__main__.MyDataset object at 0x78a27aac64a0> of size 21482
Compose(
      PILToTensor()
      Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      GaussianBlur(kernel_size=(5, 5), sigma=[0.1, 0.1])
      RandomAdjustSharpness(p=0.5, sharpness_factor=2)
      RandomHorizontalFlip(p=0.5)
      ConvertImageDtype()
      Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
) <__main__.MyDataset object at 0x78a27aac65f0> of size 7098
Compose(
      PILToTensor()
      Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      ConvertImageDtype()
      Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
)


Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 158MB/s]
[34m[1mwandb[0m: Currently logged in as: [33mluanakwon[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO: `Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
INFO: `Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name    | Type              | Params
----------------------------------------------
0 | model   | EfficientNet      | 20.2 M
1 | loss_fn | BCEWithLogitsLoss | 0     
2 | acc_fn  | BinaryAccuracy    | 0     
----------------------------------------------
20.2 M    Trainable params
0         Non-trainable params
20.2 M    Total params
80.715    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁
lr-SGD,▁▁▁▁▂▂▂▂▃▃▃▃▅▅▅▅████
train_acc,▁█▁▁▃
train_loss,█▁▇▆▆
trainer/global_step,▁▁▁▁▁▃▃▃▃▃▅▅▅▅▅▆▆▆▆▆█████

0,1
epoch,0.0
lr-SGD,0.00098
train_acc,0.625
train_loss,0.66846
trainer/global_step,159.0


<__main__.MyDataset object at 0x78a1ad45d300> of size 21405
Compose(
      PILToTensor()
      Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      GaussianBlur(kernel_size=(5, 5), sigma=[0.1, 0.1])
      RandomAdjustSharpness(p=0.5, sharpness_factor=2)
      RandomHorizontalFlip(p=0.5)
      ConvertImageDtype()
      Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
) <__main__.MyDataset object at 0x78a27aac6380> of size 7175
Compose(
      PILToTensor()
      Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      ConvertImageDtype()
      Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
)




INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO: `Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
INFO: `Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name    | Type              | Params
----------------------------------------------
0 | model   | EfficientNet      | 20.2 M
1 | loss_fn | BCEWithLogitsLoss | 0     
2 | acc_fn  | BinaryAccuracy    | 0     
----------------------------------------------
20.2 M    Trainable params
0         Non-trainable params
20.2 M    Total params
80.715    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

In [None]:
# simple plot in the case CSVLogger
if logger_config['logger_type'].lower() != 'wandblogger':
    log_path = glob('/kaggle/working/ckpts/fold_0/version_*/metrics.csv')
    log_path.sort()
    print(log_path)
    log_df = pd.read_csv(log_path[-1])
    log_df.head()

    plt.figure(figsize=(5,10))
    for i, c in enumerate(log_df.columns):
        plt.subplot(len(log_df.columns),1,i+1)
        plt.ylabel(c)
        plt.plot(log_df['step'],log_df[c],'b.',
                 label=f"last {c} {log_df[c].dropna().iloc[-1]}")
        plt.legend()
    plt.show()    

In [None]:
## visualize validation result
        
    
# dataset class with filepath included
class MyDatasetwithFilepath(MyDataset):    
    def __getitem__(self, idx):
        img, label = super().__getitem__(idx)
        path = os.path.join(self.root, self.paths[idx])
        return img, label, path
    
######## pick the worst 10 validation image ############
n_samples = 6
fig = plt.figure(figsize=(n_samples,min(limit_fold,4)))
for fold_id in range(min(limit_fold,4)):
    # prepare validation loader
    valid_df = df_splits[fold_id]
    valid_dset = MyDatasetwithFilepath(data_config['root'],valid_df,
                          get_valid_transform(
                              image_size = data_config['image_size']))
    valid_loader = DataLoader(dataset=valid_dset, 
                              batch_size=data_config['batch_size'], 
                              num_workers=data_config['num_workers'],
                              drop_last=True)
    
    # init model, loss fn, results dict
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if logger_config['logger_type'].lower() == 'wandblogger':
        ckpt_path = get_ckpt_path(fold_id, logger_config['project'])
    else:
        ckpt_path = get_ckpt_path(fold_id)
        
    if ckpt_path is not None:
        model = LitEfficientNet.load_from_checkpoint(ckpt_path, 
                                                 map_location=device)
    else:
        print(f"No checkpoints found in fold_{fold_id}. skipping fold_{fold_id}.")
        continue
    sigmoid_fn = nn.Sigmoid()
    loss_fn = nn.BCEWithLogitsLoss(reduction='none')
    validation_result = {'path':[],'pred':[], 'target':[], 'loss':[]}
    print(f"model loaded from {ckpt_path}")
    
    # run validation
    model.eval()
    with torch.no_grad():
        for x, y, paths in tqdm(valid_loader):
            x = x.to(device)
            y = y.to(device)
            pred = model.forward(x)
            loss = loss_fn(pred, y)
            pred = sigmoid_fn(pred)
            validation_result['path'] += paths
            validation_result['pred'].append(pred)
            validation_result['target'].append(y)
            validation_result['loss'].append(loss)
        
    for key in ['pred','target','loss']:
        validation_result[key] = \
            torch.concat(validation_result[key]).cpu().flatten()
    validated_df = pd.DataFrame(validation_result)
    validated_df = validated_df.sort_values(by='loss',ascending=False).head(n_samples)
    for i, (_, row) in enumerate(validated_df.iterrows()):
        print(row['path'])
        image = Image.open(row['path'])
        ax = fig.add_subplot(min(limit_fold,4),n_samples,i+fold_id*n_samples+1)
        ax.imshow(image)
        ax.axis('off')
        ax.set_title(f"{row['target']}, {row['pred']:.6f}",fontsize=7)

plt.show()
fig.savefig('/kaggle/working/bad_samples.png')

### Weird things
Training with SGD when transfer learning / finetuning is a convention, snd I followed it. At first it sort of showed a decrease in loss metric, but it was not the best graph. Tried different LR scheduler and initial LRs, but the impact was hardly recognizable. Also, small LRs like 1e-5 gave a stable graph, but the convergence was too slow and I personally hardly ever saw using LR that small. Hence, tried starting with initial LR of 1e-4 with hard lr decay, and got some nan loss value in validation phase. The next approach I tried was using Adam, hoping nans and slow convergence might be solved. but Adam made all of the model weights to nan in single step, even with super small lr. Did not go any further with Adam but while debugging I've found that pretrained weight does not output nan but at the first epoch, even with lr=5e-5 + SGD, some weight becomes nan. But then on the following epoch this disappears, even with the same lr. My explanation of this, though I did not go very deep, is that fine-tuned pretrained weight is very vulnerable to gradient exploding and in order to avoid such, I tried OneCycleLR which starts small, gets bigger, and ends very small.
The other thing to note is the impact of batch size. First thing, which is obvious, having same epoch and bigger batch size slows down training. Bigger batch size might direct the gradient better towards the optimal minima, but with same LR, there will be lesser steps per epoch thus slowed down. Second one idk... 
Point is, Batch size is not just about hardware limit or training speed, but it has significant impact on final performance of the model. More experiment and explanation from other AI learner here(KR) -> https://inhovation97.tistory.com/32


2. once the seed is set, Dataframe.sample() on same dataframe returns the same, no matter how many times called.

In [None]:
import torch
import math
from torchvision.transforms import v2
from PIL import Image


from torchvision.transforms import functional as F

class GuidedRandomErasing(v2.RandomErasing):
    """Same as Random erasing but the rectangle region will appear with the
    given normal distribution ~N(m,s). Pixel coordinates (i,j) will be scaled to 
    fit the x domain of range (-5,5) in the normal distribution curve
    (x=(i-width/2)*10/width).
    
    m : (mean_y, mean_x)
    s : (std_y, std_x)
       
    """
    def __init__(self, 
                 p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, 
                 m=(0.5,0.5), s=(1,1), inplace=False):
        super().__init__(p,scale,ratio,value,inplace)
        self.m = torch.tensor(m)
        self.s = torch.tensor(s)
        
    @staticmethod 
    def get_erase_params(img, scale, ratio, mean, std, value=None):
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
        area = img_h * img_w
        
        log_ratio = torch.log(torch.tensor(ratio))
        for _ in range(10):
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
            if not (h < img_h and w < img_w):
                continue

            # pick i, j
            img_hw = torch.tensor((img_h,img_w))
            i_j = torch.normal(mean=torch.zeros(2), std=std)*img_hw/10
            i_j = i_j + mean*img_hw - torch.tensor((h,w))
            i, j = int(i_j[0].item()), int(i_j[1].item())
            # adjust h, w
            if i < 0:
                h -= -i
                i = 0
            elif i+h >= img_h:
                h = img_h-i-1
            if j < 0:
                w -= j
                j = 0
            elif j+w >= img_w:
                w = img_w-j-1
            # if h, w < 0 continue
            if h <= 0 or w <= 0:
                continue
                
            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
                
                # print(f"erase ratio {h*w/area}, {erase_area/area}")
            return i, j, h, w, v
        
        # Return original image
        return 0, 0, img_h, img_w, img
    
    def forward(self, img):
        """
        Args:
            img (Tensor): Tensor image to be erased.

        Returns:
            img (Tensor): Erased Tensor image.
        """
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
                value = [float(self.value)]
            elif isinstance(self.value, str):
                value = None
            elif isinstance(self.value, (list, tuple)):
                value = [float(v) for v in self.value]
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
                    f"{img.shape[-3]} (number of input channels)"
                )
                
            # TODO assert shape of mean and std
            # I give up
            x, y, h, w, v = GuidedRandomErasing.get_erase_params(
                img, scale=self.scale, ratio=self.ratio,
                mean = self.m, std = self.s, value=value)
            return F.erase(img, x, y, h, w, v, self.inplace)
        return img

    def __repr__(self):
        s0 = super().__repr__()
        s = (
            f"{s0[:-1]}, "
            f"mean={self.m} "
            f"std={self.s})")
        return s
    

In [None]:
# import matplotlib.pyplot as plt

GRE = GuidedRandomErasing(p=0.5,scale=(0.01,0.03),ratio=(0.3,1.3),m=(0.41,0.38),s=(0.24,0.48))

blank = torch.zeros((1, 224,224))

for _ in range(10):
    img = torch.ones((1, 224,224))
    img = GRE(img)
    #print(torch.min(img), torch.max(img))
    blank += img
    
blank /= torch.max(blank)+1
plt.imshow(blank.view(224,224))
plt.show()
    

In [None]:
# validation code
# transforms are
# 1. random erase * 2
# 2. guided random erase on eyes + nose + lips with uniform probability
#   - both transform have same expected_area_to_be_erased
# trained versions are
# a. gb0.1
# b. gb0.1-re0.2:1
# c. gb0.1-re0.2:2
# total 6 validations

# If Possible : using the best version, run eyes/nose/lips RE validations seperately 
# to see if certain part has more critical info.

In [None]:
# transform loader
def get_valid_transform_RE(image_size, precision='32-true', random_erase=[]):
    """
    random_erase List : list of dict with 'type' and 'kwargs' as keys
    """
    # precision
    if '16' in precision:
        dtype = torch.float16
    elif '32' in precision:
        dtype = torch.float32
    elif '64' in precision:
        dtype = torch.float64
    else:
        raise NotImplementedError(f'{precision} not implemented')
        
    tr_list = [
        v2.PILToTensor(),
        v2.Resize((image_size, image_size),antialias=True)]
    # random erase
    for re_layer in random_erase:
        if re_layer['type'].lower() == 'randomerasing':
            tr_list.append(v2.RandomErasing(**re_layer['kwargs']))
        elif re_layer['type'].lower() == 'guidedrandomerasing':
            tr_list.append(GuidedRandomErasing(**re_layer['kwargs']))
        else:
            # warning
            print(f"{re_layer['type']} not implemented. Skipping {re_layer['type']}")
            continue
    
    valid_transform = v2.Compose([
        *tr_list,
        v2.ConvertImageDtype(dtype=dtype),
        v2.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
    ])
    return valid_transform
    

# model loader
def get_model_from_wandb(logger, artifact_name, device, save_dir='/kaggle/working'):
    # download artifact to /kaggle/working
    run_id = artifact_name.split('/')[2]
    logger.download_artifact(artifact_name, save_dir=save_dir, artifact_type='model')
    ckpt_path = sorted(glob(f"{save_dir}/artifacts/{run_id}/*.ckpt"))[-1]
    return LitEfficientNet.load_from_checkpoint(ckpt_path, map_location=device)
    

In [None]:
# 0.5^4 * (1*0 + 4*0.02 + 6*0.04 + 4*0.06 + 1*0.08) = 0.04
# 0.25 * (1*0 + 2*a + 1*2*a) = a = 0.04

# configs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger_config = {
    'logger_type' : 'wandblogger',
    'project' : 'Deepfake_Detection-lightning-4cv',
    'log_model' : False,
    'group' : 're_validations'
}
data_config = {
    'root' : '/kaggle/working/input/using-yunet',
    'image_size' : 224,
    'batch_size' : 16,
    'num_workers' : 3
}
transform_layers = [
    { # original random erasing 1
        'type' : 'RandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.02,0.06),
            'ratio' : (0.3,3.0),
            'value' : 'random'
        }
    },
    { # original random erasing 2
        'type' : 'RandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.02,0.06),
            'ratio' : (0.3,3.0),
            'value' : 'random'
        }
    },
    { # left eye random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.5,1.3),
            'value' : 'random',
            'm' : (0.41,0.61),
            's' : (0.24,0.46)
        }
    },
    { # right eye random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.5,1.3),
            'value' : 'random',
            'm' : (0.41,0.38),
            's' : (0.24,0.43)
        }
    },
    { # nose random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.3,1.2),
            'value' : 'random',
            'm' : (0.54,0.5),
            's' : (0.24,0.7)
        }
    },
    { # lips random erasing
        'type' : 'GuidedRandomErasing',
        'kwargs' : {
            'p' : 0.5,
            'scale' : (0.01,0.03),
            'ratio' : (0.2,1),
            'value' : 'random',
            'm' : (0.66,0.5),
            's' : (0.22,0.48)
        }
    }
]
artifacts = [
    # TODO copy paste artifact api ids
    # a. gb0.1 - 이거 다시 돌려봐야함
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021253_fold_0:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021319_fold_1:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021345_fold_2:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2404021412_fold_3:v0',
    # b. gb0.1-re0.2:1
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240728_fold_0:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240757_fold_1:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240826_fold_2:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403240854_fold_3:v0',
    # c. gb0.1-re0.2:2
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241057_fold_0:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241124_fold_1:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241150_fold_2:v0',
    'luanakwon/Deepfake_Detection-lightning-4cv/model-2403241217_fold_3:v0'
]

In [None]:
seed_everything(seed=RANDOM_SEED, workers=True)
# re-initialize dataframe
FACE_DSET_META_PATH = '/kaggle/input/using-yunet/deepfake-detection-face-dataset.csv'
df = pd.read_csv(FACE_DSET_META_PATH)
# prepare dataframe in order to build dataset & dataloader
df_real = df[df['label'] == 'REAL']
df_real_vid = df_real['video'].drop_duplicates()
df_real_splits = []
for i in range(4):
    drv = df_real_vid.sample(frac=1/(4-i), random_state=RANDOM_SEED)
    df_real_vid.drop(drv.index, inplace = True)
    df_real_splits.append(
        pd.merge(drv,df_real,how='left',left_on='video',right_on='video'))
    
df_fake = df[(df['label'] == 'FAKE') & (df['frame_id'] < 8)]
df_fake_vid = df_fake['video'].drop_duplicates()
df_fake_splits = []
for i in range(4):
    dfv = df_fake_vid.sample(frac=1/(4-i),random_state=RANDOM_SEED)
    df_fake_vid.drop(dfv.index, inplace=True)
    df_fake_splits.append(
        pd.merge(dfv,df_fake,how='left',left_on='video',right_on='video'))
    
df_splits = [
    pd.concat((df_fake_splits[i],df_real_splits[i])) \
    for i in range(4)
]
# init logger
if logger_config['logger_type'].lower() == 'wandblogger':
    logger = WandbLogger()
else:
    raise NotImplementedError('other logger types not implemented')

# iterate through given artifacts
for artifact_id in artifacts:
    # download artifact if not exist at /kaggle/working
    # init model
    run_id = artifact_id.split('/')[-1]
    fold_id = int(run_id.split(':')[0].split('_fold_')[-1])
    if not os.path.exists(f'/kaggle/working/artifacts/{run_id}'):
        model = get_model_from_wandb(logger, artifact_id, device)
    else:
        print(f"Artifact '{artifact}' already exists. Using existing file...")
        ckpt_path = sorted(glob(f"/kaggle/working/artifacts/{run_id}/*.ckpt"))[-1]
        model = LitEfficientNet.load_from_checkpoint(ckpt_path, map_location=device)
    
    # iterate through given transform_layers
    for t_layers in (transform_layers[0:1],transform_layers[2:6]):
        # prep validation loader
        valid_df = df_splits[fold_id]
        valid_dset = MyDataset(data_config['root'], valid_df,
                               get_valid_transform_RE(
                                   image_size = data_config['image_size'],
                                   random_erase = t_layers))
        valid_loader = DataLoader(dataset=valid_dset, 
                                  batch_size=data_config['batch_size'],
                                  num_workers=data_config['num_workers'],
                                  drop_last=True)
        print(valid_dset)
        
        # update logger exp config
        logger.experiment.config.update({
            'Note' : artifact_id,
            'logger_config' : logger_config,
            'data_config' : data_config,
            'transform_layers' : t_layers
        })

Zhong, Z., Zheng, L., Kang, G., Li, S., & Yang, Y. (2020). Random Erasing Data Augmentation. Proceedings of the AAAI Conference on Artificial Intelligence, 34(07), 13001-13008. https://doi.org/10.1609/aaai.v34i07.7000

Lewy, Dominik & Mańdziuk, Jacek. (2021). An overview of mixing augmentation methods and augmentation strategies. 

Haliassos, A., Vougioukas, K., Petridis, S., & Pantic, M. (2020). Lips Don't Lie: A Generalisable and Robust Approach to Face Forgery Detection. ArXiv. /abs/2012.07657
