## Imports

In [None]:
import itertools
import torch
import os
import random
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets.folder import make_dataset
from torchvision import transforms as t
import torch.nn as nn
import torchvision
from torchvision.models import ResNet50_Weights
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import random_split

## Utils Class

In [None]:
class TrainingUtilities:
    @staticmethod
    def save_model(model, model_descriptor, save_folder, verbose=0):
        torch.save(model.state_dict(), save_folder + f"/{model_descriptor}b1_model.pth")
        if verbose > 0:
            print(f"Saved model to {save_folder}/b1_model.pth")

    @staticmethod
    def save_checkpoint(epoch, model_state_dict, optimizer_state_dict, scheduler_state_dict=None, save_folder='', verbose=0):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
            'scheduler_state_dict': scheduler_state_dict
        }
        torch.save(checkpoint, save_folder + f'/checkpoint-epoch{epoch}.pth')
        if verbose > 0:
            print(f'Saved checkpoint to {save_folder}/checkpoint-epoch{epoch}.pth')

    @staticmethod
    def load_checkpoint(model, optimizer, checkpoint_path, scheduled, verbose=0):
        checkpoint = torch.load(checkpoint_path)

        if verbose > 0:
            print(f"Loading checkpoint from {checkpoint_path}")

        epoch = checkpoint['epoch']
        model_state_dict = checkpoint['model_state_dict']
        optimizer_state_dict = checkpoint['optimizer_state_dict']
        scheduler_state_dict = checkpoint['scheduler_state_dict']
        model = model.load_state_dict(model_state_dict)
        if scheduled:
            optimizer.load_state_dict(optimizer_state_dict, scheduler_state_dict)
        else:
            optimizer = optimizer.load_state_dict(optimizer_state_dict)
        return epoch, model, optimizer

## Dataset

In [None]:

def get_samples(root, extensions=(".mp4", ".avi")):
    samples = []

    # Define class labels
    class_to_idx = {
        "DFD_original_sequences": 0,  # Real videos
        "DFD_manipulated_sequences": 1  # Deepfake videos
    }

    for class_name, label in class_to_idx.items():
        class_dir = os.path.join(root, class_name)
        if class_name == 'DFD_manipulated_sequences':
            class_dir = os.path.join(class_dir, class_name)
        print(class_dir)
        if not os.path.exists(class_dir):
            continue

        # Get all video files in the directory
        for filename in os.listdir(class_dir):
            if filename.endswith(extensions):
                file_path = os.path.join(class_dir, filename)
                samples.append((file_path, label))

    return samples


class RandomDataset(torch.utils.data.IterableDataset):
    def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
        super(RandomDataset).__init__()

        self.samples = get_samples(root)

        # Allow for temporal jittering
        if epoch_size is None:
            epoch_size = len(self.samples)
        self.epoch_size = epoch_size

        self.clip_len = clip_len
        self.frame_transform = frame_transform
        self.video_transform = video_transform

    def __iter__(self):
        for i in range(self.epoch_size):
            # Get random sample
            path, target = random.choice(self.samples)
            # Get video object
            vid = torchvision.io.VideoReader(path, "video")
            metadata = vid.get_metadata()
            video_frames = []  # video frame buffer

            # Seek and return frames
            max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
            start = random.uniform(0., max_seek)
            for frame in itertools.islice(vid.seek(start), self.clip_len):
                video_frames.append(self.frame_transform(frame['data']))
                current_pts = frame['pts']
            # Stack it into a tensor
            video = torch.stack(video_frames, 0)
            if self.video_transform:
                video = self.video_transform(video)
            output = {
                'video': video,
                'target': target
            }
            yield output

## Model

In [None]:

