In [1]:
from __future__ import print_function 
from __future__ import division

is_alchemy_used = True
from datetime import datetime
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from skimage import io, transform
import torch
from torch.utils import data
from torch.utils.data import DataLoader, SubsetRandomSampler,Dataset
from random import randint
from tqdm import tqdm
from PIL import Image
from random import shuffle
if is_alchemy_used:
    from catalyst.dl import SupervisedAlchemyRunner as SupervisedRunner
else:
    from catalyst.dl import SupervisedRunner

import random

import cv2
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize, RandomResizedCrop
from albumentations.pytorch import ToTensor
from alchemy import Logger
token = "d1dd16f08d518293bcbeddd313b49aa4"

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

def seed_everything(seed=12345):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
# seed_everything()

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


PyTorch Version:  1.3.1
Torchvision Version:  0.4.2


In [2]:
from typing import Callable, List, Tuple 

import os
import torch
import catalyst

from catalyst.dl import utils

print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

torch: 1.3.1, catalyst: 20.02.3


In [3]:
BASE_DIR = f'/home/{os.environ["USER"]}/projects/dfdc/'
DATA_DIR = os.path.join(BASE_DIR, 'data/dfdc-videos')
HDF5_DIR = f'/home/{os.environ["USER"]}/projects/dfdc/data/dfdc-crops/hdf5'

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "resnet"

# Number of classes in the dataset
num_classes = 2

# Batch size for training (change depending on how much memory you have)
batch_size = 16#32

# Number of epochs to train for 
num_epochs = 10

# Flag for feature extracting. When False, we finetune the whole model, 
#   when True we only update the reshaped layer params
feature_extract = False

