In [1]:
import torch
from torch import nn
import numpy as np
import os
class FacePredict(nn.Module):
    def __init__(self):
        """
        Initialize using a pretrained tf model
        """
        super().__init__()
        self.lstm = nn.LSTM(28, 60)
        #self.dropout = nn.Dropout(p=0.5)
        self.dense = nn.Linear(60, 20)
        
    def forward(self, inputs):
        hid0, _ = self.lstm(inputs)
        #hiddrop = self.dropout(hid0)
        return self.dense(hid0)
    
    def load_weights_tf(self):
        #get the weights from tf model
        with torch.no_grad():
            #reorder weights to convert from tf to torch
            wii, wic, wif, wio = np.split(weights[2][:28, :], 4, 1)
            whi, whc, whf, who = np.split(weights[2][28:, :], 4, 1)
            wih = np.concatenate((wii, wif, wic, wio), axis = 1)
            whh = np.concatenate((whi, whf, whc, who), axis = 1)

            self.lstm.weight_ih_l0.data = torch.from_numpy(wih).transpose(0,1)
            self.lstm.weight_hh_l0.data = torch.from_numpy(whh).transpose(0,1)
            self.lstm.bias_hh_l0.data = torch.from_numpy(weights[3])
            self.lstm.bias_ih_l0.data = torch.zeros((240))

            self.dense.weight.data = torch.from_numpy(weights[0].T)
            self.dense.bias.data = torch.from_numpy(weights[1])


In [2]:
def get_audio_derivatives(audio):
    #calculate audio derivatives, return timestamps too
    audiodiff = audio[1:,:-1] - audio[:-1, :-1]
    times = audio[:, -1]
    return np.concatenate((audio[:-1, :-1], audiodiff[:, :]), axis=1), times

def shifted_time(i, times):
      if i >= 20:
        return times[i - 20]
      else:
        return times[0]

In [3]:
class FacePredictFineTune(FacePredict):
    def __init__(self):
        super().__init__()
        nn.init.xavier_uniform_(self.dense.weight)
        nn.init.zeros_(self.dense.bias)
        
        self.bn = nn.BatchNorm1d(28) #batch normalization on inputs
    def forward(self, inputs):
        #shape T*B*D (time*batch*num_feat)
        inputs_norm = self.bn(inputs.transpose(1,2)).transpose(1,2)
        hid0, _ = self.lstm(inputs_norm)
        #hiddrop = self.dropout(hid0)
        return self.dense(hid0)

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
model = FacePredictFineTune()
model.load_state_dict(torch.load('/content/drive/MyDrive/6869/face predict'))
model.double()
#torch.save(fpf.state_dict(), 'face predict')

FacePredictFineTune(
  (lstm): LSTM(28, 60)
  (dense): Linear(in_features=60, out_features=20, bias=True)
  (bn): BatchNorm1d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [6]:
audio_preprocessed = np.load('/content/drive/MyDrive/6869/xAAmF3H0-ek_audio.npy')
audio_data = get_audio_derivatives(audio_preprocessed)[0]
print(audio_data.shape)

(57190, 28)


In [7]:
video_data = np.load('/content/drive/MyDrive/6869/xAAmF3H0-ek_landmarks_frontalized.npy').reshape(-1, 25, 2)
video_lip_fiducials = video_data[:, 5:].reshape(-1, 40)
video_lip_fiducials.shape

(16947, 40)

In [8]:
#crop and shift of video, in seconds
video_start = 12*100//30
video_end = 16958*100//30 #inclusive
video_shft = 200
video_start

40

In [9]:
#preprocess video using pca
from sklearn.decomposition import PCA
pca = PCA(n_components = 20)
lip_features = pca.fit_transform(video_lip_fiducials)

#upsampling
from scipy.interpolate import interp1d
video_times = np.arange(12, 16959)/30
lips_interpolate = interp1d(video_times, lip_features, axis = 0)
audio_times = np.arange(video_start, video_end)/100
lips_upsampled = lips_interpolate(audio_times)

In [10]:
from torch.utils.data import Dataset, DataLoader
import os
class FacePredictDataset(Dataset):
    def __init__(self, inputs, outputs, predict_delay, output_begin, num_cuts = 18):
        #temporally inputs[output_begin] matches with outputs[0]
        #in rnn match inputs[output_begin + predict_delay] with outputs[0] 

        #crop outputs
        output_length = len(outputs)
        crop_len = output_length // num_cuts
        self.outputs = [outputs[crop_len*n:crop_len*(n+1)] for n in range(num_cuts)]

        #find matching parts of inputs
        self.inputs = [inputs[crop_len*n + output_begin: crop_len*(n+1) + output_begin + predict_delay] for n in range(num_cuts)]
        self.len = num_cuts
        self.crop_len = crop_len

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        return [self.inputs[idx], self.outputs[idx], self.crop_len * idx]

In [11]:
data = FacePredictDataset(audio_data, lips_upsampled, video_shft, video_start, 18)
train_data, test_data = torch.utils.data.random_split(data, [16, 2])
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=2, shuffle=True)

