In [1]:
import os
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.models import resnet50

from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
DIR_TRAIN = "data/train/"
DIR_VAL = "data/val/"

In [3]:
train_imgs = os.listdir(DIR_TRAIN) 
val_imgs = os.listdir(DIR_VAL)

In [4]:
print(train_imgs[:5])
print(val_imgs[:5])

['F2F.0.png', 'F2F.1.png', 'F2F.2.png', 'F2F.3.png', 'F2F.4.png']
['F2F.0.png', 'F2F.1.png', 'F2F.2.png', 'F2F.3.png', 'F2F.4.png']


In [5]:
def get_train_transform():
    return T.Compose([
        T.RandomHorizontalFlip(p=0.5),
        T.RandomRotation(15),
        T.RandomCrop(204),
        T.ToTensor(),
        T.Normalize((0, 0, 0),(1, 1, 1))
    ])
    
def get_val_transform():
    return T.Compose([
        T.ToTensor(),
        T.Normalize((0, 0, 0),(1, 1, 1))
    ])

In [6]:
class DeepFakeDataset(Dataset):
    def __init__(self, imgs, mode = "train", transforms = None):
        super().__init__()
        self.imgs = imgs
        self.mode = mode
        self.transforms = transforms
    
    def __getitem__(self, idx):
        image_name = self.imgs[idx]
        
        
        # Training
        if self.mode == "train":
            img = Image.open(DIR_TRAIN + image_name)
            img = img.resize((224, 224))
            
            # Prepare class label
            if (image_name.split(".")[0] == "original"):
                label = 1 # It is original
            else:
                label = 0 # It is fake
            
            label = torch.tensor(label, dtype = torch.float32)
            
            ### Apply Transforms on image
            img = self.transforms(img)
            
            return img, label
                
        # Validation
        elif self.mode == "val":
            img = Image.open(DIR_VAL + image_name)
            img = img.resize((224, 224))
            
            # Prepare class label
            if (image_name.split(".")[0] == "original"):
                label = 1 # It is original
            else:
                label = 0 # It is fake
                
            label = torch.tensor(label, dtype = torch.float32)
            
            ### Apply Transforms on image
            img = self.transforms(img)
            
            return img, label
        
        elif self.mode == "test":
            
            ### Apply Transforms on image
            img = self.transforms(img)
            
            return img
    
    def __len__(self):
        return len(self.imgs)

In [7]:
train_dataset = DeepFakeDataset(train_imgs, mode = "train", transforms = get_train_transform())
val_dataset = DeepFakeDataset(val_imgs, mode = "val", transforms = get_val_transform())

train_data_loader = DataLoader(
    dataset = train_dataset,
    # num_workers = 4,
    batch_size = 16,
    shuffle = True
)

val_data_loader = DataLoader(
    dataset = val_dataset,
    # num_workers = 4,
    batch_size = 16,
    shuffle = True
)

In [8]:
device = torch.device("cuda")

In [9]:
torch.cuda.is_available()

True

In [10]:
def accuracy(preds, trues):
    
    ### Converting preds to 0 or 1
    preds = [1 if preds[i] >= 0.5 else 0 for i in range(len(preds))]
    
    ### Calculating accuracy by comparing predictions with true labels
    acc = [1 if preds[i] == trues[i] else 0 for i in range(len(preds))]
    
    ### Summing over all correct predictions
    acc = np.sum(acc) / len(preds)
    
    return (acc * 100)

In [11]:
def train_one_epoch(train_data_loader):
    
    ### Local Parameters
    epoch_loss = []
    epoch_acc = []
    start_time = time.time()
    
    ###Iterating over data loader
    for images, labels in train_data_loader:
        
        #Loading images and labels to device
        images = images.to(device)
        labels = labels.to(device)
        labels = labels.reshape((labels.shape[0], 1)) # [N, 1] - to match with preds shape
        
        #Reseting Gradients
        optimizer.zero_grad()
        
        #Forward
        preds = model(images)
        
        #Calculating Loss
        _loss = criterion(preds, labels)
        loss = _loss.item()
        epoch_loss.append(loss)
        
        #Calculating Accuracy
        acc = accuracy(preds, labels)
        epoch_acc.append(acc)
        
        #Backward
        _loss.backward()
        optimizer.step()
    
    ###Overall Epoch Results
    end_time = time.time()
    total_time = end_time - start_time
    
    ###Acc and Loss
    epoch_loss = np.mean(epoch_loss)
    epoch_acc = np.mean(epoch_acc)
    
    ###Storing results to logs
    train_logs["loss"].append(epoch_loss)
    train_logs["accuracy"].append(epoch_acc)
    train_logs["time"].append(total_time)
        
    return epoch_loss, epoch_acc, total_time

