In [1]:
import collections
import datetime as dt
import ignite
import numpy as np
import os
import sys
import time
import torch
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Tuple

print(f"torch: {torch.__version__}, ignite: {ignite.__version__}")

torch: 1.3.1, ignite: 0.3.0


In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')
HDF5_DIR = '/media/dmitry/other/dfdc-crops/hdf5'
IMG_DIR = '/media/dmitry/other/dfdc-crops/webp_lossy'

In [4]:
# src
sys.path.insert(0, os.path.join(BASE_DIR, 'vendors/Pytorch_Retinaface'))
sys.path.insert(0, SRC_DIR)

In [5]:
sys.path.insert(0, './utils')

In [6]:
from dataset.hdf5 import HDF5Dataset
from dataset.images import ImagesDataset
from dataset.sample import FrameSampler, BalancedSampler
from model.detector import basic_detector_256, DetectorOut
from model.loss import combined_loss

In [7]:
from visualise import show_images

In [8]:
Batch = Tuple[FloatTensor, LongTensor]

In [9]:
def create_dataloader(bs: int, num_frames: int, real_fake_ratio: float, 
                      p_sparse_frames: float, chunks: Iterable[int]
                     ) -> torch.utils.data.DataLoader:
    dirs = [f'dfdc_train_part_{i}' for i in chunks]
    
    sampler = FrameSampler(num_frames, 
                           real_fake_ratio=real_fake_ratio, 
                           p_sparse=p_sparse_frames)
    tfms = T.Compose([
        T.ToTensor()
    ])
    ds = HDF5Dataset(HDF5_DIR, 
                     size=(num_frames, 256), 
                     sampler=sampler, 
                     transforms=tfms, 
                     sub_dirs=dirs)
    print('Num samples: {}'.format(len(ds)))
    
    batch_sampler = torch.utils.data.BatchSampler(
        BalancedSampler(ds),
        batch_size=bs, 
        drop_last=True
    )
    dl = torch.utils.data.DataLoader(ds, batch_sampler=batch_sampler)
    return dl


def prepare_batch(batch: Batch) -> Batch:
    x, y = batch
    x = x.to(device)
    y = y.to(device)
    return x, y

In [10]:
train_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=0.75, 
    chunks=range(5,30)
)

Num samples: 61779


In [11]:
valid_dl = create_dataloader(
    bs=12, 
    num_frames=10, 
    real_fake_ratio=100/30, 
    p_sparse_frames=1., 
    chunks=range(0,5)
)

Num samples: 7937


In [12]:
model = basic_detector_256()

Using Conv3D pooling: 2 layers


In [13]:
device = torch.device('cuda:0')

In [14]:
model = model.to(device)

In [15]:
%%time

data = iter(train_dl)
data = list(map(lambda _: next(data), tqdm(range(1))))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


CPU times: user 1.6 s, sys: 111 ms, total: 1.71 s
Wall time: 781 ms


In [16]:
%%time

for batch in tqdm(data):
    batch = prepare_batch(batch)
    out = model(*batch)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


CPU times: user 117 ms, sys: 26 ms, total: 143 ms
Wall time: 109 ms


In [17]:
# 100 - 7.9s
# 100 - 8.36s <- 3rd ord derivative

In [18]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optim, milestones=[9], gamma=0.3
)

In [19]:
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss

In [20]:
def gather_outs(batch: Batch, model_out: DetectorOut, 
                loss: FloatTensor) -> Dict[str, Tensor]:
    y_pred = (model_out[-1] >= 0.5).flatten().float().detach().cpu()
    y_true = batch[-1].float().cpu()
    out = {
        'loss': loss.item(), 
        'y_pred': y_pred, 
        'y_true': y_true
    }
    return out


def train(trainer: Engine, batch: Batch) -> int:
    model.train()
    optim.zero_grad()
    x, y = prepare_batch(batch)
    out = model(x, y)
    loss = combined_loss(out, x, y)
    loss.backward()
    optim.step()
    return gather_outs(batch, out, loss)


def validate(trainer: Engine, batch: Batch) -> int:
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch)
        out = model(x, y)
        loss = combined_loss(out, x, y)
    return gather_outs(batch, out, loss)