In [12]:
from torch import optim
def delay_loss(preds, y, loss, delay):
    return loss(preds[:, delay:, :], y)

loss = nn.MSELoss()
optim = optim.Adam(model.parameters())

In [13]:
from tqdm import tqdm
for epoch in tqdm(range(260)):
    for X, y, _ in train_dataloader:
        preds = model(X.double())
        l = delay_loss(preds, y, loss, video_shft)
        l.backward()
        optim.step()
        optim.zero_grad()
    if epoch % 20 == 19:
        print('epoch: ', epoch)
        print('training loss:', l)
        for X_val, y_val, _ in test_dataloader:
            preds_val = model(X_val.double())
            l_val = delay_loss(preds_val, y_val, loss, video_shft)
            print('validation loss:', l_val)

  8%|▊         | 20/260 [00:12<02:28,  1.62it/s]

epoch:  19
training loss: tensor(8.4759, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(7.9280, dtype=torch.float64, grad_fn=<MseLossBackward>)


 15%|█▌        | 40/260 [00:24<02:15,  1.63it/s]

epoch:  39
training loss: tensor(8.0662, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(7.5390, dtype=torch.float64, grad_fn=<MseLossBackward>)


 23%|██▎       | 60/260 [00:37<02:08,  1.55it/s]

epoch:  59
training loss: tensor(7.8522, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(7.2958, dtype=torch.float64, grad_fn=<MseLossBackward>)


 31%|███       | 80/260 [00:49<01:51,  1.61it/s]

epoch:  79
training loss: tensor(7.6935, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.9541, dtype=torch.float64, grad_fn=<MseLossBackward>)


 38%|███▊      | 100/260 [01:01<01:39,  1.61it/s]

epoch:  99
training loss: tensor(7.5195, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.9716, dtype=torch.float64, grad_fn=<MseLossBackward>)


 46%|████▌     | 120/260 [01:14<01:29,  1.56it/s]

epoch:  119
training loss: tensor(7.4042, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.8356, dtype=torch.float64, grad_fn=<MseLossBackward>)


 54%|█████▍    | 140/260 [01:26<01:15,  1.59it/s]

epoch:  139
training loss: tensor(7.2061, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.5243, dtype=torch.float64, grad_fn=<MseLossBackward>)


 62%|██████▏   | 160/260 [01:39<01:02,  1.59it/s]

epoch:  159
training loss: tensor(7.1122, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.5680, dtype=torch.float64, grad_fn=<MseLossBackward>)


 69%|██████▉   | 180/260 [01:51<00:50,  1.59it/s]

epoch:  179
training loss: tensor(6.9889, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.4661, dtype=torch.float64, grad_fn=<MseLossBackward>)


 77%|███████▋  | 200/260 [02:03<00:37,  1.58it/s]

epoch:  199
training loss: tensor(7.0152, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.2190, dtype=torch.float64, grad_fn=<MseLossBackward>)


 85%|████████▍ | 220/260 [02:16<00:25,  1.59it/s]

epoch:  219
training loss: tensor(6.9694, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.2009, dtype=torch.float64, grad_fn=<MseLossBackward>)


 92%|█████████▏| 240/260 [02:28<00:12,  1.61it/s]

epoch:  239
training loss: tensor(6.9211, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.1688, dtype=torch.float64, grad_fn=<MseLossBackward>)


100%|██████████| 260/260 [02:41<00:00,  1.61it/s]

epoch:  259
training loss: tensor(6.9410, dtype=torch.float64, grad_fn=<MseLossBackward>)
validation loss: tensor(6.2837, dtype=torch.float64, grad_fn=<MseLossBackward>)





In [20]:
X_val, y_val, val_starts = test_dataloader.__iter__().next()
preds_val = model(X_val.double())
sample_lips = pca.inverse_transform(preds_val[0].detach().numpy())
(print(val_starts[0])*30/100) + 12

tensor(25104)


In [21]:
np.save('/content/drive/MyDrive/6869/sample_lips', sample_lips)

In [22]:
25104*30/100 + 12

7543.2

941.4