In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from utils import *
import os
import sys
import glob
import pathlib

In [2]:
EPOCHS=1
validate_epochs = [1,2,10]
batch_size=1
test_fraction = 0.2
train_label_path = '../data/source/labels/train_meta.json'
train_path = '../data/source/train/'
checkpoint_prefix = 'models/ckpt_{epoch}'
resize_shape = (224,224)
sequence_len = 16
n_workers = 1
use_mult_prc = False

In [11]:
vid_root = '/home/kevin/deepfake-proj/data/source/train_val_sort/train/'
vid_root = pathlib.Path(vid_root)
vid_ds = tf.data.Dataset.list_files(str(vid_root/'*/*'))

In [12]:
len(list(vid_ds))

320

In [4]:
for f in vid_ds.take(5):
    print(f.numpy())

b'/home/kevin/deepfake-proj/data/source/train_val_sort/train/FAKE/eukvucdetx.mp4'
b'/home/kevin/deepfake-proj/data/source/train_val_sort/train/FAKE/ahfazfbntc.mp4'
b'/home/kevin/deepfake-proj/data/source/train_val_sort/train/FAKE/awhmfnnjih.mp4'
b'/home/kevin/deepfake-proj/data/source/train_val_sort/train/FAKE/dnhvalzvrt.mp4'
b'/home/kevin/deepfake-proj/data/source/train_val_sort/train/REAL/ellavthztb.mp4'


In [205]:
class DeepFakeTransformer(object):
    def __init__(self, chan_means=[0.485, 0.456, 0.406],
                       chan_std_dev=[0.229, 0.224, 0.225],
                       resize_shape=(300,300),
                       seq_length=298):
        self.chan_means = chan_means
        self.chan_std_dev = chan_std_dev
        self.resize_shape = resize_shape
        self.seq_length = seq_length
        
    def get_frames(filename):
        '''
        method for getting the frames from a video file
        args: 
            filename: exact path of the video file
            first_only: whether to detect the first frame only or all of the frames
        out:
            video_frames, label:  
        '''

        filepath = filename.numpy().decode('utf-8')


        cap = cv2.VideoCapture(filepath) 
        # captures the video. Think of it as if life is a movie so we ask the method to focus on patricular event
        # that is our video in this case. It will concentrate on the video
        frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frameWidth = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frameHeight = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        all_frames = np.empty((frameCount, frameHeight, frameWidth, 3), np.dtype('uint8'))



        fc = 0
        while(cap.isOpened() and fc < frameCount): # as long as all the frames have been traversed
            ret, frame = cap.read()
            # capture the frame. Again, if life is a movie, this function acts as camera

            if ret==True:
                all_frames[fc] = frame
                fc += 1
                if cv2.waitKey(1) & 0xFF == ord('q'): # break in between by pressing the key given
                    break
            else:
                break

        cap.release()
        # release whatever was held by the method for say, resources and the video itself
        return all_frames

    # tensorflow functions to pre-process videos
    def normalize(video, chan_means, chan_std_dev):
        """[summary]

        Arguments:
            video {tf.Tensor} -- tensorflow reshaped video data
            chan_means {array} -- [description]
            chan_std_dev {array} -- [description]

        Returns:
            [tf.Tensor] -- normalized video data
        """

        video /= 255
        video -= chan_means
        video /= chan_std_dev

        return video

    def transform_vid(self, filename):
        
        chan_means = self.chan_means
        chan_std_dev = self.chan_std_dev
        resize_shape = self.resize_shape
        seq_length = self.seq_length
 
        parts = tf.strings.split(filename, '/')
        label = parts[-2]
        # Don't want to exceed frames, available, using 198 as limit
        if seq_length == 298:
            start = 0
        else:
            start = np.random.randint(298 - seq_length)

        vid = get_frames(filename)[start:(start+seq_length),:,:,:]
        vid = tf.image.resize(vid, size=resize_shape).numpy()
        vid = normalize(vid, chan_means, chan_std_dev)

        return vid, label

In [206]:
transformer = DeepFakeTransformer(resize_shape=(224,224))
trf_func = transformer.transform_vid

In [207]:
vid_ds = vid_ds.map(lambda x: tf.py_function(trf_func, [x], [tf.float32, tf.string]))

