In [1]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from sda.encoder_image import Encoder
from sda.img_generator import Generator
from sda.rnn_audio import RNN
from sda.encoder_audio import Encoder as AEncoder

from scipy import signal
from skimage import transform as tf
import numpy as np
from PIL import Image
import contextlib
import shutil
import skvideo.io as sio
import scipy.io.wavfile as wav
import ffmpeg
import face_alignment
from pydub import AudioSegment
from pydub.utils import mediainfo

import glob

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

dev = torch.device("cuda:0")

Random Seed:  999


# Network

In [2]:
model_path = "/home/jarrod/dev/speech-driven-animation/sda/data/grid.dat"

device = torch.device("cuda:" + str(0))
model_dict = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(0))
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device="cpu", flip_input=False)

stablePntsIDs = [33, 36, 39, 42, 45]
mean_face = model_dict["mean_face"]
img_size = model_dict["img_size"]
audio_rate = model_dict["audio_rate"]
video_rate = model_dict["video_rate"]
audio_feat_len = model_dict['audio_feat_len']
audio_feat_samples = model_dict['audio_feat_samples']
id_enc_dim = model_dict['id_enc_dim']
rnn_gen_dim = model_dict['rnn_gen_dim']
aud_enc_dim = model_dict['aud_enc_dim']
# I think this is the size of the noise vector
aux_latent = model_dict['aux_latent']
# sequential noise is a boolean value
sequential_noise = model_dict['sequential_noise']
conversion_dict = {'s16': np.int16, 's32': np.int32}
        
# image preprocessing
img_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((img_size[0], img_size[1])),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


def preprocess_img(img):
        src = fa.get_landmarks(img)[0][stablePntsIDs, :]
        dst = mean_face[stablePntsIDs, :]
        tform = tf.estimate_transform('similarity', src, dst)  # find the transformation matrix
        warped = tf.warp(img, inverse_map=tform.inverse, output_shape=img_size)  # wrap the frame image
        warped = warped * 255  # note output from wrap is double image (value range [0,1])
        warped = warped.astype('uint8')

        return warped

def _cut_sequence_(seq, cutting_stride, pad_samples):
    pad_left = torch.zeros(pad_samples // 2, 1)
    pad_right = torch.zeros(pad_samples - pad_samples // 2, 1)

    seq = torch.cat((pad_left, seq), 0)
    seq = torch.cat((seq, pad_right), 0)

    stacked = seq.narrow(0, 0, audio_feat_samples).unsqueeze(0)
    iterations = (seq.size()[0] - audio_feat_samples) // cutting_stride + 1
    for i in range(1, iterations):
        stacked = torch.cat((stacked, seq.narrow(0, i * cutting_stride, audio_feat_samples).unsqueeze(0)))
    return stacked#.to(self.device)

def genSample(img, audio, fs=None, aligned=False):
        if isinstance(img, str):  # if we have a path then grab the image
            frm = Image.open(img)
            frm.thumbnail((400, 400))
            frame = np.array(frm)
        else:
            frame = img

        # handle aligning the face with the model's learned "mean face"
        # may also do some preprocessing
        if not aligned:
            frame = preprocess_img(frame)

        # if we have a path then grab the audio clip
        if isinstance(audio, str):  
            info = mediainfo(audio)
            fs = int(info['sample_rate'])
            audio = np.array(AudioSegment.from_file(audio, info['format_name']).set_channels(1).get_array_of_samples())

            if info['sample_fmt'] in conversion_dict:
                audio = audio.astype(conversion_dict[info['sample_fmt']])
            else:
                if max(audio) > np.iinfo(np.int16).max:
                    audio = audio.astype(np.int32)
                else:
                    audio = audio.astype(np.int16)

        if fs is None:
            raise AttributeError("Audio provided without specifying the rate. Specify rate or use audio file!")

        if audio.ndim > 1 and audio.shape[1] > 1:
            audio = audio[:, 0]

        max_value = np.iinfo(audio.dtype).max
        
        if fs != audio_rate:
            seq_length = audio.shape[0]
            speech = torch.from_numpy(
                signal.resample(audio, int(seq_length * audio_rate / float(fs))) / float(max_value)).float()
            speech = speech.view(-1, 1)
            
        else:
            audio = torch.from_numpy(audio / float(max_value)).float()
            speech = audio.view(-1, 1)

#         take the input image and preprocess it    
        frame = img_transform(frame)#.to(self.device)

        cutting_stride = int(audio_rate / float(video_rate))
        audio_seq_padding = audio_feat_samples - cutting_stride

        # Create new sequences of the audio windows
        audio_feat_seq = _cut_sequence_(speech, cutting_stride, audio_seq_padding)
        frame = frame.unsqueeze(0)
        audio_feat_seq = audio_feat_seq.unsqueeze(0)
        audio_feat_seq_length = audio_feat_seq.size()[1]
    
        return speech, audio_feat_seq, audio_feat_seq_length, frame
   

In [11]:
a, a1, a2, frame = genSample("example/male_face2.jpg", "example/hello_world.wav", aligned=True)

# Debug

In [12]:
class RNN(nn.Module):
    def __init__(self, feat_length, enc_code_size, rnn_code_size, rate, n_layers=2, init_kernel=None,
                 init_stride=None):
        super(RNN, self).__init__()
        self.audio_feat_samples = int(rate * feat_length)
        self.enc_code_size = enc_code_size
        self.rnn_code_size = rnn_code_size
        self.encoder = AEncoder(self.enc_code_size, rate, feat_length, init_kernel=init_kernel,
                               init_stride=init_stride)
        self.rnn = nn.GRU(self.enc_code_size, self.rnn_code_size, n_layers, batch_first=True)

    def forward(self, x, lengths):
        seq_length = x.size()[1]
        print("encoder11 ", x.requires_grad)
        x = x.reshape(-1, 1, self.audio_feat_samples)
        print("encoder22 ", x.requires_grad)
        x = self.encoder(x)
        x = x.view(-1, seq_length, self.enc_code_size)
        print("encoder33 ", x.requires_grad)
        x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
#         print("encoder44 ", x.requires_grad)
#         print(x.shape)
#         print(self.enc_code_size, self.rnn_code_size)
        x, h = self.rnn(x)
#         print("encoder55 ", x.requires_grad)
        
        x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        print("encoder66 ", x.requires_grad)
        return x.contiguous()

In [13]:
# audio encoder
# size of noise vector
aux_latent = 10
sequential_noise = True
img_size = (128,96)
rnn_gen_dim = 256
id_enc_dim = 128
aud_enc_dim = 256
audio_feat_len = 0.2
audio_rate = 50000
encoder = RNN(audio_feat_len, aud_enc_dim, rnn_gen_dim, audio_rate, init_kernel=0.005, init_stride=0.001)

In [14]:


X = torch.ones((1,71,10000,1))
X.requires_grad = True
# X = X.to(dev)
X.requires_grad

True

In [16]:
a1 = a1.cuda()
a1.requires_grad = True

NameError: name 'a1' is not defined

In [15]:
a3 = torch.Tensor([10])
a3.requires_grad = True
z = encoder(X, a3)

encoder11  True
encoder22  True
encoder33  True
encoder66  True
