In [1]:
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

#load pretrained weights
sess = tf.compat.v1.InteractiveSession()
path = "synthesizing_obama_network_training/obama/obama/"
saver = tf.compat.v1.train.import_meta_graph("synthesizing_obama_network_training/obama/obama/model.ckpt-300.meta")
ckpt = tf.train.get_checkpoint_state(path)
saver.restore(sess, ckpt.model_checkpoint_path)

variables = tf.compat.v1.trainable_variables()
weights = []
for var in variables:
    weights.append(sess.run(var))
print(variables)

INFO:tensorflow:Restoring parameters from synthesizing_obama_network_training/obama/obama/model.ckpt-300
[<tf.Variable 'rnnlm/output_w:0' shape=(60, 20) dtype=float32_ref>, <tf.Variable 'rnnlm/output_b:0' shape=(20,) dtype=float32_ref>, <tf.Variable 'rnnlm/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(88, 240) dtype=float32_ref>, <tf.Variable 'rnnlm/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(240,) dtype=float32_ref>]


In [2]:
import _pickle
with open("./synthesizing_obama_network_training/data/training_obama.cpkl", "rb") as f:
    data = _pickle.load(f, encoding = 'latin1')
#get mean and std of train data for batch normalization
meani, stdi, meano, stdo = data["inputmean"], data["inputstd"], data["outputmean"], data["outputstd"]

In [10]:
import torch
from torch import nn
import numpy as np
import os
class FacePredict(nn.Module):
    def __init__(self):
        """
        Initialize using a pretrained tf model
        """
        super().__init__()
        self.lstm = nn.LSTM(28, 60)
        #self.dropout = nn.Dropout(p=0.5)
        self.dense = nn.Linear(60, 20)
        
        #get the weights from tf model
        with torch.no_grad():
            #reorder weights to convert from tf to torch
            wii, wic, wif, wio = np.split(weights[2][:28, :], 4, 1)
            whi, whc, whf, who = np.split(weights[2][28:, :], 4, 1)
            wih = np.concatenate((wii, wif, wic, wio), axis = 1)
            whh = np.concatenate((whi, whf, whc, who), axis = 1)

            self.lstm.weight_ih_l0.data = torch.from_numpy(wih).transpose(0,1)
            self.lstm.weight_hh_l0.data = torch.from_numpy(whh).transpose(0,1)
            self.lstm.bias_hh_l0.data = torch.from_numpy(weights[3])
            self.lstm.bias_ih_l0.data = torch.zeros((240))

            self.dense.weight.data = torch.from_numpy(weights[0].T)
            self.dense.bias.data = torch.from_numpy(weights[1])
            
    def forward(self, inputs):
        hid0, _ = self.lstm(inputs)
        #hiddrop = self.dropout(hid0)
        return self.dense(hid0)


In [4]:
def get_audio_derivatives(audio):
    #calculate audio derivatives, return timestamps too
    audiodiff = audio[1:,:-1] - audio[:-1, :-1]
    times = audio[:, -1]
    return np.concatenate((audio[:-1, :-1], audiodiff[:, :]), axis=1), times

def shifted_time(i, times):
      if i >= 20:
        return times[i - 20]
      else:
        return times[0]

In [5]:
#test our weights on obama
test_audio = np.load("./synthesizing_obama_network_training/obama_data/audio/normalized-cep13/test_audio_preprocessed.wav.npy")
inputs, times = get_audio_derivatives(test_audio)
inputs = (inputs - meani) / stdi
inputs = torch.Tensor(inputs).unsqueeze(1)

fp = FacePredict()
outputs = fp(inputs)
outputs

In [14]:
class FacePredictFineTune(FacePredict):
    def __init__(self):
        super().__init__()
        nn.init.xavier_uniform_(self.dense.weights)
        nn.init.zeros_(self.dense.bias)
        
        self.bn = nn.BatchNorm1D(28) #batch normalization on inputs
    def forward(self, inputs):
        #shape T*B*D (time*batch*num_feat)
        inputs_norm = self.bn(inputs.transpose(1,2)).transpose(1,2)
        hid0, _ = self.lstm(inputs_norm)
        #hiddrop = self.dropout(hid0)
        return self.dense(hid0)

In [15]:
from torch.utils.data import Dataset
import os
class FacePredictDataset(Dataset):
    #post mfcc calculation dataset with audio and landmarks
    def __init__(self, video_names, audio_dir, landmarks_dir):
        self.video_names = video_names
        self.audio_dir = audio_dir
        self.landmarks_dir = landmarks_dir
        
    def __len__(self):
        return len(self.video_names)
    
    def __getitem__(self, idx):
        audio_path = os.path.join(self.audio_dir, self.video_names[idx] + '.wav.npy')
        audio = np.load(audio_path)
        audio_features = get_audio_derivatives(audio)
        
        landmarks_path = os.path.join(self.landmarks_dir, self.video_names[idx] + '.npy')
        landmarks = np.load(landmarks_path)
        return {'audio': audio_features, 'landmarks': landmarks}