In [1]:
import collections
import datetime as dt
import ignite
import numpy as np
import os
import sys
import time
import torch

from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Tuple

print(f"torch: {torch.__version__}, ignite: {ignite.__version__}")

torch: 1.3.1, ignite: 0.3.0


In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')
HDF5_DIR = '/media/dmitry/other/dfdc-crops/hdf5'
IMG_DIR = '/media/dmitry/other/dfdc-crops/webp_lossy'

In [4]:
# src
sys.path.insert(0, os.path.join(BASE_DIR, 'vendors/Pytorch_Retinaface'))
sys.path.insert(0, SRC_DIR)

In [5]:
sys.path.insert(0, './utils')

In [6]:
from dataset.hdf5 import HDF5Dataset
from dataset.images import ImagesDataset
from dataset.sample import FrameSampler, BalancedSampler
from model.detector import basic_detector_256, DetectorOut
from model.loss import combined_loss

In [7]:
from visualise import show_images

In [8]:
Batch = Tuple[FloatTensor, LongTensor]

In [9]:
def create_dataloader(bs: int, num_frames: int, real_fake_ratio: float, 
                      p_sparse_frames: float, chunks: Iterable[int]
                     ) -> torch.utils.data.DataLoader:
    dirs = [f'dfdc_train_part_{i}' for i in chunks]
    
    sampler = FrameSampler(num_frames, 
                           real_fake_ratio=real_fake_ratio, 
                           p_sparse=p_sparse_frames)
    tfms = T.Compose([
        T.ToTensor()
    ])
    ds = HDF5Dataset(HDF5_DIR, 
                     size=(num_frames, 256), 
                     sampler=sampler, 
                     transforms=tfms, 
                     sub_dirs=dirs)
    print('Num samples: {}'.format(len(ds)))
    
    batch_sampler = torch.utils.data.BatchSampler(
        BalancedSampler(ds),
        batch_size=bs, 
        drop_last=True
    )
    dl = torch.utils.data.DataLoader(ds, batch_sampler=batch_sampler)
    return dl


def prepare_batch(batch: Batch) -> Batch:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    return x, y

In [10]:
train_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=0.75, 
    chunks=range(5,30)
)

Num samples: 61779


In [11]:
valid_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=1., 
    chunks=range(0,5)
)

Num samples: 7937


In [12]:
model = basic_detector_256()

Using Conv3D pooling: 2 layers


In [13]:
device = torch.device('cuda')

In [14]:
x = torch.rand(12, 3, 10, 256, 256, device=device)
y = torch.randint(0, 2, (12,), device=device)

In [15]:
model = model.to(device)

In [16]:
%%time

data = iter(train_dl)
data = list(map(lambda _: next(data), tqdm(range(100))))

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


CPU times: user 2min 29s, sys: 6.11 s, total: 2min 36s
Wall time: 1min 1s


In [17]:
%%time

for batch in tqdm(data):
    batch = prepare_batch(batch)
    out = model(*batch)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


CPU times: user 6.14 s, sys: 1.72 s, total: 7.86 s
Wall time: 7.84 s


In [18]:
# 100 - 7.9s

In [19]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optim, milestones=[9], gamma=0.3
)

In [20]:
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy

In [21]:
def gather_outs(batch: Batch, model_out: DetectorOut, 
                loss: FloatTensor) -> Dict[str, Tensor]:
    out = {'loss': loss.item(),
           'y_pred': (model_out[-1] >= 0.5).float(),
           'y_true': batch[-1]}
    return out


def train(trainer: Engine, batch: Batch) -> int:
    model.train()
    optim.zero_grad()
    x, y = prepare_batch(batch)
    out = model(x, y)
    loss = combined_loss(out, x, y)
    loss.backward()
    optim.step()
    return gather_outs(batch, out, loss)


def validate(trainer: Engine, batch: Batch) -> int:
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch)
        out = model(x, y)
        loss = combined_loss(out, x, y)
    return gather_outs(batch, out, loss)

