In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn
import torch
import cv2
import os
import random
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from torchvision import models
from torchvision.datasets import ImageFolder
import torchvision.utils as vutils
from tqdm import tqdm

import kagglehub

# Download latest version
celebs_path = kagglehub.dataset_download("reubensuju/celeb-df-v2")

print("Path to dataset files:", celebs_path)

def extract_frames(video_path, filename, output_folder):
        # Create output folder if it doesn't exist
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # Open the video file
        cap = cv2.VideoCapture(video_path)

        if not cap.isOpened():
            print(f"Error: Could not open video file {video_path}")
            return

        frame_count = 0
        while True:
            # Read a frame
            ret, frame = cap.read()

            # If frame is not read successfully, break the loop
            if not ret:
                break

            # Save the frame
            frame_filename = os.path.join(output_folder, f"frame_{frame_count:05d}_{filename}.jpg")
            cv2.imwrite(frame_filename, frame)

            frame_count += 1

        # Release the video capture object
        cap.release()
        print(f"Extracted {frame_count} frames to {output_folder}")

with open(os.path.join(celebs_path, 'List_of_testing_videos.txt')) as f:
    lines = f.read().strip().split("\n")
    for line in lines:
        label_video = line.split(" ")
        label = label_video[0]
        video = label_video[1]
        video_path = os.path.join(celebs_path, video)
        output_path = os.path.join("frames", label)
        extract_frames(video_path, video[:-4].replace("/", "_"), output_path)


#each image will be 64x64
image_size = 64