In [4]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes) 
        input_size = 224

    elif model_name == "inception":
        """ Inception v3 
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()
    
    return model_ft, input_size

def my_initialize_model(file_checkpoint, model_name, feature_extract, emb_len):
    

    model, input_size = initialize_model(model_name, 2, feature_extract, use_pretrained=True)
#     model = model.to(device)
    if file_checkpoint != None:
        checkpoint = torch.load(file_checkpoint)#, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
    
    _ = model.eval()
    
    if file_checkpoint != None:
        del checkpoint

#     emb_len = 128
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, emb_len)
    return model, input_size

In [5]:
# model, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

In [6]:
import torch.nn as nn
import torch.nn.functional as F
from network.models import model_selection

class Net(nn.Module):
    def __init__(self, model_name, emb_len, hidden_dim):
        super(Net, self).__init__()
        self.backbone, self.input_size = my_initialize_model(None, model_name, False, emb_len)
        self.lstm = nn.LSTM(emb_len, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, 2)
#         self.out2tag = nn.Linear(self.input_size, 2)
        

    def forward(self, sentences):
        self.lstm.flatten_parameters()
        
        tag_scores_list = torch.zeros((sentences.shape[0], 2), dtype=torch.float32 ).cuda()
#         print(tag_scores_list.shape)
        for i, sentence in enumerate(sentences):   
            embeds = self.backbone(sentence.permute(0, 3, 1, 2))
#             print(embeds.shape)
            lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
#             print(lstm_out.shape)
            tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))     
#             print(tag_space.shape)
            tag_scores_list[i] = tag_space[-1,:]
        return tag_scores_list


model = Net('resnet', 16, 16)

In [7]:
import math
# import os
import gc
import sys
import time

from pathlib import Path

from functools import partial
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union

# from tqdm.notebook import tqdm

import cv2
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchvision
from torch import Tensor

In [8]:
sys.path.insert(0, os.path.join(BASE_DIR, 'src'))
from dataset.utils import read_labels
from prepare_data import get_file_list

In [9]:
def show_images(images, cols = 1, titles = None):
    """Display a list of images in a single figure with matplotlib.
    
    Parameters
    ---------
    images: List of np.arrays compatible with plt.imshow.
    
    cols (Default = 1): Number of columns in figure (number of rows is 
                        set to np.ceil(n_images/float(cols))).
    
    titles: List of titles corresponding to each image. Must have
            the same length as titles.
    """
    assert((titles is None)or (len(images) == len(titles)))
    n_images = len(images)
    if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
    fig = plt.figure()
    for n, (image, title) in enumerate(zip(images, titles)):
        a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1)
        if image.ndim == 2:
            plt.gray()
        plt.imshow(image)
        a.set_title(title)
    fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
    plt.show()

In [10]:
def check_len_hdf5(path):
    lens = dict()
    for name in os.listdir(path):
        full_path = os.path.join(path, name)
        if os.path.isfile(full_path):
            with h5py.File(full_path, 'r+') as f:
                lens[name] = len(f)
    return lens

In [11]:
def find_num_frames(df, idxs):
    lens = []
    for idx in idxs:
        meta = df.iloc[idx]
        path = os.path.join(HDF5_DIR, meta.dir, meta.name[:-4]+'.h5')
        if os.path.isfile(path):
            with h5py.File(path, 'r+') as f:
                lens.append(len(f))
        else:
            lens.append(-1)
    return lens

In [12]:
def read_hdf5(path: str, num_frames: int, size: int,
              sample_fn: Callable[[int], np.ndarray]) -> np.ndarray:
    img_size = (size, size)
    images = []
    with h5py.File(path, 'r+') as file:
        total_frames = len(file)
        if total_frames > 0:
            idxs = sample_fn(total_frames)
            pick = create_mask(idxs, total_frames)
            for i, key in enumerate(file.keys()):
                if pick[i]:
                    img = np.uint8(file[key])
                    img = cv2.imdecode(img, cv2.IMREAD_COLOR)
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = cv2.resize(img, img_size, 
                                     interpolation=cv2.INTER_NEAREST)
                    images.append(img)
            return np.stack(images)
        else:
            return np.empty((0, size, size, 3), dtype=np.uint8)
        
        
def sparse_frames(n: int, total: int) -> np.ndarray:
    idxs = np.linspace(0, total, min(n, total), dtype=int, endpoint=False)
    rnd_shift = np.random.randint(0, (total - idxs[-1]))
    return idxs + rnd_shift


def rnd_slice_frames(n: int, total: int, stride=1.) -> np.ndarray:
    idxs = np.arange(0, total, stride)[:n].astype(np.uint16)
    rnd_shift = np.random.randint(0, (total - idxs[-1]))
    return idxs + rnd_shift


def create_mask(idxs: np.ndarray, total: int) -> np.ndarray:
    mask = np.zeros(total, dtype=np.bool)
    mask[idxs] = 1
    return mask


def pad(frames: np.ndarray, amount: int, where :str='start') -> np.ndarray:
    dims = np.zeros((frames.ndim, 2), dtype=np.int8)
    pad_dim = 1 if where == 'end' else 0
    dims[0, pad_dim] = amount
    return np.pad(frames, dims, 'constant')

In [13]:
class FrameSampler():
    def __init__(self, num_frames: int, real_fake_ratio: float, 
                 p_sparse: float):
        self.num_frames = num_frames
        self.real_fake_ratio = real_fake_ratio
        self.p_sparse = p_sparse
        
    def __call__(self, label: Tuple[int, bool]) -> Callable[[int], np.ndarray]:
        dice = np.random.rand()
        if dice < self.p_sparse:
            return partial(sparse_frames, self.num_frames)
        else:
            # Stored frames: fake - 30, real - 150, 
            # the real_fake_ratio should be set to 150 / 30 = 5
            # stride for fake: 5 - (4 * 1) = 1
            # stride for real: 5 - (4 * 0) = 5
            n = self.real_fake_ratio
            stride = n - ((n-1) * int(label))
            return partial(rnd_slice_frames, self.num_frames, stride=stride)

In [14]:
sampler = FrameSampler(num_frames=15, real_fake_ratio=30/30, p_sparse=1.)


In [15]:
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_path: str, hdf5_path: str, 
                 size: Tuple[int, int], 
                 sampler: FrameSampler):
        self.df = VideoDataset._read_annotations(data_path, hdf5_path)
        self.size = size
        self.hdf5_path = hdf5_path
        self.sampler = sampler
        
    @staticmethod
    def _read_annotations(data_path: str, hdf5_path: str) -> pd.DataFrame:
        if not os.path.isdir(data_path):
            raise RuntimeError('Invalid data dir.. Wake up your drive brah!')
        parts = []
        for chunk_dir in os.listdir(data_path):
            hdf5_chunk_path = Path(hdf5_path)/chunk_dir
            if not hdf5_chunk_path.is_dir():
                print('{}: dir is missing..'.format(hdf5_chunk_path))
                continue
            meta_path = Path(data_path)/chunk_dir/'metadata.json'
            df = pd.read_json(meta_path).T
            df = df.reset_index().rename({'index': 'file'}, axis=1)
            df['label'] = df['label'] == 'FAKE'
            df['file'] = df['file'].apply(lambda file: file[:-4]+'.h5')
            df['dir'] = chunk_dir
            df['missing'] = False
            for i in range(len(df)):
                hdf5_file = df.loc[i, 'file']
                path = hdf5_chunk_path/hdf5_file
                if not path.is_file():
                    df.loc[i, 'missing'] = True
            num_miss = df['missing'].sum()
            if num_miss > 0:
                print('{}: {} files missing..'.format(
                    hdf5_chunk_path, num_miss))
                df = df[~df['missing']]
            df.drop(['split', 'missing'], axis=1, inplace=True)
            parts.append(df)
        return pd.concat(parts)
        
    def __len__(self) :
        return len(self.df)
    
    def __getitem__(self, idx) -> Tuple[np.ndarray, int]:
        num_frames, size = self.size
        meta = self.df.iloc[idx]
        label = int(meta.label)
        path = os.path.join(self.hdf5_path, meta.dir, meta.file)
        
        if os.path.isfile(path):
            sample_fn = self.sampler(meta.label)
            frames = read_hdf5(path, num_frames, size, sample_fn=sample_fn)
        else:
            print('Unable to read {}'.format(path))
            frames = np.zeros((num_frames, size, size, 3), dtype=np.uint8)
        
        if len(frames) > 0:
            pad_amount = num_frames - len(frames)
            if pad_amount > 0:
                frames = pad(frames, pad_amount, 'start')
        else:
            print('Empty file {}'.format(path))
            frames = np.zeros((num_frames, size, size, 3), dtype=np.uint8)
            
        frames = np.array(frames, dtype=np.float32)
        tr = Compose([
            
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
#             ToTensor()

        ])
        frames =np.asarray([tr(image=frame)['image'] for frame in frames ], dtype=np.float32)
#         print(frames.dtype)
        return frames, label

In [16]:
def shuffled_idxs(values: np.ndarray, val: int) -> List[int]:
    idxs = (values == val).nonzero()[0]
    idxs = np.random.permutation(idxs)
    return idxs


class BalancedSampler(torch.utils.data.RandomSampler):
    def __init__(self, data_source, replacement=False, num_samples=None):
        
        super().__init__(data_source, replacement, num_samples)
        if not hasattr(data_source, 'df'):
            raise ValueError("DataSource must have a 'df' property")
            
        if not 'label' in data_source.df: 
            raise ValueError("DataSource.df must have a 'label' column")
    
    def __iter__(self):
        df = self.data_source.df
        all_labels = df['label'].values
        uniq_labels, label_freq = np.unique(all_labels, return_counts=True)
        rev_freq = (len(all_labels) / label_freq)
        
        idxs = []
        for freq, label in zip(rev_freq, uniq_labels):
            fraction, times = np.modf(freq)
            label_idxs = (all_labels == label).nonzero()[0]
            for _ in range(int(times)):
                label_idxs = np.random.permutation(label_idxs)
                idxs.append(label_idxs)
            if fraction > 0.05:
                label_idxs = np.random.permutation(label_idxs)
                chunk = int(len(label_idxs) * fraction)
                idxs.append(label_idxs[:chunk])
        idxs = np.concatenate(idxs)
        idxs = np.random.permutation(idxs)[:self.num_samples]
        
        for i in idxs:
            yield i
        # return iter(idxs.tolist())

In [17]:
input_size = model.input_size
FASTPART=False
if FASTPART:
    num_frames = 2
else:
    num_frames = 15
def get_loader(num_frames=15, real_fake_ratio=1, p_sparse=0.5, input_size=input_size, hdf_dir=None):
    
    sampler = FrameSampler(num_frames, real_fake_ratio=real_fake_ratio, p_sparse=p_sparse)
    ds = VideoDataset(DATA_DIR, hdf_dir, size=(num_frames, input_size), sampler=sampler)
    print(len(ds))
    s = BalancedSampler(ds)
    batch_sampler = torch.utils.data.BatchSampler(
        BalancedSampler(ds), 
        batch_size=batch_size, 
        drop_last=True
    )
    dl = torch.utils.data.DataLoader(ds, batch_sampler=batch_sampler)
    return dl
    
loaders = {}
loaders['train'] = get_loader(num_frames=num_frames, real_fake_ratio=1, p_sparse=0.5, input_size=input_size, 
                              hdf_dir=f'/home/{os.environ["USER"]}/projects/dfdc/data/dfdc-crops-train/')
loaders['valid'] = get_loader(num_frames=num_frames, real_fake_ratio=1, p_sparse=0.5, input_size=input_size, 
                              hdf_dir=f'/home/{os.environ["USER"]}/projects/dfdc/data/dfdc-crops-valid/')
loaders['test'] = get_loader(num_frames=num_frames, real_fake_ratio=1, p_sparse=0.5, input_size=input_size, 
                              hdf_dir=f'/home/{os.environ["USER"]}/projects/dfdc/data/dfdc-crops-test/')

/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_46: dir is missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_41: dir is missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_32: 22 files missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_34: 1 files missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_45: dir is missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_21: 18 files missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_16: 1 files missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_18: 3 files missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_48: dir is missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_47: dir is missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_43: dir is missing..
/home/kb/projects/dfdc/data/dfdc-crops-train/dfdc_train_part_12: 44 files missing..
/home/k

/home/kb/projects/dfdc/data/dfdc-crops-test/dfdc_train_part_42: dir is missing..
12551


In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:


project = 'dfdc_v1_resnet'
num_epochs = 5

group = datetime.now().strftime("%m_%d_%Y__%H_%M_%S")

if FASTPART:
    group = f'fast_{group}'
    
expnum = 0
experiment = f"exp{expnum}"
logdir = f"/home/kb/hdd/logs/deepfake/{project}/{group}/{experiment}"


model = model.to(device)
params_to_update = model.parameters()
if feature_extract:
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
else:
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            pass


criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(params=model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# model runner
runner = SupervisedRunner()





print(f'----------------Experiment: {experiment}')
logger = Logger(
    token=token,
    experiment=experiment,
    group=group,
    project=project,
)

logger.close()

runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=True,
    monitoring_params={
        "token": token,
        "project": project,
        "experiment": experiment,
        "group": group,
    }
)

----------------Experiment: exp0
1/5 * Epoch (train):   2% 106/5882 [01:37<1:34:50,  1.02it/s, loss=0.688]