In [1]:
import tensorflow as tf
import numpy as np
import gc
import pandas as pd 
from tensorflow.keras.layers import ConvLSTM2D, Conv3D, Conv2D, Flatten, Dense, BatchNormalization
from utils import *
from sklearn.model_selection import train_test_split
from video_loader import DeepFakeTransformer
import sys
import pathlib
import datetime
import matplotlib.pyplot as plt

In [2]:
# Train params
EPOCHS=10
batch_size=2
epoch_steps = 1000
val_steps = 1000
reg_penalty = 0.001
cls_wt = {0:3, 1:2.25}

# Dataset params
data_pairs_path = '../data/source/labels/fake_to_real_mapping.csv'
resize_shape = (224,224)
sequence_len = 30
prefetch_num = 10
train_val_split = 0.015

In [3]:
df_pairs = pd.read_csv(data_pairs_path)[['real', 'fake']]

train_df, val_df = train_test_split(df_pairs, test_size = train_val_split)

In [4]:
print(len(train_df))
print(len(val_df))

69277
1055


In [5]:
train_df.to_numpy()

array([['../data/source/train_val_sort/train/REAL/clcjjbtmnm.mp4',
        '../data/source/train_val_sort/train/FAKE/qsdjaqcmzl.mp4'],
       ['../data/source/train_val_sort/train/REAL/sdzzjnfxtw.mp4',
        '../data/source/train_val_sort/train/FAKE/sbyckvrrlz.mp4'],
       ['../data/source/train_val_sort/train/REAL/nzrcmggcfp.mp4',
        '../data/source/train_val_sort/train/FAKE/mamywmhzvm.mp4'],
       ...,
       ['../data/source/train_val_sort/train/REAL/rjlgchzmfv.mp4',
        '../data/source/train_val_sort/train/FAKE/uwnwipuzvk.mp4'],
       ['../data/source/train_val_sort/train/REAL/crmobgizwo.mp4',
        '../data/source/train_val_sort/train/FAKE/prykceywau.mp4'],
       ['../data/source/train_val_sort/train/REAL/fuusxfhmrc.mp4',
        '../data/source/train_val_sort/train/FAKE/mgkkbmfumw.mp4']],
      dtype=object)

In [7]:
for f in train_ds.take(1):
    print(f)

tf.Tensor(
[b'../data/source/train_val_sort/train/REAL/clcjjbtmnm.mp4'
 b'../data/source/train_val_sort/train/FAKE/qsdjaqcmzl.mp4'], shape=(2,), dtype=string)


