Tests the train.py module.

In [1]:
import os
import time 

import cv2
import torch
import torchvision
import numpy as np
from torch.utils.data.dataloader import DataLoader

import cheapfake.contrib.dataset as dataset
import cheapfake.contrib.models_contrib as models 
import cheapfake.contrib.transforms as transforms

#### Testing AugmentedFAN

In [2]:
random_seed = 41
metadata_path = "/home/shu/cheapfake/cheapfake/contrib/wide_balanced_metadata_fs03.csv"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dfdataset = dataset.DeepFakeDataset(metadata_path=metadata_path, frame_transform=transforms.BatchRescale(4), sequential_audio=True, sequential_frames=True, random_seed=random_seed, num_samples=2)
dfdataloader = DataLoader(dfdataset, batch_size=2, shuffle=True)

for batch_index, batch in enumerate(dfdataloader):
    frames, _, audio_stft, label = batch
    frames = frames[:, :75]
    
face_model = models.AugmentedFAN(device=device)
frames_model = models.AugmentedLipNet(device=device)

frames = frames.float().to(device)
start_time = time.time()
landmarks, fan_embedding = face_model(frames)
end_time = time.time()
print("Entire operation took {} seconds".format(end_time - start_time))

print(fan_embedding.shape, landmarks.shape)

KeyboardInterrupt: 

In [None]:
def _find_bounding_box(landmarks, tol=(2, 2, 2, 2)):
    """Finds the minimum bounding box containing the points, with tolerance in the left, right, top, and bottom directions (in pixels).

    Parameters
    ----------
    landmarks : numpy.ndarray or torch.Tensor instance
        Numpy array or Torch tensor containing the predicted xy-coordinates of the detected facial landmarks.
    tol : tuple, optional
        The tolerance (in pixels) in each direction (left, top, right, bottom) by default (2, 2, 2, 2). 
    
    Returns
    -------
    bbox : tuple (of ints)
        Tuple (min_x, min_y, max_x, max_y) containing the coordinates of the bounding box, with tolerance in the left, right, top and bottom directions. 

    """
    assert isinstance(tol, tuple)
    assert len(tol) == 4, "Need four values for the tolerance."

    x_coords, y_coords = zip(*landmarks)
    bbox = (
        min(x_coords) - tol[0],
        min(y_coords) - tol[1],
        max(x_coords) + tol[2],
        max(y_coords) + tol[3],
    )
    bbox = tuple([int(item) for item in bbox])

    return bbox


def _find_bounding_boxes(landmarks, tol=(2, 2, 2, 2)):
    """Finds the minimum bounding boxes for a batch of facial landmarks.

    Parameters
    ----------
    landmarks : numpy.ndarray or torch.Tensor instance
        Numpy array or Torch tensor containing the xy-coordinates of the detected facial landmarks, in batches.
    tol : tuple, optional
        The tolerance (in pixels) in each direction (left, top, right, bottom) by default (2, 2, 2,2).
    
    Returns
    -------
    bboxes : list (of tuples)
        List containing tuples containing the coordinates of the bounding boxes for the batch of landmarks.

    """
    bboxes = list()
    landmarks = landmarks[:, 48:68]
    for landmark in landmarks:
        bboxes.append(_find_bounding_box(landmark, tol))

    return bboxes


def _crop_lips(frames, landmarks, tol=(2, 2, 2, 2), channels_first=True):
    """Crops the lip area from a batch of frames.

    Parameters
    ----------
    frames : torch.Tensor instance
        Torch tensor instance containing the frames to crop the lip areas from.
    landmarks : numpy.ndarray or torch.Tensor instance
        Numpy array or Torch tensor containing the xy-coordinates of the detected facial landmarks.
    tol : tuple, optional
        The tolerance (in pixels) in each direction (left, top, right, bottom) by default (2, 2, 2, 2).
    channels_first : bool, optional
        If True then the input and output are assumed to have shape (sample, channel, height, width), by default True. Otherwise the input and output are assumed to have shape (sample, height, width, channel).

    Returns
    -------
    cropped_frames : numpy.ndarray or torch.Tensor instance
        Numpy array or Torch tensor containing the cropped lips.

    """
    assert isinstance(frames, torch.Tensor)
    assert isinstance(landmarks, (torch.Tensor, np.ndarray))
    assert isinstance(tol, tuple)
    assert isinstance(channels_first, bool)
    
    if channels_first:
        frames = frames.permute(0, 2, 3, 1)

    bboxes = _find_bounding_boxes(landmarks, tol=tol)
    
    extracted_lips = torch.empty(frames.shape[0], 64, 128, 3)
    for idx, (bbox, frame) in enumerate(zip(bboxes, frames)):
        extracted_lip = frame[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
        extracted_lips[idx] = torch.from_numpy(cv2.resize(extracted_lip.cpu().numpy(), dsize=(128, 64), interpolation=cv2.INTER_CUBIC))
    
    return extracted_lips

def _crop_lips_batch(batch_frames, batch_landmarks, tol=(2, 2, 2, 2), channels_first=True):
    """Extracts the lip area for a batch of batch of frames.
    
    Finish documentation later.
    
    """
    assert isinstance(batch_frames, torch.Tensor)
    assert isinstance(batch_landmarks, torch.Tensor)
    assert isinstance(tol, tuple)
    assert isinstance(channels_first, bool)
    
    output_shape = (batch_frames.shape[0], batch_frames.shape[1], 64, 128, batch_frames.shape[2])
    batch_extracted_lips = torch.empty(output_shape)
    for idx, (frames, landmarks) in enumerate(zip(batch_frames, batch_landmarks)):
        batch_extracted_lips[idx] = _crop_lips(frames, landmarks, tol=tol, channels_first=channels_first)
    
    return batch_extracted_lips

In [None]:
start_time = time.time()
extracted_lips = _crop_lips_batch(frames, landmarks)
extracted_lips = extracted_lips.permute(0, -1, 1, 2, 3).float().to(device)
print("extracted_lips has shape {}".format(extracted_lips.shape))
#extracted_lips = extracted_lips.float().to(device)
#extracted_lips = extracted_lips.permute(3, 0, 1, 2)
#extracted_lips = extracted_lips[None, :, :, :, :]
lip_embedding = frames_model(extracted_lips)
end_time = time.time()

print("Entire process took {} seconds".format(end_time - start_time))


In [None]:
print(lip_embedding.shape)

In [None]:
concat_features = torch.cat((fan_embedding, lip_embedding), axis=1)
print(concat_features.shape)

In [None]:
import cheapfake.contrib.ResNetSE34L as resnet_models

audio_model = resnet_models.ResNetSE34L().to(device)
audio_embedding = audio_model(audio_stft.view(audio_stft.shape[0], -1))

print(audio_embedding.shape)