In [1]:
import collections
import ignite
import numpy as np
import os
import sys
import torch

from torch import FloatTensor, LongTensor
from torchvision import transforms as T
from typing import Callable, Iterable, List, Tuple

print(f"torch: {torch.__version__}, ignite: {ignite.__version__}")

torch: 1.3.1, ignite: 0.3.0


In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')
HDF5_DIR = '/media/dmitry/other/dfdc-crops/hdf5'

In [4]:
# src
sys.path.insert(0, os.path.join(BASE_DIR, 'vendors/Pytorch_Retinaface'))
sys.path.insert(0, SRC_DIR)

In [5]:
sys.path.insert(0, './utils')

In [6]:
from dataset.hdf5 import HDF5Dataset
from dataset.sample import FrameSampler, BalancedSampler
from model.detector import basic_detector_256
from model.loss import combined_loss

In [7]:
from visualise import show_images

In [8]:
def create_dataloader(bs: int, num_frames: int, real_fake_ratio: float, 
                      p_sparse_frames: float, chunks: Iterable[int]
                     ) -> torch.utils.data.DataLoader:
    dirs = [f'dfdc_train_part_{i}' for i in chunks]
    
    sampler = FrameSampler(num_frames, 
                           real_fake_ratio=real_fake_ratio, 
                           p_sparse=p_sparse_frames)
    tfms = T.Compose([
        T.ToTensor()
    ])
    ds = HDF5Dataset(HDF5_DIR, size=(num_frames, 256), 
                     sampler=sampler, x_tfms=tfms, sub_dirs=dirs)
    print('Num samples: {}'.format(len(ds)))
    
    batch_sampler = torch.utils.data.BatchSampler(
        BalancedSampler(ds),
        batch_size=bs, 
        drop_last=True
    )
    dl = torch.utils.data.DataLoader(ds, batch_sampler=batch_sampler)
    return dl

In [9]:
train_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=0.75, 
    chunks=range(5,30)
)

Num samples: 61779


In [10]:
valid_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=1., 
    chunks=range(0,5)
)

Num samples: 7937


In [11]:
model = basic_detector_256()

Using Conv3D pooling: 2 layers


In [12]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optim, milestones=[9], gamma=0.3
)

In [13]:
from ignite.engine import Engine, Events

In [14]:
Batch = Tuple[FloatTensor, LongTensor]

def update_model(trainer: Engine, batch: Batch) -> int:
    model.train()
    optim.zero_grad()
    x, y = batch
    out = model(x, y)
    loss = combined_loss(out, x, y)
    loss.backward()
    optim.step()
    return loss.item()


def validate(trainer: Engine, batch: Batch) -> int:
    model.eval()
    with torch.no_grad():
        x, y = batch
        out = model(x, y)
        loss = combined_loss(out, x, y)
    return loss.item()

In [15]:
trainer = Engine(update_model)
evaluator = Engine(validate)

In [16]:
def log_iter(engine: Engine, title: str) -> None:
    iteration = engine.state.iteration
    epoch = engine.state.epoch
    loss = engine.state.output
    print("{} | ep: {}, it: {}, loss: {:.5f}".format(
        title, epoch, iteration, loss))
    
    
def log_epoch(engine: Engine, title: str) -> None:
    epoch = engine.state.epoch
    print("{} | ep: {}, accuracy: {:.2f}".format(title, epoch, 0.0))
    

@trainer.on(Events.ITERATION_COMPLETED(every=1))
def log_train_iter(engine: Engine) -> None:
    log_iter(engine, 'train')
    
    
@evaluator.on(Events.ITERATION_COMPLETED(every=1))
def log_val_iter(engine: Engine) -> None:
    log_iter(engine, 'val')

In [17]:
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
    with evaluator.add_event_handler(Events.COMPLETED, log_epoch, 'val'):
        evaluator.run(valid_dl)

In [18]:
trainer.run(train_dl, max_epochs=3, epoch_length=5)

train | ep: 1, it: 1, loss: 1.73834
train | ep: 1, it: 2, loss: 1.85836
train | ep: 1, it: 3, loss: 1.71780
train | ep: 1, it: 4, loss: 1.71548
train | ep: 1, it: 5, loss: 1.68621
val | ep: 1, it: 1, loss: 1.75351
val | ep: 1, it: 2, loss: 1.76389
val | ep: 1, it: 3, loss: 1.82034
val | ep: 1, it: 4, loss: 1.84660
val | ep: 1, it: 5, loss: 1.69976
val | ep: 1, it: 6, loss: 1.65810
val | ep: 1, it: 7, loss: 1.87485


Current run is terminating due to exception: .
Engine run is terminating due to exception: .
Engine run is terminating due to exception: .


KeyboardInterrupt: 

In [None]:
# data = iter(valid_dl)

In [None]:
# batch = next(data)

In [None]:
# images = batch[0].permute(0, 1, 3, 4, 2).numpy()

In [None]:
# show_images(images[:,1], cols=3)