In [208]:
for vid, label in vid_ds.take(1):
    print(label)
    print(vid.shape)

tf.Tensor(b'REAL', shape=(), dtype=string)
(298, 224, 224, 3)


In [166]:
transformer.resize_shape

[224, 224]

In [15]:
test_dims = (batch_size, sequence_len, *resize_shape, 3)


In [16]:
test_dims

(1, 16, 224, 224, 3)

In [17]:
class VideoReader:
    """Helper class for reading one or more frames from a video file."""

    def __init__(self, verbose=True, insets=(0, 0)):
        """Creates a new VideoReader.

        Arguments:
            verbose: whether to print warnings and error messages
            insets: amount to inset the image by, as a percentage of 
                (width, height). This lets you "zoom in" to an image 
                to remove unimportant content around the borders. 
                Useful for face detection, which may not work if the 
                faces are too small.
        """
        self.verbose = verbose
        self.insets = insets

    def read_frames(self, path, num_frames, jitter=0, seed=None):
        """Reads frames that are always evenly spaced throughout the video.

        Arguments:
            path: the video file
            num_frames: how many frames to read, -1 means the entire video
                (warning: this will take up a lot of memory!)
            jitter: if not 0, adds small random offsets to the frame indices;
                this is useful so we don't always land on even or odd frames
            seed: random seed for jittering; if you set this to a fixed value,
                you probably want to set it only on the first video 
        """
        assert num_frames > 0

        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        if frame_count <= 0: return None

        start = np.random.randint(frame_count-num_frames)
        frame_idxs = np.linspace(start, start+num_frames, num=num_frames, dtype=np.int)
        if jitter > 0:
            np.random.seed(seed)
            jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
            frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)

        result = self._read_frames_at_indices(path, capture, frame_idxs)
        capture.release()
        return result

    def read_random_frames(self, path, num_frames, seed=None):
        """Picks the frame indices at random.
        
        Arguments:
            path: the video file
            num_frames: how many frames to read, -1 means the entire video
                (warning: this will take up a lot of memory!)
        """
        assert num_frames > 0
        np.random.seed(seed)

        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        if frame_count <= 0: return None

        frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
        result = self._read_frames_at_indices(path, capture, frame_idxs)

        capture.release()
        return result

    def read_frames_at_indices(self, path, frame_idxs):
        """Reads frames from a video and puts them into a NumPy array.

        Arguments:
            path: the video file
            frame_idxs: a list of frame indices. Important: should be
                sorted from low-to-high! If an index appears multiple
                times, the frame is still read only once.

        Returns:
            - a NumPy array of shape (num_frames, height, width, 3)
            - a list of the frame indices that were read

        Reading stops if loading a frame fails, in which case the first
        dimension returned may actually be less than num_frames.

        Returns None if an exception is thrown for any reason, or if no
        frames were read.
        """
        assert len(frame_idxs) > 0
        capture = cv2.VideoCapture(path)
        result = self._read_frames_at_indices(path, capture, frame_idxs)
        capture.release()
        return result

    def _read_frames_at_indices(self, path, capture, frame_idxs):
        try:
            frames = []
            idxs_read = []
            for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
                # Get the next frame, but don't decode if we're not using it.
                ret = capture.grab()
                if not ret:
                    if self.verbose:
                        print("Error grabbing frame %d from movie %s" % (frame_idx, path))
                    break

                # Need to look at this frame?
                current = len(idxs_read)
                if frame_idx == frame_idxs[current]:
                    ret, frame = capture.retrieve()
                    if not ret or frame is None:
                        if self.verbose:
                            print("Error retrieving frame %d from movie %s" % (frame_idx, path))
                        break

                    frame = self._postprocess_frame(frame)
                    frames.append(frame)
                    idxs_read.append(frame_idx)

            if len(frames) > 0:
                return np.stack(frames), idxs_read
            if self.verbose:
                print("No frames read from movie %s" % path)
            return None
        except:
            if self.verbose:
                print("Exception while reading movie %s" % path)
            return None    

    def read_middle_frame(self, path):
        """Reads the frame from the middle of the video."""
        capture = cv2.VideoCapture(path)
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        result = self._read_frame_at_index(path, capture, frame_count // 2)
        capture.release()
        return result

    def read_frame_at_index(self, path, frame_idx):
        """Reads a single frame from a video.
        
        If you just want to read a single frame from the video, this is more
        efficient than scanning through the video to find the frame. However,
        for reading multiple frames it's not efficient.
        
        My guess is that a "streaming" approach is more efficient than a 
        "random access" approach because, unless you happen to grab a keyframe, 
        the decoder still needs to read all the previous frames in order to 
        reconstruct the one you're asking for.

        Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
        or None if reading failed.
        """
        capture = cv2.VideoCapture(path)
        result = self._read_frame_at_index(path, capture, frame_idx)
        capture.release()
        return result

    def _read_frame_at_index(self, path, capture, frame_idx):
        capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = capture.read()    
        if not ret or frame is None:
            if self.verbose:
                print("Error retrieving frame %d from movie %s" % (frame_idx, path))
            return None
        else:
            frame = self._postprocess_frame(frame)
            return np.expand_dims(frame, axis=0), [frame_idx]
    
    def _postprocess_frame(self, frame):
#         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

#         if self.insets[0] > 0:
#             W = frame.shape[1]
#             p = int(W * self.insets[0])
#             frame = frame[:, p:-p, :]

#         if self.insets[1] > 0:
#             H = frame.shape[1]
#             q = int(H * self.insets[1])
#             frame = frame[q:-q, :, :]

        return frame

In [4]:
vids = glob.glob('../data/source/microtest/*')
vids

['../data/source/microtest/aassnaulhq.mp4',
 '../data/source/microtest/aqrsylrzgi.mp4',
 '../data/source/microtest/apvzjkvnwn.mp4',
 '../data/source/microtest/apedduehoy.mp4',
 '../data/source/microtest/aktnlyqpah.mp4',
 '../data/source/microtest/ahjnxtiamx.mp4',
 '../data/source/microtest/ayipraspbn.mp4',
 '../data/source/microtest/bcbqxhziqz.mp4',
 '../data/source/microtest/alrtntfxtd.mp4',
 '../data/source/microtest/axfhbpkdlc.mp4',
 '../data/source/microtest/bcvheslzrq.mp4',
 '../data/source/microtest/ajiyrjfyzp.mp4',
 '../data/source/microtest/adohdulfwb.mp4',
 '../data/source/microtest/acazlolrpz.mp4',
 '../data/source/microtest/aayfryxljh.mp4',
 '../data/source/microtest/aomqqjipcp.mp4']

In [18]:
vr = VideoReader()

In [22]:
%%time
for vid in vids:
    vid = vr.read_frames(vid, num_frames=30)

CPU times: user 13.4 s, sys: 3.46 s, total: 16.8 s
Wall time: 3.7 s


In [21]:
%%time
for vid in vids:
    vid = get_frames(vid)


CPU times: user 1min 3s, sys: 1.8 s, total: 1min 5s
Wall time: 22.3 s


In [16]:
num_frames = 30
frame_ct = 300
start = np.random.randint(frame_count-num_frames)
np.linspace(start, start+num_frames, num=num_frames, dtype=int)

array([110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
       123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135,
       136, 137, 138, 140])

In [24]:
vid

(array([[[[164,  44,  62],
          [166,  46,  64],
          [144,  60,  66],
          ...,
          [160, 149, 142],
          [163, 152, 145],
          [163, 152, 145]],
 
         [[163,  43,  61],
          [164,  44,  62],
          [142,  58,  64],
          ...,
          [154, 143, 136],
          [156, 145, 138],
          [156, 145, 138]],
 
         [[151,  46,  59],
          [152,  47,  60],
          [133,  58,  61],
          ...,
          [150, 139, 132],
          [147, 136, 129],
          [147, 136, 129]],
 
         ...,
 
         [[ 72,  18,  27],
          [ 70,  16,  25],
          [ 69,  17,  25],
          ...,
          [  9,   3,   8],
          [  9,   3,   8],
          [  9,   3,   8]],
 
         [[ 70,  16,  25],
          [ 69,  15,  24],
          [ 68,  16,  24],
          ...,
          [  9,   3,   8],
          [  9,   3,   8],
          [  9,   3,   8]],
 
         [[ 70,  16,  25],
          [ 69,  15,  24],
          [ 68,  16,  24],
   