In [83]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import cv2
import os

In [7]:
train_csv_path = "../../dfdc_dataset/train.csv"
train_data_path = "../../dfdc_dataset/archive/train/"

test_csv_path = "../../dfdc_dataset/test.csv"
test_data_path = "../../dfdc_dataset/archive/validation/"

In [45]:
class DFDCDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_csv, data_folder):
        self.data = pd.read_csv(data_csv)
        self.data_path = data_folder
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        img_path = self.data["path"][i]
        label = self.data["label"][i]
        label = 0 if label == "real" else 1 
        
        img_path = os.path.join(self.data_path, img_path)
        img = cv2.imread(img_path) / 255.
        
        return img, label

In [46]:
train_dataset = DFDCDataset(train_csv_path, train_data_path)
test_dataset = DFDCDataset(test_csv_path, test_data_path)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

MesoNet

In [95]:
class Meso(torch.nn.Module):
    def __init__(self):
        super(Meso, self).__init__()
        self.conv1 = nn.Conv2d(3,8,3,padding=1)
        self.conv2 = nn.Conv2d(8,8,5,padding=2)
        self.conv3 = nn.Conv2d(8,16,5,padding=2)
        self.conv4 = nn.Conv2d(16,16,5,padding=2)
        self.fc1 = nn.Linear(1600, 16)
        self.fc2 = nn.Linear(16,2 )
        
    def forward(self, data):
    
        x = self.conv1(data)
        x = F.batch_norm(F.relu(x),torch.zeros(8),torch.ones(8))
        x = F.max_pool2d(x, kernel_size = 2)
            
        x = self.conv2(x)
        x = F.batch_norm(F.relu(x),torch.zeros(8),torch.ones(8))
        x = F.max_pool2d(x, kernel_size = 2)
        
        x = self.conv3(x)
        x = F.batch_norm(F.relu(x),torch.zeros(16),torch.ones(16))
        x = F.max_pool2d(x, kernel_size = 2)
        
        x = self.conv4(x)
        x = F.batch_norm(F.relu(x),torch.zeros(16),torch.ones(16))
        x = F.max_pool2d(x, kernel_size = 2)
        
        x = torch.flatten(x, 1)
        
        x = F.dropout2d(x, p=0.5)
        x = F.relu(self.fc1(x))
        
        x = F.dropout2d(x, p=0.5)
        x = self.fc2(x)
        
        return x 

In [93]:
model = Meso()

In [94]:
model

Meso(
  (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=1600, out_features=16, bias=True)
  (fc2): Linear(in_features=16, out_features=2, bias=True)
)

In [99]:
optimizer = torch.optim.Adam(meso.parameters(), lr = 3e-4)
loss_fn = nn.CrossEntropyLoss()

In [100]:
for e in range(10):
    
    train_loop = tqdm(train_dataloader, leave=False, position=0)
    total_corr = 0
    total = 0
    total_loss = 0.
    model.train()
    
    train_loop.set_description(f"Epoch {e+1}")
    
    for batch_num, (imgs, labels) in enumerate(train_loop):
        optimizer.zero_grad()
        
        imgs = imgs.float().transpose(1, 3)
        
        preds = model(imgs)
        
        loss = loss_fn(preds, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        corr_preds = (torch.argmax(preds, dim=-1) == labels).sum().item()
        
        total += len(imgs)
        total_corr += corr_preds
        
        train_loop.set_postfix(loss = loss.item(), acc = corr_preds / len(imgs) * 100)
    
    print(f"TRAIN EPOCH {e + 1} LOSS = {total_loss / total} ACC = {total_corr / total * 100.}")
    
    
    test_loop = tqdm(test_dataloader, leave=False, position=0)
    total_corr = 0
    total = 0
    total_loss = 0.
    model.eval()
    
    test_loop.set_description(f"Epoch {e+1}")
    
    for batch_num, (imgs, labels) in enumerate(test_loop):
        
        imgs = imgs.float().transpose(1, 3)
        
        preds = model(imgs)
        
        loss = loss_fn(preds, labels)
        
        total_loss += loss.item()
        corr_preds = (torch.argmax(preds, dim=-1) == labels).sum().item()
        
        total += len(imgs)
        total_corr += corr_preds
        
        test_loop.set_postfix(loss = loss.item(), acc = corr_preds / len(imgs) * 100)
    
    print(f"TEST EPOCH {e + 1} LOSS = {total_loss / total} ACC = {total_corr / total * 100.}")
          

                                                                                                                       

KeyboardInterrupt: 