celebs_image_dataset = ImageFolder(root="frames",
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

features_extract = []
labels_extract = []
for i in tqdm(range(len(celebs_image_dataset))):
    t = celebs_image_dataset[i]
    features_extract.append(t[0])
    labels_extract.append(t[1])

X_unshuffled = torch.stack(features_extract, dim = 0)
y_unshuffled = torch.Tensor(labels_extract).type(torch.LongTensor)

#convention: real = 0, fake = 1
real_mask = y_unshuffled == 1
fake_mask = y_unshuffled == 0

y_unshuffled[real_mask] = 0
y_unshuffled[fake_mask] = 1


random.seed(0)
shuffle_mask = np.arange(0, len(X_unshuffled), 1)
random.shuffle(shuffle_mask)

X_shuffled = X_unshuffled[shuffle_mask]
y_shuffled = y_unshuffled[shuffle_mask]

n_real = 25000
n_fake = 25000

real_indices = torch.argwhere(y_shuffled == 0).T.squeeze()
fake_indices = torch.argwhere(y_shuffled == 1).T.squeeze()

sample_indices = torch.cat([real_indices[:n_real], fake_indices[:n_fake]], dim=0)

torch.save(X_shuffled[sample_indices], "X_celebs_df_2.pt")
torch.save(y_shuffled[sample_indices], "y_celebs_df_2.pt")
print("Feature size:", X_shuffled.shape)
print("Label size:", y_shuffled.shape)

X_sample = torch.load("X_celebs_df_2.pt")
y_sample = torch.load("y_celebs_df_2.pt")

random.seed(0)
shuffle_mask = np.arange(0, len(X_sample), 1)
random.shuffle(shuffle_mask)

X_sample = X_sample[shuffle_mask]
y_sample = y_sample[shuffle_mask]

print(X_sample.shape)
print(y_sample.shape)


class ResNetClassifier:

    def __init__(self, resnet_model: nn.Module = None, k:int=2, lr:float=1e-3, epochs: int=10, batch_size: int=64,
                 optimizer_func: torch.optim.Optimizer = torch.optim.Adam):

        assert resnet_model is not None
        self.model = nn.Sequential(resnet_model, nn.Linear(1000, k), nn.Softmax(dim=1))
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.optimizer_func = optimizer_func


    def eval(self, train_X: torch.Tensor, train_y: torch.Tensor, test_X: torch.Tensor, test_y: torch.Tensor):
        self.model.eval()
        train_pred_y = self.predict(train_X)
        test_pred_y = self.predict(test_X)

        acc_train = accuracy_score(train_y, train_pred_y)
        acc_test = accuracy_score(test_y, test_pred_y)

        criterion = nn.CrossEntropyLoss()
        loss_train = criterion(self.model(train_X), train_y).detach()
        loss_test =  criterion(self.model(test_X), test_y).detach()

        return acc_train, acc_test, loss_train, loss_test


    def fit(self, X: torch.Tensor, y: torch.Tensor):

        criterion = nn.CrossEntropyLoss()
        optimizer = self.optimizer_func(self.model.parameters(), self.lr)

        train_X, test_X, train_y, test_y = train_test_split(X, y, test_size= 0.2)
        train_tensor = TensorDataset(train_X, train_y)
        train_loader = DataLoader(train_tensor, batch_size=self.batch_size, shuffle=True)

        self.train_losses = []
        self.test_losses = []
        self.train_acces = []
        self.test_acces = []
        for i in range(self.epochs):
            self.model.train()

            for batch in train_loader:
                #extract features (X) and labels (y)
                X_tensor, y_tensor = batch

                #extract the hypothesis for X (the probabilities according to our model)
                h_X = self.model(X_tensor)

                #compute cross entropy loss
                loss = criterion(h_X, y_tensor)

                #update the model's parameters using gradient ascent and backpropagation
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                acc_train, acc_test, loss_train, loss_test = self.eval(train_X, train_y, test_X, test_y)

                self.train_losses.append(loss_train)
                self.test_losses.append(loss_test)
                self.train_acces.append(acc_train)
                self.test_acces.append(acc_test)

            print("Epoch {} Complete: Train Loss = {}, Test Loss = {}, Train Accuracy = {}%, Test Accuracy = {}%".format(i, loss_train, loss_test, acc_train*100, acc_test*100))

    def plot_results(self):
        n_epochs = len(self.train_losses)
        epochs = np.arange(n_epochs)
        plt.plot(epochs, self.train_losses, label="Train Loss", c='r')
        plt.plot(epochs, self.test_losses, label="Test Loss", c='b')
        plt.xlabel("Mini Epochs")
        plt.ylabel("Binary Cross Entropy Loss")
        plt.legend()
        plt.show()

        plt.plot(epochs, self.train_acces, label="Train Accuracy", c='r')
        plt.plot(epochs, self.test_acces, label="Test Accuracy", c='b')
        plt.xlabel("Mini Epochs")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.show()


    def predict(self, X):
        self.model.eval()
        #extract the hypothesis for X (the probabilities according to our model)
        h_X = self.model(X)
        #give a prediction based on which class probability is higher
        _, preds = torch.max(h_X, dim=1)
        return preds.detach()

resnet18 = models.resnet18(weights=None)
resnet18_deepfake_classifier = ResNetClassifier(resnet_model=resnet18, lr=1e-4, epochs = 10, batch_size=32)

resnet18_deepfake_classifier.fit(X_sample[:1000], y_sample[:1000])

resnet18_deepfake_classifier.plot_results()


class ConvEncoder(nn.Module):
    def __init__(self, in_channels=3, hidden_channels = [32, 64, 128], out_channels=256, input_height=64, input_width=64, latent_dim=128):
        super().__init__()

        output_height = input_height // 2 ** (len(hidden_channels) + 1)
        output_width = input_width // 2 ** (len(hidden_channels) + 1)

        layers = []
        layers.append(nn.Conv2d(in_channels, hidden_channels[0], kernel_size=4, stride=2, padding=1))
        layers.append(nn.ReLU(True))
        for k in range(0, len(hidden_channels)-1):
            layers.append(nn.Conv2d(hidden_channels[k], hidden_channels[k+1], kernel_size=4, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(hidden_channels[k+1]))
            layers.append(nn.ReLU(True))
        layers.append(nn.Conv2d(hidden_channels[-1], out_channels, kernel_size=4, stride=2, padding=1))
        self.net = nn.Sequential(*layers)

        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(out_channels * output_height * output_width, latent_dim)
        self.fc_logvar = nn.Linear(out_channels * output_height * output_width, latent_dim)

    def forward(self, x):
        x = self.net(x)
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class ConvDecoder(nn.Module):
    def __init__(self, in_channels=256, hidden_channels = [128, 64, 32], out_channels=3, output_height=64, output_width=64, latent_dim=128):
        super().__init__()

        self.in_channels = in_channels
        self.input_height = output_height // 2 ** (len(hidden_channels) + 1)
        self.input_width = output_width // 2 ** (len(hidden_channels) + 1)

        self.fc = nn.Linear(latent_dim, self.in_channels * self.input_height * self.input_width)
        layers = []
        layers.append(nn.ConvTranspose2d(in_channels, hidden_channels[0], kernel_size=4, stride=2, padding=1))
        layers.append(nn.ReLU(True))
        for k in range(0, len(hidden_channels)-1):
            layers.append(nn.ConvTranspose2d(hidden_channels[k], hidden_channels[k+1], kernel_size=4, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(hidden_channels[k+1]))
            layers.append(nn.ReLU(True))
        layers.append(nn.ConvTranspose2d(hidden_channels[-1], out_channels, kernel_size=4, stride=2, padding=1))
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, self.in_channels, self.input_height, self.input_width)
        x = self.net(x)
        return x

class VAELoss(nn.Module):

    def __init__(self):

        super().__init__()

    def forward(self, x, x_hat, mu, logvar):

        recon_loss = F.mse_loss(x_hat, x, reduction='sum')

        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return (recon_loss + kl_loss) / x.size(0)

class ConvVAE(nn.Module):

    def __init__(self, in_channels=3, hidden_channels = [32, 64, 128], out_channels = 256, height=64, width=64, latent_dim=128):

        super().__init__()

        self.encoder = ConvEncoder(in_channels, hidden_channels, out_channels, height, width, latent_dim)

        self.decoder = ConvDecoder(out_channels, hidden_channels[::-1], in_channels, height, width, latent_dim)

    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5 * logvar)

        eps = torch.randn_like(std)

        return mu + eps * std

    def forward(self, x):

        mu, logvar = self.encoder(x)

        z = self.reparameterize(mu, logvar)

        x_hat = self.decoder(z)

        return x_hat, mu, logvar


class VAEWrapper:

    def __init__(self, in_channels=3, hidden_channels = [128, 64, 32], out_channels = 256, height=64, width=64, latent_dim=128, lr:float=1e-3,

                 epochs: int=10, batch_size: int=64, optimizer_func: torch.optim.Optimizer = torch.optim.Adam):

        self.latent_dim = latent_dim

        self.model = ConvVAE(in_channels, hidden_channels, out_channels, height, width, self.latent_dim)

        self.lr = lr

        self.epochs = epochs

        self.batch_size = batch_size

        self.optimizer_func = optimizer_func


    def fit(self, X):

        criterion = VAELoss()

        optimizer = self.optimizer_func(self.model.parameters(), self.lr)

        train_tensor = TensorDataset(X)

        train_loader = DataLoader(train_tensor, batch_size=self.batch_size, shuffle=True)

        for i in range(self.epochs):

            total_loss = 0

            self.model.train()

            for batch in train_loader:

                imgs = batch[0]

                x_hat, mu, logvar = self.model(imgs)

                loss = criterion(imgs, x_hat, mu, logvar)

                loss.backward()

                optimizer.step()

                optimizer.zero_grad()

                total_loss += loss.item()

            print("Epoch {} Complete: VAE Custom Loss = {}".format(i, total_loss / len(train_loader)))

    def reconstruction_loss_weights(self, X):

        self.model.eval()

        X_hat, _, _ = self.model(X)

        recon_error = (X - X_hat).pow(2)

        per_image_mse = recon_error.mean(dim=[1, 2, 3])

        weights = F.softmax(2.0 * recon_error, dim=0)

        return weights


    def generate_images(self, z):

        assert z.shape[-1] == self.latent_dim

        self.model.eval()

        with torch.no_grad():

            samples = self.model.decoder(z)


        return samples

vae = VAEWrapper()

vae.fit(real_imgs[:20000])

samples = vae.generate_images(torch.rand((64, 128)))

def reconstruction_weighted_BCEloss(x_hat, x, y_prob, y):

    '''
    x_hat: vae reconstruction
    x: input image
    y_prob: CNN classifier output class probabilities (0-1)
    y: class label for input image
    '''

    criterion = nn.BCELoss(reduction='none')

    bce_per_image = criterion(y_prob, y)

    recon_error = ((x - x_hat)**2).mean(dim=[1,2,3])

    weights = F.softmax(2.0 * recon_error, dim=0)

    weighted_loss = (bce_per_image * weights).sum()

    return weighted_loss