In [8]:
class VideoReader:
    """Helper class for reading one or more frames from a video file."""

    def __init__(self, verbose=True, insets=(0, 0)):
        """Creates a new VideoReader.

        Arguments:
            verbose: whether to print warnings and error messages
            insets: amount to inset the image by, as a percentage of 
                (width, height). This lets you "zoom in" to an image 
                to remove unimportant content around the borders. 
                Useful for face detection, which may not work if the 
                faces are too small.
        """
        self.verbose = verbose
        self.insets = insets
    
    def read_frames_test(self, path, num_frames, jitter=0, seed=None):
        """Reads frames that are always evenly spaced throughout the video.

        Arguments:
            path: the video file
            num_frames: how many frames to read, -1 means the entire video
                (warning: this will take up a lot of memory!)
            jitter: if not 0, adds small random offsets to the frame indices;
                this is useful so we don't always land on even or odd frames
            seed: random seed for jittering; if you set this to a fixed value,
                you probably want to set it only on the first video 
        """

        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))

        frame_idxs = np.linspace(0, frame_count - 1, frame_count, endpoint=True, dtype=np.int)
        if jitter > 0:
            np.random.seed(seed)
            jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
            frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)

        result, _ = self._read_frames_at_indices(path, capture, frame_idxs)
        capture.release()
        return result


    def read_frames(self, path, num_frames, jitter=0, seed=None):
        """Reads frames that are always evenly spaced throughout the video.

        Arguments:
            path: the video file
            num_frames: how many frames to read
            jitter: if not 0, adds small random offsets to the frame indices;
                this is useful so we don't always land on even or odd frames
            seed: random seed for jittering; if you set this to a fixed value,
                you probably want to set it only on the first video 
        """
        assert num_frames > 0

        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        if frame_count <= 0: return None

        start = np.random.randint(frame_count-num_frames)
        frame_idxs = np.linspace(start, start+num_frames, num=num_frames, dtype=np.int)
        if jitter > 0:
            np.random.seed(seed)
            jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
            frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)

        result, _ = self._read_frames_at_indices(path, capture, frame_idxs)
        capture.release()
        return result

    def read_random_frames(self, path, num_frames, seed=None):
        """Picks the frame indices at random.
        
        Arguments:
            path: the video file
            num_frames: how many frames to read, -1 means the entire video
                (warning: this will take up a lot of memory!)
        """
        assert num_frames > 0
        np.random.seed(seed)

        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        if frame_count <= 0: return None

        frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
        result = self._read_frames_at_indices(path, capture, frame_idxs)

        capture.release()
        return result

    def read_frames_at_indices(self, path, frame_idxs):
        """Reads frames from a video and puts them into a NumPy array.

        Arguments:
            path: the video file
            frame_idxs: a list of frame indices. Important: should be
                sorted from low-to-high! If an index appears multiple
                times, the frame is still read only once.

        Returns:
            - a NumPy array of shape (num_frames, height, width, 3)
            - a list of the frame indices that were read

        Reading stops if loading a frame fails, in which case the first
        dimension returned may actually be less than num_frames.

        Returns None if an exception is thrown for any reason, or if no
        frames were read.
        """
        assert len(frame_idxs) > 0
        capture = cv2.VideoCapture(path)
        result = self._read_frames_at_indices(path, capture, frame_idxs)
        capture.release()
        return result

    def _read_frames_at_indices(self, path, capture, frame_idxs):
        try:
            frames = []
            idxs_read = []
            for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
                # Get the next frame, but don't decode if we're not using it.
                ret = capture.grab()
                if not ret:
                    if self.verbose:
                        print("Error grabbing frame %d from movie %s" % (frame_idx, path))
                    break

                # Need to look at this frame?
                current = len(idxs_read)
                if frame_idx == frame_idxs[current]:
                    ret, frame = capture.retrieve()
                    if not ret or frame is None:
                        if self.verbose:
                            print("Error retrieving frame %d from movie %s" % (frame_idx, path))
                        break

                    frame = self._postprocess_frame(frame)
                    frames.append(frame)
                    idxs_read.append(frame_idx)

            if len(frames) > 0:
                return np.stack(frames), idxs_read
            if self.verbose:
                print("No frames read from movie %s" % path)
            return None
        except:
            if self.verbose:
                print("Exception while reading movie %s" % path)
            return None    

    def read_middle_frame(self, path):
        """Reads the frame from the middle of the video."""
        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        result = self._read_frame_at_index(path, capture, frame_count // 2)
        capture.release()
        return result

    def read_frame_at_index(self, path, frame_idx):
        """Reads a single frame from a video.
        
        If you just want to read a single frame from the video, this is more
        efficient than scanning through the video to find the frame. However,
        for reading multiple frames it's not efficient.
        
        My guess is that a "streaming" approach is more efficient than a 
        "random access" approach because, unless you happen to grab a keyframe, 
        the decoder still needs to read all the previous frames in order to 
        reconstruct the one you're asking for.

        Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
        or None if reading failed.
        """
        capture = cv2.VideoCapture(path)
        result = self._read_frame_at_index(path, capture, frame_idx)
        capture.release()
        return result

    def _read_frame_at_index(self, path, capture, frame_idx):
        capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = capture.read()    
        if not ret or frame is None:
            if self.verbose:
                print("Error retrieving frame %d from movie %s" % (frame_idx, path))
            return None
        else:
            frame = self._postprocess_frame(frame)
            return np.expand_dims(frame, axis=0), [frame_idx]
    
    def _postprocess_frame(self, frame):
        # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # if self.insets[0] > 0:
        #     W = frame.shape[1]
        #     p = int(W * self.insets[0])
        #     frame = frame[:, p:-p, :]

        # if self.insets[1] > 0:
        #     H = frame.shape[1]
        #     q = int(H * self.insets[1])
        #     frame = frame[q:-q, :, :]

        return frame

In [54]:
class DeepFakeDualTransformer(object):
    def __init__(self, chan_means=[0.485, 0.456, 0.406],
                       chan_std_dev=[0.229, 0.224, 0.225],
                       resize_shape=(300,300),
                       seq_length=298,
                       mode="train"):
        """[summary]
        
        Keyword Arguments:
            chan_means {list} -- [description] (default: {[0.485, 0.456, 0.406]})
            chan_std_dev {list} -- [description] (default: {[0.229, 0.224, 0.225]})
            resize_shape {tuple} -- [description] (default: {(300,300)})
            seq_length {int} -- [description] (default: {298})
            mode {str} -- [description] (default: {"train"})
        """

        self.chan_means = chan_means
        self.chan_std_dev = chan_std_dev
        self.resize_shape = resize_shape
        self.seq_length = seq_length
        self.mode = mode
        self.reader = VideoReader()
        
    def get_frames(self, fnames):

        num_frames = self.seq_length
        
        real = fnames.numpy()[0].decode('utf-8')
        fake = fnames.numpy()[1].decode('utf-8')
        
        real_capture = cv2.VideoCapture(real)
        fake_capture = cv2.VideoCapture(fake)
        
        
        # Counts should be equal between real and fakes
        frame_count = int(fake_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        # Base inds on same frame grab to use matching video frames
        start = np.random.randint(frame_count-num_frames)
        frame_idxs = np.linspace(start, start+num_frames, num=num_frames, dtype=np.int)
        
        real_vid, _ = self.reader._read_frames_at_indices(real, real_capture, frame_idxs)
        fake_vid, _ = self.reader._read_frames_at_indices(fake, fake_capture, frame_idxs)
        
        real_capture.release()
        fake_capture.release()
        
        return real_vid, fake_vid
    
    def normalize(self, video, chan_means, chan_std_dev):
        """[summary]

        Arguments:
            video {tf.Tensor} -- tensorflow reshaped video data
            chan_means {array} -- [description]
            chan_std_dev {array} -- [description]

        Returns:
            [tf.Tensor] -- normalized video data
        """

        video /= 255
        video -= chan_means
        video /= chan_std_dev

        return video
    
    def transform_vid(self, filenames):

        
        chan_means = self.chan_means
        chan_std_dev = self.chan_std_dev
        resize_shape = self.resize_shape
        
        # For kaggle only
        # fname = parts[-1].numpy().decode('utf-8')
        # global filelog
        # filelog.append(fname)
        
        real_vid, fake_vid = self.get_frames(filenames)
        

        real_vid = tf.image.resize(real_vid, size=resize_shape)
        fake_vid = tf.image.resize(fake_vid, size=resize_shape)
        real_vid = self.normalize(real_vid, chan_means, chan_std_dev)
        fake_vid = self.normalize(fake_vid, chan_means, chan_std_dev)

        return tf.stack((real_vid, fake_vid))
    
    def transform_map(self, x):
        result_tensor = tf.py_function(func=self.transform_vid,
                                        inp=[x],
                                        Tout=[tf.float32])
        result_tensor[0].set_shape((2,None,None,None,None))
        return result_tensor[0]

In [10]:
def get_nframes_test(path):
    capture = cv2.VideoCapture(path)
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    capture.release()
    return frame_count
    
def _test_frames(df):

    i = 0
    for r,f in df.to_numpy():
        fs = get_nframes_test(f)
        rs = get_nframes_test(r)
        i+=1
        if i%100 == 0:
            print('vids tested: ', i+1)
        
        if fs != rs:
            print(r)
            print(rs)
            print(f)
            print(fs)
            

In [55]:
train_ds = tf.data.Dataset.from_tensor_slices(train_df.to_numpy())
val_ds = tf.data.Dataset.from_tensor_slices(val_df.to_numpy())
train_transformer = DeepFakeDualTransformer(resize_shape=resize_shape, seq_length=sequence_len)
# TODO add in random crops, rotations, etc to make this non-redundant
val_transformer = DeepFakeDualTransformer(resize_shape=resize_shape, seq_length=sequence_len)

train_ds = train_ds.map(lambda x: train_transformer.transform_map(x)).prefetch(prefetch_num)
val_ds = val_ds.map(lambda x: val_transformer.transform_map(x)).prefetch(prefetch_num)

In [23]:
df_trf = DeepFakeDualTransformer(resize_shape=resize_shape, seq_length=sequence_len)

In [56]:
res = df_trf.transform_vid(f)

In [58]:
i = 0
for f1 in train_ds.as_numpy_iterator():
    print(f1.shape)
    i += 1
    if i > 3:
        break

(2, 30, 224, 224, 3)
(2, 30, 224, 224, 3)
(2, 30, 224, 224, 3)
(2, 30, 224, 224, 3)