In [31]:
trainer = Engine(train)
evaluator = Engine(validate)

In [32]:
parse_y = lambda out: [out['y_pred'], out['y_true']]
accuracy = Accuracy(output_transform=parse_y)
log_loss = Loss(torch.nn.BCELoss(), output_transform=parse_y)

for metric, key in zip([accuracy, log_loss], ['acc', 'nll']):
    for engine in [trainer, evaluator]:
        metric.attach(engine, key)

In [33]:
@trainer.on(Events.EPOCH_STARTED)
@evaluator.on(Events.EPOCH_STARTED)
def start_epoch(engine: Engine):
    engine.state.t0 = time.time()

In [34]:
log_freq = 1


def humanize_time(time: int) -> str:
    return dt.datetime.fromtimestamp(time).strftime('%H:%M:%S')


@trainer.on(Events.ITERATION_COMPLETED(every=log_freq), 'train')
@evaluator.on(Events.ITERATION_COMPLETED(every=log_freq), 'val')
def log_iter(engine: Engine, title: str) -> None:
    epoch = trainer.state.epoch
    iteration = engine.state.iteration
    loss = engine.state.output['loss']
    t0 = engine.state.t0
    t1 = time.time()
    it_time = (t1 - t0) / log_freq
    cur_time = humanize_time(t1)
    print("[{}][{:.2f} s] {:>5} | ep: {:2d}, it: {:3d}, loss: {:.5f}".format(
        cur_time, it_time, title, epoch, iteration, loss))
    engine.state.t0 = t1
    

def log_epoch(engine: Engine, title: str) -> None:
    epoch = trainer.state.epoch
    metrics = engine.state.metrics
    t1 = time.time()
    cur_time = humanize_time(t1)
    print("\n[{}] {:>5} | ep: {}, acc: {:.3f}, nll: {:.3f}\n".format(
        cur_time, title, epoch, metrics['acc'], metrics['nll']))

In [35]:
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
    with evaluator.add_event_handler(Events.COMPLETED, log_epoch, 'val'):
        evaluator.run(valid_dl, epoch_length=5)

In [36]:
trainer.run(train_dl, max_epochs=3, epoch_length=5)

[00:10:12][0.50 s] train | ep:  1, it:   1, loss: 1.71789
[00:10:12][0.46 s] train | ep:  1, it:   2, loss: 1.70496
[00:10:13][0.47 s] train | ep:  1, it:   3, loss: 1.70045
[00:10:13][0.47 s] train | ep:  1, it:   4, loss: 1.72319
[00:10:14][0.49 s] train | ep:  1, it:   5, loss: 1.73214
[00:10:14][0.26 s]   val | ep:  1, it:   1, loss: 1.70091
[00:10:14][0.27 s]   val | ep:  1, it:   2, loss: 1.70089
[00:10:14][0.32 s]   val | ep:  1, it:   3, loss: 1.64533
[00:10:15][0.26 s]   val | ep:  1, it:   4, loss: 1.65219
[00:10:15][0.26 s]   val | ep:  1, it:   5, loss: 1.77936

[00:10:15][     ]   val | ep: 1, acc: 0.400, nll: 16.579

[00:10:16][0.51 s] train | ep:  2, it:   6, loss: 1.70433
[00:10:16][0.49 s] train | ep:  2, it:   7, loss: 1.69832
[00:10:16][0.46 s] train | ep:  2, it:   8, loss: 1.67537
[00:10:17][0.48 s] train | ep:  2, it:   9, loss: 1.70126
[00:10:17][0.48 s] train | ep:  2, it:  10, loss: 1.75903
[00:10:18][0.26 s]   val | ep:  2, it:   1, loss: 1.69127
[00:10:18][0.

Current run is terminating due to exception: .
Engine run is terminating due to exception: .


[00:10:20][0.46 s] train | ep:  3, it:  12, loss: 1.74724


KeyboardInterrupt: 

In [27]:
# data = iter(valid_dl)

In [28]:
# batch = next(data)

In [29]:
# images = batch[0].permute(0, 2, 3, 4, 1).numpy()

In [30]:
# show_images(images[:,1], cols=3)