In [12]:
def val_one_epoch(val_data_loader, best_val_acc):
    
    ### Local Parameters
    epoch_loss = []
    epoch_acc = []
    start_time = time.time()
    
    ###Iterating over data loader
    for images, labels in val_data_loader:
        
        #Loading images and labels to device
        images = images.to(device)
        labels = labels.to(device)
        labels = labels.reshape((labels.shape[0], 1)) # [N, 1] - to match with preds shape
        
        #Forward
        preds = model(images)
        
        #Calculating Loss
        _loss = criterion(preds, labels)
        loss = _loss.item()
        epoch_loss.append(loss)
        
        #Calculating Accuracy
        acc = accuracy(preds, labels)
        epoch_acc.append(acc)
    
    ###Overall Epoch Results
    end_time = time.time()
    total_time = end_time - start_time
    
    ###Acc and Loss
    epoch_loss = np.mean(epoch_loss)
    epoch_acc = np.mean(epoch_acc)
    
    ###Storing results to logs
    val_logs["loss"].append(epoch_loss)
    val_logs["accuracy"].append(epoch_acc)
    val_logs["time"].append(total_time)
    
    ###Saving best model
    if epoch_acc > best_val_acc:
        best_val_acc = epoch_acc
        torch.save(model.state_dict(),"resnet50_best.pth")
        
    return epoch_loss, epoch_acc, total_time, best_val_acc

In [13]:
model = resnet50(pretrained = True)

# Modifying Head - classifier

model.fc = nn.Sequential(
    nn.Linear(2048, 1, bias = True),
    nn.Sigmoid()
)

In [14]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

# Learning Rate Scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.5)

#Loss Function
criterion = nn.BCELoss()

# Logs - Helpful for plotting after training finishes
train_logs = {"loss" : [], "accuracy" : [], "time" : []}
val_logs = {"loss" : [], "accuracy" : [], "time" : []}

# Loading model to device
model.to(device)

# No of epochs 
epochs = 40

In [None]:
import time
best_val_acc = 0
for epoch in range(epochs):
    
    ###Training
    loss, acc, _time = train_one_epoch(train_data_loader)
    
    #Print Epoch Details
    print("\nTraining")
    print("Epoch {}".format(epoch+1))
    print("Loss : {}".format(round(loss, 4)))
    print("Acc : {}".format(round(acc, 4)))
    print("Time : {}".format(round(_time, 4)))
    
    ###Validation
    loss, acc, _time, best_val_acc = val_one_epoch(val_data_loader, best_val_acc)
    
    #Print Epoch Details
    print("\nValidating")
    print("Epoch {}".format(epoch+1))
    print("Loss : {}".format(round(loss, 4)))
    print("Acc : {}".format(round(acc, 4)))
    print("Time : {}".format(round(_time, 4)))
    


Training
Epoch 1
Loss : 0.5536
Acc : 69.7713
Time : 154.2331

Validating
Epoch 1
Loss : 0.4479
Acc : 77.992
Time : 22.141

Training
Epoch 2
Loss : 0.409
Acc : 80.2842
Time : 137.3524

Validating
Epoch 2
Loss : 0.3542
Acc : 84.109
Time : 19.0882

Training
Epoch 3
Loss : 0.3391
Acc : 84.325
Time : 145.8545

Validating
Epoch 3
Loss : 0.3337
Acc : 85.4388
Time : 24.1803

Training
Epoch 4
Loss : 0.2936
Acc : 87.0004
Time : 139.4376

Validating
Epoch 4
Loss : 0.2638
Acc : 88.7633
Time : 19.1897

Training
Epoch 5
Loss : 0.2632
Acc : 88.6323
Time : 133.058

Validating
Epoch 5
Loss : 0.2928
Acc : 86.3032
Time : 19.5307

Training
Epoch 6
Loss : 0.2344
Acc : 90.2087
Time : 134.5653

Validating
Epoch 6
Loss : 0.237
Acc : 89.0293
Time : 19.12

Training
Epoch 7
Loss : 0.2173
Acc : 91.008
Time : 132.4732

Validating
Epoch 7
Loss : 0.2106
Acc : 91.8883
Time : 19.0179

Training
Epoch 8
Loss : 0.1933
Acc : 92.1292
Time : 133.2204

Validating
Epoch 8
Loss : 0.2491
Acc : 90.492
Time : 19.055

Training
Ep

In [None]:
### Plotting Results

#Loss
plt.title("Loss")
plt.plot(np.arange(1, 41, 1), train_logs["loss"], color = 'blue')
plt.plot(np.arange(1, 41, 1), val_logs["loss"], color = 'yellow')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

#Accuracy
plt.title("Accuracy")
plt.plot(np.arange(1, 41, 1), train_logs["accuracy"], color = 'blue')
plt.plot(np.arange(1, 41, 1), val_logs["accuracy"], color = 'yellow')
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.show()