In [3]:
!git clone https://github.com/kiyoshi2000/automathon-2024-B.git

Cloning into 'automathon-2024-B'...
remote: Enumerating objects: 425, done.[K
remote: Counting objects: 100% (169/169), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 425 (delta 163), reused 160 (delta 160), pack-reused 256[K
Receiving objects: 100% (425/425), 2.19 MiB | 23.57 MiB/s, done.
Resolving deltas: 100% (238/238), done.


In [8]:
!pip install av

Collecting av
  Downloading av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.6 kB)
Downloading av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (33.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.8/33.8 MB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hInstalling collected packages: av
Successfully installed av-12.0.0


In [2]:
!pip install -r automathon-2024-B/requirements.txt

Collecting appnope==0.1.4 (from -r automathon-2024-B/requirements.txt (line 2))
  Downloading appnope-0.1.4-py2.py3-none-any.whl.metadata (908 bytes)
Collecting comm==0.2.2 (from -r automathon-2024-B/requirements.txt (line 7))
  Downloading comm-0.2.2-py3-none-any.whl.metadata (3.7 kB)
Collecting contourpy==1.2.1 (from -r automathon-2024-B/requirements.txt (line 8))
  Downloading contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.8 kB)
Collecting debugpy==1.8.1 (from -r automathon-2024-B/requirements.txt (line 10))
  Downloading debugpy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Collecting exceptiongroup==1.2.1 (from -r automathon-2024-B/requirements.txt (line 13))
  Downloading exceptiongroup-1.2.1-py3-none-any.whl.metadata (6.6 kB)
Collecting filelock==3.13.4 (from -r automathon-2024-B/requirements.txt (line 15))
  Downloading filelock-3.13.4-py3-none-any.whl.metadata (2.8 kB)
Collecting fonttools==4.51.0 (fr

In [4]:
import torch
import torch.nn as nn
import torchvision.models.detection.keypoint_rcnn as keypoint_rcnn
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import torch.optim as optim
import torchvision.io as io
import os
import json
from tqdm import tqdm
import csv
import timm
import wandb

from PIL import Image
import torchvision.transforms as transforms

In [9]:
class VideoDataset(Dataset):
    """
    This Dataset takes a video and returns a tensor of shape [10, 3, 256, 256]
    That is 10 colored frames of 256x256 pixels.
    """
    def __init__(
        self, root_dir, dataset_choice="train", nb_frames=10, trans=None, device="cpu"
    ):
        super().__init__()
        self.device = device
        self.dataset_choice = dataset_choice
        self.transforms = trans
        if  self.dataset_choice == "train":
            self.root_dir = os.path.join(root_dir, "dataset/train_dataset")
        elif  self.dataset_choice == "test":
            self.root_dir = os.path.join(root_dir, "dataset/test_dataset")
        elif  self.dataset_choice == "experimental":
            self.root_dir = os.path.join(root_dir, "dataset/experimental_dataset")
        else:
            raise ValueError("choice must be 'train', 'test' or 'experimental'")

        with open(os.path.join(root_dir, "dataset.csv"), 'r') as file:
            reader = csv.reader(file)
            # read dataset.csv with id,label columns to create
            # a dict which associated label: id
            self.ids = {row[1] : row[0] for row in reader}

        if self.dataset_choice == "test":
            self.data = None
        else:
            with open(os.path.join(self.root_dir, "metadata.json"), 'r') as file:
                self.data= json.load(file)
                self.data = {k : (torch.tensor(float(1)) if v == 'FAKE' else torch.tensor(float(0))) for k, v in self.data.items()}

        self.video_files = [f for f in os.listdir(self.root_dir) if f.endswith('.mp4')]
        
    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = os.path.join(self.root_dir, self.video_files[idx])
        video, audio, info = io.read_video(video_path, pts_unit='sec')
        
        video = video.permute(0,3,1,2)
        length = video.shape[0]
        video = video[[i*(length//(nb_frames)) for i in range(nb_frames)]]

        # resize the data into a reglar shape of 256x256 and normalize it
        #video = smart_resize(video, 256) / 255
        video = video / 255
        
        try:
            video = self._apply_transforms(video)
        except:
            assert False
            if idx == len(self) - 1:
                return self[0]
            return self[idx + 1]

        ID = self.ids[self.video_files[idx]]
        if self.dataset_choice == "test":
            return video, ID
        else:
            label = self.data[self.video_files[idx]]
            return video, label, ID
        
    def _apply_transforms(self, stack):
        """apply the transforms to the stack of frames"""
        if self.transforms is None:
            return stack
        return self.transforms(stack)
    


In [10]:
class Trainer:
    def __init__(self, model, train_dataloader, loss_fn, optimizer, scheduler, device):
        self.model = model
        self.dataloader = train_dataloader
        self.loss_fn = loss_fn
        self.optim = optimizer
        self.sched = scheduler
        self.device = device
    
    def train_one_epoch(self):
        total_loss = 0
        for i, (vid, label, _) in enumerate(self.dataloader):
            vid, label = vid.to(device), label.long().to(device)
            
            y_pred = self.model(vid)
            loss = self.loss_fn(y_pred, label)
            
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            
            total_loss += loss.item()
            print(f"idx: {i} - loss:{loss.item()}")
        return total_loss
            
    def train(self, epochs):
        for i in range(epochs):
            epoch_loss = self.train_one_epoch()
            print(f"epoch: #{i} | loss: {epoch_loss}")   
            if self.sched is not None:
                self.sched.step()

                
class CropFaces():
    
    def __init__(self, outshape=315, device="cpu"):
        self.model = keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True)
        self.model.to(device)
        self.model.eval()
        
        self._outshape = outshape
        self.device = device
    
    def __call__(self, stack):
        return self._crop_stack(stack)
    
    def _crop_frame(self, frame, keypoints):
        face_keypoints_indices = [0, 1, 2, 3, 4]  # Indices corresponding to face keypoints

        x_coords = [keypoints[i][0] for i in face_keypoints_indices]
        y_coords = [keypoints[i][1] for i in face_keypoints_indices]

        # Determine bounding box coordinates with extra padding
        xmin = int(min(x_coords))
        xmax = int(max(x_coords))
        ymin = int(min(y_coords))
        ymax = int(max(y_coords))

        xpad = (self._outshape - (xmax - xmin)) / 2
        ypad = (self._outshape - (ymax - ymin)) / 2
        if xpad < 0 or ypad < 0:
            raise ValueError

        ymin = ymin - floor(ypad)
        ymax = ymax + ceil(ypad)
        xmin = xmin - floor(xpad)
        xmax = xmax + ceil(xpad)
        return frame[:,ymin:ymax,xmin:xmax]
            
    def _crop_stack(self, stack):
        with torch.no_grad():
            batch_outputs = self.model(stack)
        out = []
        for i, outputs in enumerate(batch_outputs):
            frame = stack[i]            
            keypoints = outputs['keypoints'][0]
            out.append(self._crop_frame(frame, keypoints))
        return torch.stack(out)
            

In [11]:
import timm 
from math import floor, ceil
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset_dir = "/kaggle/input/automathon-deepfake"
nb_frames   = 10

# grayscale and flip
trans = transforms.Compose([
        transforms.Grayscale(),
        transforms.RandomHorizontalFlip(),
        CropFaces(),
        torch.squeeze
    ])

experimental_dataset = VideoDataset(
    dataset_dir,
    dataset_choice="experimental",
    nb_frames=nb_frames,
    trans=trans
)

train_dataloader = DataLoader(experimental_dataset, batch_size=1, shuffle=True, pin_memory=True)

model = timm.create_model("resnet18", pretrained=True, num_classes=2, in_chans=nb_frames)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
            
trainer = Trainer(model, train_dataloader, loss_fn, optimizer, None, device)

breakpoint()
trainer.train(1)

ImportError: PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