In [22]:
trainer = Engine(train)
evaluator = Engine(validate)

In [23]:
accuracy = Accuracy(output_transform=lambda out: [out['y_pred'], out['y_true']])
accuracy.attach(trainer,   'acc')
accuracy.attach(evaluator, 'acc')

In [24]:
@trainer.on(Events.EPOCH_STARTED)
@evaluator.on(Events.EPOCH_STARTED)
def start_epoch(engine: Engine):
    t0 = time.time()
    trainer.state.ep_time = t0
    trainer.state.time = t0

In [25]:
def log_iter(engine: Engine, title: str) -> None:
    epoch = trainer.state.epoch
    iteration = engine.state.iteration
    loss = engine.state.output['loss']
    t_ep = trainer.state.ep_time
    t0 = trainer.state.time
    t1 = time.time()
    cur_time = dt.datetime.fromtimestamp(t1).strftime('%H:%M:%S')
    print("[{}][{:.2f} s] {:>5} | ep: {:3d}, it: {:5d}, loss: {:.5f}".format(
        cur_time, t1 - t0, title, epoch, iteration, loss))
    trainer.state.time = t1
    
    
def log_epoch(engine: Engine, title: str) -> None:
    epoch = trainer.state.epoch
    metrics = engine.state.metrics
    print("{:>5} | ep: {}, acc: {:.3f}\n".format(
        title, epoch, metrics['acc']))


@trainer.on(Events.ITERATION_COMPLETED(every=5))
def log_train_iter(engine: Engine) -> None:
    log_iter(engine, 'train')
    
    
@evaluator.on(Events.ITERATION_COMPLETED(every=5))
def log_val_iter(engine: Engine) -> None:
    log_iter(engine, 'val')

In [26]:
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
    with evaluator.add_event_handler(Events.COMPLETED, log_epoch, 'val'):
        evaluator.run(valid_dl, epoch_length=5)

In [27]:
trainer.run(train_dl, max_epochs=3, epoch_length=100)

[00:36:19][2.27 s] train | ep:   1, it:     5, loss: 1.72960
[00:36:21][2.35 s] train | ep:   1, it:    10, loss: 1.94522
[00:36:24][3.12 s] train | ep:   1, it:    15, loss: 1.69148
[00:36:28][3.25 s] train | ep:   1, it:    20, loss: 1.71001
[00:36:31][3.24 s] train | ep:   1, it:    25, loss: 1.73655
[00:36:34][3.43 s] train | ep:   1, it:    30, loss: 1.72883
[00:36:38][4.27 s] train | ep:   1, it:    35, loss: 1.70893
[00:36:43][4.46 s] train | ep:   1, it:    40, loss: 1.70642
[00:36:48][4.62 s] train | ep:   1, it:    45, loss: 1.72912
[00:36:52][4.67 s] train | ep:   1, it:    50, loss: 1.70249
[00:36:57][4.79 s] train | ep:   1, it:    55, loss: 1.63429
[00:37:02][4.52 s] train | ep:   1, it:    60, loss: 1.67645
[00:37:06][4.51 s] train | ep:   1, it:    65, loss: 1.72357
[00:37:10][4.24 s] train | ep:   1, it:    70, loss: 1.66377
[00:37:15][4.32 s] train | ep:   1, it:    75, loss: 1.70160
[00:37:19][4.55 s] train | ep:   1, it:    80, loss: 1.82731
[00:37:23][4.28 s] train

State:
	iteration: 300
	epoch: 3
	epoch_length: 100
	max_epochs: 3
	output: <class 'dict'>
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: 12
	ep_time: 1584049237.6555922
	time: 1584049239.0117607

In [28]:
# data = iter(valid_dl)

In [29]:
# batch = next(data)

In [30]:
# images = batch[0].permute(0, 2, 3, 4, 1).numpy()

In [31]:
# show_images(images[:,1], cols=3)