In [1]:
import collections
import ignite
import numpy as np
import os
import sys
import torch

from torch import FloatTensor, LongTensor
from torchvision import transforms as T
from typing import Callable, Iterable, List, Tuple

print(f"torch: {torch.__version__}, ignite: {ignite.__version__}")

torch: 1.3.1, ignite: 0.3.0


In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')
HDF5_DIR = '/media/dmitry/other/dfdc-crops/hdf5'

In [4]:
# src
sys.path.insert(0, os.path.join(BASE_DIR, 'vendors/Pytorch_Retinaface'))
sys.path.insert(0, SRC_DIR)

In [5]:
sys.path.insert(0, './utils')

In [6]:
from dataset.hdf5 import HDF5Dataset
from dataset.sample import FrameSampler, BalancedSampler
from model.detector import basic_detector_256
from model.loss import combined_loss

In [7]:
from visualise import show_images

In [8]:
def create_dataloader(bs: int, num_frames: int, real_fake_ratio: float, 
                      p_sparse_frames: float, chunks: Iterable[int]
                     ) -> torch.utils.data.DataLoader:
    dirs = [f'dfdc_train_part_{i}' for i in chunks]
    
    sampler = FrameSampler(num_frames, 
                           real_fake_ratio=real_fake_ratio, 
                           p_sparse=p_sparse_frames)
    tfms = T.Compose([
        T.ToTensor()
    ])
    ds = HDF5Dataset(HDF5_DIR, size=(num_frames, 256), 
                     sampler=sampler, x_tfms=tfms, sub_dirs=dirs)
    print('Num samples: {}'.format(len(ds)))
    
    batch_sampler = torch.utils.data.BatchSampler(
        BalancedSampler(ds),
        batch_size=bs, 
        drop_last=True
    )
    dl = torch.utils.data.DataLoader(ds, batch_sampler=batch_sampler)
    return dl

In [9]:
train_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=0.75, 
    chunks=range(5,30)
)

Num samples: 61779


In [10]:
valid_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=1., 
    chunks=range(0,5)
)

Num samples: 7937


In [11]:
model = basic_detector_256()

Using Conv3D pooling: 2 layers


In [12]:
device = torch.device('cuda')

In [13]:
model = model.to(device)

In [14]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optim, milestones=[9], gamma=0.3
)

In [15]:
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy

In [16]:
Batch = Tuple[FloatTensor, LongTensor]


def prepare_batch(batch: Batch) -> Batch:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    return x, y


def train(trainer: Engine, batch: Batch) -> int:
    model.train()
    optim.zero_grad()
    x, y = prepare_batch(batch)
    out = model(x, y)
    loss = combined_loss(out, x, y)
    loss.backward()
    optim.step()
    y_hat = (out[-1] > 0.5).float()
    return loss.item(), y_hat, y


def validate(trainer: Engine, batch: Batch) -> int:
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch)
        out = model(x, y)
        loss = combined_loss(out, x, y)
    y_hat = (out[-1] > 0.5).float()
    return loss.item(), y_hat, y

In [17]:
trainer = Engine(train)
evaluator = Engine(validate)

In [18]:
accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer,   'acc')
accuracy.attach(evaluator, 'acc')

In [19]:
def log_iter(engine: Engine, title: str) -> None:
    iteration = engine.state.iteration
    epoch = trainer.state.epoch
    loss = engine.state.output[0]
    print("{:>5} | ep: {}, it: {}, loss: {:.5f}".format(
        title, epoch, iteration, loss))
    
    
def log_epoch(engine: Engine, title: str) -> None:
    epoch = trainer.state.epoch
    acc = engine.state.metrics['acc']
    print("{:>5} | ep: {}, acc: {:.3f}\n".format(title, epoch, acc))
    

@trainer.on(Events.ITERATION_COMPLETED(every=1))
def log_train_iter(engine: Engine) -> None:
    log_iter(engine, 'train')
    
    
@evaluator.on(Events.ITERATION_COMPLETED(every=1))
def log_val_iter(engine: Engine) -> None:
    log_iter(engine, 'val')

In [20]:
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
    with evaluator.add_event_handler(Events.COMPLETED, log_epoch, 'val'):
        evaluator.run(valid_dl, epoch_length=2)

In [21]:
trainer.run(train_dl, max_epochs=3, epoch_length=5)

train | ep: 1, it: 1, loss: 1.72871
train | ep: 1, it: 2, loss: 1.85537
train | ep: 1, it: 3, loss: 1.81833
train | ep: 1, it: 4, loss: 1.71735
train | ep: 1, it: 5, loss: 1.74818
  val | ep: 1, it: 1, loss: 1.71932
  val | ep: 1, it: 2, loss: 1.66894
  val | ep: 1, acc: 0.417

train | ep: 2, it: 6, loss: 1.78472
train | ep: 2, it: 7, loss: 1.75002
train | ep: 2, it: 8, loss: 1.73071
train | ep: 2, it: 9, loss: 1.66770
train | ep: 2, it: 10, loss: 2.05835
  val | ep: 2, it: 1, loss: 1.73405
  val | ep: 2, it: 2, loss: 1.76036
  val | ep: 2, acc: 0.583

train | ep: 3, it: 11, loss: 1.77715
train | ep: 3, it: 12, loss: 1.75383
train | ep: 3, it: 13, loss: 1.68536
train | ep: 3, it: 14, loss: 1.71176
train | ep: 3, it: 15, loss: 1.73357
  val | ep: 3, it: 1, loss: 1.69722
  val | ep: 3, it: 2, loss: 1.73346
  val | ep: 3, acc: 0.417



State:
	iteration: 15
	epoch: 3
	epoch_length: 5
	max_epochs: 3
	output: <class 'tuple'>
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: 12

In [22]:
# data = iter(valid_dl)

In [23]:
# batch = next(data)

In [24]:
# images = batch[0].permute(0, 1, 3, 4, 2).numpy()

In [25]:
# show_images(images[:,1], cols=3)