class CNN_LSTM(nn.Module):
    def __init__(self, hidden_dim=512, num_layers=3, num_classes=2):
        super(CNN_LSTM, self).__init__()
        cnn = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.feature_extractor = nn.Sequential(*list(cnn.children())[:-1])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Global Average Pooling to (B, 2048, 1, 1)

        self.lstm = nn.LSTM(input_size=2048, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        batch_size, frames, C, H, W = x.shape  # (Batch, Frames, C, H, W)

        x = x.view(batch_size * frames, C, H, W)  # (Batch × Frames, C, H, W)
        features = self.feature_extractor(x)
        features = self.pool(features).squeeze(-1).squeeze(-1)  # (Batch × Frames, 2048)

        features = features.view(batch_size, frames, -1)  # (Batch, Frames, Feature_Dim)
        lstm_out, _ = self.lstm(features)
        out = self.fc(lstm_out[:, -1, :])
        return out

## Trainer

In [None]:
class ModelTrainer:
    def __init__(self, model, optimizer, scheduled, criterion, epochs, dataloaders, device, save_folder,
                 is_continue=False, checkpoint=None):
        self.model = model
        self.optimizer = optimizer
        self.scheduled = scheduled
        self.criterion = criterion
        self.epochs = epochs
        self.dataloaders = dataloaders
        self.DEVICE = device
        self.save_folder = save_folder
        self.is_continue = is_continue
        self.checkpoint = checkpoint

    def train_model(self, verbose=0):
        model, optimizer, criterion, epochs, dataloaders = self.model, self.optimizer, self.criterion, self.epochs, self.dataloaders
        device = self.DEVICE

        training_epoch = 0
        epoch = 0
        if self.is_continue:

            if verbose > 0:
                print(f"Continuing from checkpoint {self.checkpoint}")

            epoch, model, optimizer = TrainingUtilities.load_checkpoint(model, optimizer, self.checkpoint, self.scheduled, verbose)

        for training_epoch in range(epoch, epochs):

            train_losses = []
            val_losses = []
            val_accuracies = []
            avg_train_loss = 0
            avg_val_loss = 0
            val_accuracy = 0
            train_accuracy = 0

            for phase in ['train', 'val']:
                if phase == 'train':
                    train_loader = dataloaders['train']
                    model.train()
                    train_loss = 0
                    correct_train = 0
                    total_train = 0
                    for batch in tqdm(train_loader, desc=f"Epoch {training_epoch + 1}/5 - Training"):
                        video = batch['video'].to(device)  # (Batch, Frames, C, H, W)
                        target = batch['target'].to(device)

                        optimizer.zero_grad()
                        outputs = model(video)
                        loss = criterion(outputs, target)
                        loss.backward()
                        optimizer.step()

                        train_loss += loss.item()
                        _, predicted = torch.max(outputs, 1)
                        correct_train += (predicted == target).sum().item()
                        total_train += target.size(0)

                    avg_train_loss = train_loss / len(train_loader)
                    train_accuracy = correct_train / total_train
                else:
                    val_loader = dataloaders['val']
                    avg_val_loss, val_accuracy = self.evaluate(val_loader, training_epoch)

            print(
                f"Epoch [{epoch + 1}/5], Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")        #Test model
        self.test(dataloaders['test'])

    def test(self, test_loader):
        model, criterion,  device = self.model, self.criterion, self.DEVICE
        model.eval()
        test_loss = 0

        correct_test = 0
        total_test = 0

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Testing"):
                video = batch['video'].to(device)
                target = batch['target'].to(device)

                outputs = model(video)
                loss = criterion(outputs, target)
                test_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_test += (predicted == target).sum().item()
                total_test += target.size(0)

        avg_test_loss = test_loss / len(test_loader)
        test_accuracy = correct_test / total_test
        print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
        TrainingUtilities.save_model(model,'final_', self.save_folder)

    def evaluate(self,val_loader,epoch, verbose=0):
        model, criterion, device = self.model, self.criterion, self.DEVICE
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch + 1}/5 - Validation"):
                video = batch['video'].to(device)
                target = batch['target'].to(device)

                outputs = model(video)
                loss = criterion(outputs, target)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == target).sum().item()
                total_val += target.size(0)
        val_accuracy = correct_val / total_val
        avg_val_loss = val_loss / len(val_loader)
        return avg_val_loss, val_accuracy


## Code

In [None]:
dataset_root = "/kaggle/input/deep-fake-detection-dfd-entire-original-dataset"
dataset = RandomDataset(root=dataset_root, clip_len=16)

In [None]:
# Define split sizes
train_size = int(0.7 * len(dataset))  # 70% for training
val_size = int(0.15 * len(dataset))   # 15% for validation
test_size = len(dataset) - train_size - val_size  # 15% for testing

In [None]:
# Split dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)

dataloaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Initialize model, loss, and optimizer
model = CNN_LSTM().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
save_folder = '/kaggle/working/'
trainer = ModelTrainer(model=model, optimizer=optimizer, criterion=criterion, epochs=10, dataloaders=dataloaders, device=device, save_folder=save_folder, scheduled=False)

In [None]:
trainer.train_model(verbose=1)