In [1]:
import math
import os
import gc
import sys
import time

from typing import List

from numba import jit, njit
from pathlib import Path
from tqdm.notebook import tqdm

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')
DATA_DIR = os.path.join(BASE_DIR, 'data/dfdc-videos')
SAVE_DIR = os.path.join(BASE_DIR, 'data/dfdc-crops')

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2

import torch
import torchvision

from torchvision import ops

import nvidia.dali as dali
from nvidia.dali.plugin.pytorch import DALIGenericIterator

# src
sys.path.insert(0, SRC_DIR)
from sample.reader import VideoReader
from dataset.utils import read_labels

# Pytorch_Retinaface
sys.path.insert(0, os.path.join(BASE_DIR, 'Pytorch_Retinaface'))
from data import cfg_mnet
from layers.functions.prior_box import PriorBox
from models.retinaface import RetinaFace
from detect_utils import detect, load_model, postproc_detections
from utils.nms.py_cpu_nms import py_cpu_nms

In [4]:
@njit
def calc_axis(c0, c1, pad, cmax):
    c0 = max(0, c0 - pad)
    c1 = min(cmax, c1 + pad)
    return c0, c1, c1 - c0


@njit
def expand_bbox(bbox, pct):
    bbox = np.copy(bbox)
    bbox[:2] *= 1 - pct
    bbox[2:] *= 1 + pct
    return bbox


@njit
def crop_face(img, bbox, pad_pct=0.05, square=True):
    img_h, img_w, _ = img.shape
    
    if pad_pct > 0:
        bbox = expand_bbox(bbox, pad_pct)
        
    x0, y0, x1, y1 = bbox.astype(np.int16)
    
    if square:
        w, h = x1 - x0, y1 - y0
        if w > h:
            pad = (w - h) // 2
            y0, y1, h = calc_axis(y0, y1, pad, img_h)
        elif h > w:
            pad = (h - w) // 2
            x0, x1, w = calc_axis(x0, x1, pad, img_w)
    
    size = min(w, h)
    face = img[y0:y1, x0:x1][:size, :size]
    return face

In [5]:
def round_num_faces(num_faces, frac_thresh=0.25):
    avg = num_faces.mean()
    fraction, integral = np.modf(avg)
    rounded = integral if fraction < frac_thresh else integral + 1
    return int(rounded)

In [6]:
class VideoPipe(dali.pipeline.Pipeline):
    def __init__(self, filenames: List[str], seq_len=30, stride=10, 
                 batch_size=1, num_threads=1, device_id=0):
        super(VideoPipe, self).__init__(
            batch_size, num_threads, device_id, seed=3)
        self.input = dali.ops.VideoReader(
            device='gpu', filenames=filenames, 
            sequence_length=seq_len, stride=stride,
            shard_id=0, num_shards=1)

    def define_graph(self):
        output = self.input(name='reader')
        return output
    
    
def get_file_list(df: pd.DataFrame, start: int, end: int, 
                  base_dir:str=DATA_DIR) -> List[str]:
    path_fn = lambda row: os.path.join(base_dir, row.dir, row.name)
    return df.iloc[start:end].apply(path_fn, axis=1).values.tolist()


def init_detector(cfg, weights, use_cpu=False):
    cfg['pretrain'] = False
    net = RetinaFace(cfg=cfg, phase='test')
    net = load_model(net, weights, use_cpu)
    net.eval()
    return net


def mkdirs(base_dir, chunk_dirs):
    for chunk_dir in chunk_dirs:
        dir_path = os.path.join(base_dir, chunk_dir)
        if not os.path.isdir(dir_path):
            os.mkdir(dir_path)

In [7]:
def prepare_imgs(sample):
    n, h, w, c = sample.shape
    imgs = sample.float()
    imgs -= torch.tensor([104, 117, 123], device=imgs.device)
    imgs = imgs.permute(0, 3, 1, 2)
    scale = torch.tensor([w, h, w, h])
    return imgs, scale


