# Imports and installations
- Pytorch
- Database import directly from Kaggle

Data source: https://www.kaggle.com/xhlulu/140k-real-and-fake-faces

In [None]:
!pip install timm
!pip install kaggle
! mkdir ~/.kaggle}
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/kaggle.json
! chmod 600 ~/.kaggle/kaggle.json


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
mkdir: cannot create directory ‘/root/.kaggle}’: File exists


In [None]:
! kaggle datasets download xhlulu/140k-real-and-fake-faces
! unzip '140k-real-and-fake-faces.zip';

In [None]:
from json import load
import torch
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor,Compose, Resize, Normalize
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
#from vision_transformer import VisionTransformer
from torch import nn
import timm
import torch.nn.functional as F

# Early Stopping
Implements the early stopping technique based on the value of the loss function. Used to prevent the training set from overfitting.

Source: https://github.com/Bjarten/early-stopping-pytorch


In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Auxiliary functions
- load_data : Loads the training, testing and validation set.
- acurracy : Calculates the accuracy of a set of logits.
- validation_on_end : Calculates the values of loss and acuracy per epoch and averages these values per batch.
- epoch_end: Prints the metrics during training, after each epoch.
- get_default_device: Identifies the GPU devices present in the environment.


In [None]:
def load_data(path,train_rate,batch_size,test=False):
    NORMALIZE_MEAN = (0.5, 0.5, 0.5)
    NORMALIZE_STD = (0.5, 0.5, 0.5)
    transform = Compose([
              Resize(size=(224, 224)),
              ToTensor(),
              Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
              ])
    
    dataset_train = ImageFolder(path+"/train", transform=transform)
    dataset_test = ImageFolder(path+"/test", transform=transform)
    dataset_valid =ImageFolder(path+"/valid", transform=transform)

    train_dl = DataLoader(dataset_train,batch_size,shuffle=True)
    test_dl = DataLoader(dataset_test,batch_size)
    valid_dl = DataLoader(dataset_valid,batch_size)
    
    return train_dl,test_dl,valid_dl

def accuracy(outputs, labels):
    _,preds = torch.max(outputs,dim = 1)
    return torch.tensor(torch.sum(preds==labels).item() / len(preds))
    
def validation_epoch_end(outputs):
        batch_loss = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_loss).mean()
        batch_acc = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_acc).mean()

        return {"val_loss":epoch_loss.item(),'val_acc':epoch_acc.item()}
def epoch_end( epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

In [None]:
def get_default_device():

    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def to_device(data,device):
    if isinstance(data, (list,tuple)):
        return [to_device(x,device) for x in data]

    return data.to(device, non_blocking=True)

class DeviceDataLoader():

    def __init__(self,dl,device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b,self.device)

    def __len__(self):
        return len(self.dl)

# Training


## Model Loading
  - DeiT patch 16
  


Source:
https://github.com/facebookresearch/dino

In [None]:

model= torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_224', pretrained=True)

for params in model.parameters():
  params.requires_grad = False


model.head = nn.Linear(768,2)
model.head_dist = nn.Linear(768,2)

model



Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main


DistilledVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNor

## Training and Evaluation Functions
- evaluate: Evaluates the model with the validation set
- fit: trains the model for a specified number of epochs.

In [None]:
@torch.no_grad()
def evaluate(model, val_loader):
    outputs =[]
    for batch in val_loader:
        out = model(batch[0])
        loss = F.cross_entropy(out[0],batch[1])
        acc = accuracy(out[0],batch[1])
        outputs.append({'val_loss': loss.detach(), "val_acc":acc})
    return validation_epoch_end(outputs)
        


def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD,path_to_save=''):
    history = []
    optimizer = opt_func(model.parameters(),lr)
    early_stopping = EarlyStopping(patience=4, verbose=True, path=path_to_save)
    for epoch in range(epochs):
        model.train()
        train_losses = []
        total = len(train_loader)
        for batch in train_loader:
            out = model(batch[0])
            loss_head = F.cross_entropy(out[0],batch[1])
            loss_head.backward()
            train_losses.append(loss_head)
            loss_dist = F.cross_entropy(out[1],batch[1])
            loss_dist.backward()
            optimizer.step()
            optimizer.zero_grad()

        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        epoch_end(epoch,result)
        early_stopping(result['val_loss'], model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        history.append(result)

    return history

# Main Script
- Database loading
- Transferring the model and the training and validation sets to the GPU
- Training the model

In [None]:
train,test,valid = load_data('/content/real_vs_fake/real-vs-fake',0.8,256)
device = get_default_device()

train = DeviceDataLoader(train,device)
valid = DeviceDataLoader(valid,device)
model.to(device);

In [None]:
num_epochs = 50
opt_func = torch.optim.Adam
lr = 0.001
path_to_save=''
history = fit(num_epochs, lr, model,train,valid, opt_func,path_to_save)

Epoch [0], train_loss: 0.2079, val_loss: 0.2014, val_acc: 0.9219
Validation loss decreased (inf --> 0.201378).  Saving model ...
Epoch [1], train_loss: 0.1925, val_loss: 0.1909, val_acc: 0.9271
Validation loss decreased (0.201378 --> 0.190869).  Saving model ...
Epoch [2], train_loss: 0.1851, val_loss: 0.1851, val_acc: 0.9288
Validation loss decreased (0.190869 --> 0.185087).  Saving model ...
Epoch [3], train_loss: 0.1797, val_loss: 0.1824, val_acc: 0.9291
Validation loss decreased (0.185087 --> 0.182419).  Saving model ...
Epoch [4], train_loss: 0.1770, val_loss: 0.1795, val_acc: 0.9298
Validation loss decreased (0.182419 --> 0.179498).  Saving model ...
Epoch [5], train_loss: 0.1748, val_loss: 0.1770, val_acc: 0.9304
Validation loss decreased (0.179498 --> 0.176960).  Saving model ...
Epoch [6], train_loss: 0.1737, val_loss: 0.1776, val_acc: 0.9296
EarlyStopping counter: 1 out of 4
Epoch [7], train_loss: 0.1722, val_loss: 0.1759, val_acc: 0.9314
Validation loss decreased (0.176960 -