In [1]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from moviepy.editor import *

from sda.encoder_image import Encoder
from sda.img_generator import Generator
from sda.rnn_audio import RNN
from sda.encoder_audio import Encoder as AEncoder

from scipy import signal
from skimage import transform as tf
import numpy as np
from PIL import Image
import contextlib
import shutil
import skvideo.io as sio
import scipy.io.wavfile as wav
import ffmpeg
import face_alignment
from pydub import AudioSegment
from pydub.utils import mediainfo

import glob

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

dev = torch.device("cuda:0")

Random Seed:  999


In [222]:
video_filenames = glob.glob('/home/jarrod/dev/speech-driven-animation/data/*/*.mpg')

# Network

# Preprocess Video Data

In [2]:
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device="cuda:" + str(0), flip_input=False)
mean_face = np.load('/home/jarrod/dev/speech-driven-animation/data/mean_face.npy')

In [277]:
def preprocess_img(img):
    stablePntsIDs = [33, 36, 39, 42, 45]
    
    src = fa.get_landmarks(img)
    if src != None:
        dst = mean_face[stablePntsIDs, :]
        tform = tf.estimate_transform('similarity', src[0][stablePntsIDs, :], dst)  # find the transformation matrix
        warped = tf.warp(img, inverse_map=tform.inverse, output_shape=(256,256))  # wrap the frame image
        warped = warped * 255  # note output from wrap is double image (value range [0,1])
        warped = warped.astype('uint8')

    else:
        warped = np.zeros((256,256,3))
        
    return warped

def cropParams(frame):
    
    src = fa.get_landmarks(frame)
    err = 0
    
    if src != None and src[0].shape[0] == 68 and src[0].shape[1] == 2:
    
        max_x = int(np.max(src[0][:,0]))
        min_x = int(np.min(src[0][:,0]))
        
        max_y = int(np.max(src[0][:,1]))
        min_y = int(np.min(src[0][:,1]))

        center_x = int(min_x + (max_x - min_x)/2)
        center_y = int(min_y + (max_y - min_y)/2)

        img_height = 256
        img_width = 256

        left_crop = int(center_x - img_width/2)
        right_crop = int(center_x + img_width/2)

        top_crop = int(center_y - img_height/2)
        bottom_crop = int(center_y + img_height/2)
    
    else:
        err = 1
        crop = 0
        return [0], err, src
    
    return [top_crop, bottom_crop, left_crop,right_crop], err, src

def zeroPad(frame, size):
    
    # height
    if frame.shape[0] < size[0]:
        
        orig_frame_shape = np.copy(frame.shape)
        
        top_diff = int(np.floor((size[0]-frame.shape[0])/2))
        top_pad = np.zeros((top_diff,frame.shape[1],3),dtype=np.uint8)
        frame = np.concatenate((top_pad,frame),axis=0)
        
        bot_diff = int(np.ceil((size[0]-orig_frame_shape[0])/2))
        bot_pad = np.zeros((bot_diff,orig_frame_shape[1],3),dtype=np.uint8)
        frame = np.concatenate((frame, bot_pad),axis=0)
    
    return frame
    
    
def alignFace(vid_data, fname):
    
    new_vid = []
    crop_params, err, src = cropParams(vid_data[0])
    
    if err == 1:
        return [],err,src
    
    for i, frame in enumerate(vid_data):
          
        crop = frame[16:272, crop_params[2]:crop_params[3], :]
        # zero pad height of image evenly on top and bottom
        
#         crop = zeroPad(crop, (256,256))
        
        if err == 0:
#             print("immediately before", crop.shape)
#             new_vid.append(preprocess_img(crop))
            new_vid.append(crop)
        else:
            print("error in ", fname, " at frame ", i)
    
    
    out = np.stack(new_vid)
    
    return out, err, src

In [292]:
# NOTE: this outputs un-normalized faces (np.uint8)

count = 0
    
for f in range(0,len(video_filenames)):
    
    fname = video_filenames[f]
    
    
    if fname.find('id') == -1:
        continue
    
    if count % 10 == 0:
        print(f, " of " + str(len(video_filenames)), ' ', end='')
    
    vid_data = sio.vread(fname)
    
    if vid_data.shape[0] != 75:
        print("video does not have 75 frames")
        continue
        
    out, err, src = alignFace(vid_data, fname)
    
    if err == 1:
        print('error with face track')
        continue
    
    np.save('./256_vid_data/' + fname.split('/')[-1].split('.mpg')[0] + ".npy", out)

    count += 1

29  of 3000  908  of 3000  1602  of 3000  2535  of 3000  

In [284]:
video_filenames

['/home/jarrod/dev/speech-driven-animation/data/s3/lgbf7s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/pgbe1s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/srbu5s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/lgiz3a.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/bwwh2p.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/lwiy5s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/bgbn8p.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/lbakzp.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/sgwx3a.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/sban3a.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/swau9a.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/bris1s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/pgbk5s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/bras5s.mpg',
 '/home/jarrod/dev/speech-driven-animation/data/s3/lgwgzn.mpg',
 '/home/jarrod/dev/speech-driven-animati

In [264]:
filenames = glob.glob('/home/jarrod/dev/vq-vae-2-pytorch/vid_test_data/*vid*.npy')

# Debug and Clean

In [21]:
# Check How Many Videos Have Less Than 75 Frames

count = 0

# for f in audio_filenames:
audio_filenames = glob.glob('/home/jarrod/dev/speech-driven-animation/data/aligned_faces/*.npy') 
   
bad_list = []
    
for i, f in enumerate(audio_filenames):
    
    
    vid_data = np.load(f)
     
    if i % 100 == 0:
        print(i)


    if vid_data.shape[0] < 75:
        print("file: ", f)
        count += 1
        print(count)
    
    

0
100
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/brwa4p.npy
1
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/pwbd6s.npy
2
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/sbaa4p.npy
3
200
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/brwg8p.npy
4
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/bgbn9a.npy
5
300
400
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/pbwxzs.npy
6
500
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/bwwuzn.npy
7
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/bramzn.npy
8
600
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/bgit2n.npy
9
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/srwi5a.npy
10
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/lrarzn.npy
11
700
800
900
1000
file:  /home/jarrod/dev/speech-driven-animation/data/aligned_faces/swao7a.n

In [6]:
vid_data = sio.vread('./data/s1/bbizzn.mpg')

In [7]:
vid_data.shape

(75, 288, 360, 3)

In [8]:
test = np.load('./data/aligned_faces/bbizzn.npy')

In [9]:
test.shape

(63, 128, 96, 3)

In [15]:
out, err = alignFace(vid_data, './data/s1/bbizzn.mpg')

error in  ./data/s1/bbizzn.mpg  at frame  0
error in  ./data/s1/bbizzn.mpg  at frame  1
error in  ./data/s1/bbizzn.mpg  at frame  2
error in  ./data/s1/bbizzn.mpg  at frame  3
error in  ./data/s1/bbizzn.mpg  at frame  4
error in  ./data/s1/bbizzn.mpg  at frame  5
error in  ./data/s1/bbizzn.mpg  at frame  6
error in  ./data/s1/bbizzn.mpg  at frame  7
error in  ./data/s1/bbizzn.mpg  at frame  8
error in  ./data/s1/bbizzn.mpg  at frame  9
error in  ./data/s1/bbizzn.mpg  at frame  10
error in  ./data/s1/bbizzn.mpg  at frame  11