def detect(sample, model, cfg, device):
    bs = cfg['batch_size']
    num_frames, height, width, ch = sample.shape
    imgs, scale = prepare_imgs(sample)

    priorbox = PriorBox(cfg, image_size=(height, width))
    priors = priorbox.forward().to(device)
    scale = scale.to(device)

    detections = []
    for start in range(0, num_frames, bs):
        end = start + bs
        imgs_batch = imgs[start:end] #.to(device)
        with torch.no_grad():
            loc, conf, landms = model(imgs_batch)
        imgs_batch, landms = None, None
        dets = postproc_detections(loc, conf, priors, scale, cfg)
        detections.append(dets)
        loc, conf = None, None
    return np.vstack(detections) if len(detections) > 1 else detections[0]

In [8]:
def prepare_data(
        start=0, end=None, 
        num_frames_fake=30, num_frames_real=120,
        use_cpu=False, bs=32, verbose=False,
        base_dir=BASE_DIR, data_dir=DATA_DIR, save_dir=SAVE_DIR,
        chunk_dirs=None, max_open_files=100):
    df = read_labels(data_dir, chunk_dirs=chunk_dirs)
    mkdirs(save_dir, df['dir'].unique())
    
    device = torch.device("cpu" if use_cpu else "cuda")
    weights_mnet = os.path.join(base_dir, 'data/weights/mobilenet0.25_Final.pth')
    cfg = {**cfg_mnet, 'batch_size': bs}
    detector = init_detector(cfg, weights_mnet, use_cpu).to(device)
    
    if end is None:
        end = len(df)
        
    num_frames = num_frames_fake
    
    for start_pos in range(start, end, max_open_files):
        end_pos = min(start_pos + max_open_files, end)
        files = get_file_list(df, start_pos, end_pos)
        pipe = VideoPipe(files, seq_len=num_frames, stride=300//num_frames)
        pipe.build()
        
        data_iter = DALIGenericIterator(
            [pipe], ['images'], len(files), dynamic_shape=True)
        for idx, batch in tqdm(enumerate(data_iter), total=len(files)):
            meta = df.iloc[start_pos + idx] # <- check this !!!
            # fake = bool(meta['label'])

            sample_dir = os.path.join(save_dir, meta.dir, meta.name[:-4])
            if not os.path.isdir(sample_dir):
                os.mkdir(sample_dir)
            if verbose:
                t0 = time.time()

            images = batch[0]['images'].squeeze(0)
            detections = detect(images, detector, cfg_mnet, device)
            num_faces = np.array(list(map(len, detections)), dtype=np.uint8)
            max_faces_per_frame = round_num_faces(num_faces, frac_thresh=0.25)
            images = images.cpu().numpy()

            for f in range(num_frames):
                for det in detections[f][:max_faces_per_frame]:
                    face = crop_face(images[f], det[:4])
                    file_path = os.path.join(sample_dir, '%03d.png' % f)
                    face = cv2.cvtColor(face, cv2.COLOR_RGB2BGR)
                    # cv2.imwrite(file_path, face)
            detections = None

            if verbose:
                t1 = time.time()
                print('[%6d][%.02f s] %s' % (start_pos + idx, t1 - t0, sample_dir))
                
        files, pipe, data_iter = None, None, None
        gc.collect()
    print('DONE')

In [9]:
!ls ../data/dfdc-videos/dfdc_train_part_49 | wc -l

3135


In [10]:
# df = read_labels(DATA_DIR, chunk_dirs=['dfdc_train_part_49'])
# files = get_file_list(df, 2680, 2730)

In [None]:
%%time
gc.collect()
prepare_data(start=2650, end=None, max_open_files=50, bs=30, verbose=True, 
             chunk_dirs=['dfdc_train_part_49'])

Loading pretrained model from /home/dmitry/projects/dfdc/data/weights/mobilenet0.25_Final.pth
remove prefix 'module.'
Missing keys:0
Unused checkpoint keys:0
Used keys:300


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

[  2660][1.01 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/ljcjouvznz
[  2661][0.46 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/tmfschpvyo
[  2662][0.47 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/rfykxekxer
[  2663][0.46 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/uyszohikoe
[  2664][0.45 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/hoaweiathp
[  2665][0.48 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/uxvarukjxl
[  2666][0.46 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/qbqimiiqil
[  2667][0.47 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/ncjqganhal
[  2668][0.45 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/jeclduagbh
[  2669][0.46 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/kyajzhnjjv
[  2670][0.45 s] /home/dmitry/projects/dfdc/data/dfdc-crops/dfdc_train_part_49/oygvvmtzjv
[  2671][0