# Sing2Ani Training Pipeline

### Dataset Configuration

In [1]:
%matplotlib inline

import os
import shutil
import time
from datetime import datetime
from pathlib import Path
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.insert(1, 'pytorch-mdn/mdn')
import mdn

from pydub import AudioSegment

from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import torch.nn as nn

from sklearn.preprocessing import MinMaxScaler

print(torchaudio.get_audio_backend())
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

print(torch.__version__)
print(torchaudio.__version__)

from IPython.display import Audio
from torchaudio.utils import download_asset

torch.random.manual_seed(0)

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

TRAINING_DATA_PATH = "./sampledata/Training" # "C:/Users/Kevin/AppData/LocalLow/kevinjycui/Training"
TESTING_DATA_PATH = "./sampledata/Testing" # "C:/Users/Kevin/AppData/LocalLow/kevinjycui/Testing"

BLENDSHAPE_PARAMS = ["A", "Angry", "Blink", "Blink_L", "Blink_R", "E", "Fun", "I", "Joy", "LookDown", "LookLeft", "LookRight", "LookUp", "Neutral", "O", "Sorrow", "Surprised", "U"]
BONE_PARAMS = ["Chest", "Head", "Hips", "LeftEye", "LeftFoot", "LeftHand", "LeftIndexDistal", "LeftIndexIntermediate", "LeftIndexProximal", "LeftLittleDistal", "LeftLittleIntermediate", "LeftLittleProximal", "LeftLowerArm", "LeftLowerLeg", "LeftMiddleDistal", "LeftMiddleIntermediate", "LeftMiddleProximal", "LeftRingDistal", "LeftRingIntermediate", "LeftRingProximal", "LeftShoulder", "LeftThumbDistal", "LeftThumbIntermediate", "LeftThumbProximal", "LeftToes", "LeftUpperArm", "LeftUpperLeg", "Neck", "RightEye", "RightFoot", "RightHand", "RightIndexDistal", "RightIndexIntermediate", "RightIndexProximal", "RightLittleDistal", "RightLittleIntermediate", "RightLittleProximal", "RightLowerArm", "RightLowerLeg", "RightMiddleDistal", "RightMiddleIntermediate", "RightMiddleProximal", "RightRingDistal", "RightRingIntermediate", "RightRingProximal", "RightShoulder", "RightThumbDistal", "RightThumbIntermediate", "RightThumbProximal", "RightToes", "RightUpperArm", "RightUpperLeg", "Spine", "UpperChest"]

def print_metadata(metadata, blendshapes, bones, src=None):
    if src:
        print("-" * 10)
        print("Source:", src)
        print("-" * 10)
    print(" - sample_rate:", metadata.sample_rate)
    print(" - num_channels:", metadata.num_channels)
    print(" - num_frames:", metadata.num_frames)
    print(" - bits_per_sample:", metadata.bits_per_sample)
    print(" - encoding:", metadata.encoding)
    duration = metadata.num_frames / metadata.sample_rate
    print(" - duration:", duration, end='s\n')
    print(" - num_blendshape_frames:", len(blendshapes))
    print(" - num_bone_frames:", len(bones))
    print(" - blendshape fps:", len(blendshapes) / duration)
    print(" - bone fps:", len(bones) / duration)
    frames_per_blendshape = metadata.num_frames / len(blendshapes)
    frames_per_bone = metadata.num_frames / len(bones)
    print(" - frames per blendshape:", frames_per_blendshape)
    print(" - frames per bone:", frames_per_bone)
    seconds_per_blendshape = duration / len(blendshapes)
    seconds_per_bone = duration / len(bones)
    print(" - seconds per blendshape:", seconds_per_blendshape, end='s\n')
    print(" - seconds per bone:", seconds_per_bone, end='s\n')
    assert frames_per_blendshape == frames_per_bone and seconds_per_blendshape == seconds_per_bone
    print()

def plot_waveform(waveform, sr, title="Waveform"):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sr

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes.plot(time_axis, waveform[0], linewidth=1)
        axes.grid(True)
    else:
        for axis in axes:
            axis.plot(time_axis, waveform[0], linewidth=1)
            axis.grid(True)
    figure.suptitle(title)

def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Spectrogram (db)")
    axs.set_ylabel(ylabel)
    axs.set_xlabel("frame")
    im = axs.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto")
    fig.colorbar(im, ax=axs)

class VRMParamsDataset(Dataset):
    """ VRM Parameter Dataset """

    n_mfcc = 39
    n_vrmframes = 64

    def __init__(self, filename, DATA_PATH=TRAINING_DATA_PATH, effects=None, audio_only=False, bs_scaler=None, bone_scaler=None):
        """
        Arguments:
            filename (string): Path to wav/csv files.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.audio_only = audio_only
        self.name = Path(filename).stem
        self.DATA_PATH = DATA_PATH

        self.audio_path = self.DATA_PATH + f"/Audio/{self.name}.wav"

        self.SPEECH_WAVEFORM, self.SAMPLE_RATE = torchaudio.load(self.audio_path)
        self.SPEECH_WAVEFORM = torch.mean(self.SPEECH_WAVEFORM, dim=0).unsqueeze(0)

        if self.audio_only:
            return

        self.blendshapes = pd.read_csv(DATA_PATH + "/Blendshapes/" + self.name + ".csv")
        self.bones = pd.read_csv(DATA_PATH + "/Bones/" + self.name + ".csv")

        self.bs_scaler = bs_scaler
        if self.bs_scaler == None:
            self.bs_scaler = MinMaxScaler(feature_range=(0, 1)).fit(self.blendshapes.values)
        bs_scaled = self.bs_scaler.transform(self.blendshapes.values)
        self.blendshape = pd.DataFrame(bs_scaled)

        self.bone_scaler = bone_scaler
        if self.bone_scaler == None:
            self.bone_scaler = MinMaxScaler(feature_range=(0, 1)).fit(self.bones.values)
        bone_scaled = self.bone_scaler.transform(self.bones.values)
        self.bones = pd.DataFrame(bone_scaled)

        assert len(self.blendshapes) == len(self.bones)

        metadata = torchaudio.info(self.audio_path)
        print_metadata(metadata, self.blendshapes, self.bones, src=self.audio_path)

        self.STATIC_FRAME = (metadata.num_frames / self.SAMPLE_RATE) / len(self.blendshapes)

        for idx in range(len(self.blendshapes)):
            approx_time = (idx + 1) * self.STATIC_FRAME
            self.blendshapes.iloc[idx, 0] = approx_time
            self.bones.iloc[idx, 0] = approx_time

        if effects:
            self.SPEECH_WAVEFORM, self.SAMPLE_RATE = torchaudio.sox_effects.apply_effects_tensor(self.SPEECH_WAVEFORM, self.SAMPLE_RATE, effects)

        self.init_mfcc()

    def init_mfcc(self):
        def mfcc_transform(n_mfcc, n_mels, n_fft, win_length, hop_length):
            print("MFCC with")
            print(" - number of mfcc:", n_mfcc)
            print(" - number of mels:", n_mels)
            print(" - number of fft:", n_fft)
            print(" - window length:", win_length)
            print(" - hop length:", hop_length)
            return T.MFCC(
                sample_rate=self.SAMPLE_RATE,
                n_mfcc=n_mfcc,
                melkwargs={
                    "n_fft": n_fft,
                    "n_mels": n_mels,
                    "win_length": win_length,
                    "hop_length": hop_length,
                    "window_fn": torch.hann_window
                },
            )
        
        n_mfcc = self.n_mfcc
        n_mels = n_mfcc * 2
        n_fft = int(self.STATIC_FRAME * self.SAMPLE_RATE) // 16
        win_length = n_fft
        hop_length = n_fft // 2

        self.mfcc = mfcc_transform(n_mfcc, n_mels, n_fft, win_length, hop_length)(self.SPEECH_WAVEFORM)
        self.hop_length = hop_length
        print(" - number of mfcc frames:", len(self.mfcc[0][0]))
        print()

    def input_dim(self):
        return self.n_mfcc * self.n_vrmframes
    
    def output_dim(self):
        return len(self.blendshapes.iloc[0]) + len(self.bones.iloc[0]) - 2

    def _set_static_frame(self, _sf):
        self.STATIC_FRAME = _sf
        self.init_mfcc()

    def __len__(self):
        if self.audio_only:
            metadata = torchaudio.info(self.audio_path)
            return int((metadata.num_frames / metadata.sample_rate) / self.STATIC_FRAME)
        return len(self.blendshapes)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Number of frames before and after timestamp to get
        n_vrmframes_half = self.n_vrmframes // 2

        mfcc_index = min((idx + 1) * n_vrmframes_half, self.mfcc.shape[2] - n_vrmframes_half - 1)
        mfcc_frame = torch.zeros(self.n_mfcc, self.n_vrmframes)
        # print(idx, mfcc_index, n_vrmframes_half)
        mfcc_frame[:] = self.mfcc[0,:,mfcc_index - n_vrmframes_half:mfcc_index + n_vrmframes_half]

        if self.audio_only:
            return mfcc_frame, torch.empty(1)
                
        time_window = self.blendshapes.iloc[idx, 0]

        blendshape_params = self.blendshapes.iloc[idx, 1:]
        blendshape_params = np.asarray(blendshape_params)
        blendshape_params = blendshape_params.astype('float')

        assert time_window == self.bones.iloc[idx, 0]

        bone_params = self.bones.iloc[idx, 1:]
        bone_params = np.asarray(bone_params)
        bone_params = bone_params.astype('float')

        return mfcc_frame, torch.Tensor(np.concatenate((blendshape_params, bone_params)))
    

effect = [["sinc", "300-3k"]]
effect = None # Windows

train_file = "8-8-2023 3-15-42 PM"
valid_file = "2023-08-15 10-05-57 PM"

train = VRMParamsDataset(train_file, TRAINING_DATA_PATH, effect)
valid = VRMParamsDataset(valid_file, TRAINING_DATA_PATH, effect, bs_scaler=train.bs_scaler, bone_scaler=train.bone_scaler)
test = VRMParamsDataset("7-17-2023 4-53-18 PM", TRAINING_DATA_PATH, effect, bs_scaler=train.bs_scaler, bone_scaler=train.bone_scaler)
# test._set_static_frame(train.STATIC_FRAME)

soundfile
cuda:0
2.0.1+cu118
2.0.2+cu118
----------
Source: ./sampledata/Training/Audio/8-8-2023 3-15-42 PM.wav
----------
 - sample_rate: 48000
 - num_channels: 1
 - num_frames: 86400000
 - bits_per_sample: 16
 - encoding: PCM_S
 - duration: 1800.0s
 - num_blendshape_frames: 95449
 - num_bone_frames: 95449
 - blendshape fps: 53.02722222222222
 - bone fps: 53.02722222222222
 - frames per blendshape: 905.195444687739
 - frames per bone: 905.195444687739
 - seconds per blendshape: 0.018858238430994562s
 - seconds per bone: 0.018858238430994562s

MFCC with
 - number of mfcc: 39
 - number of mels: 78
 - number of fft: 56
 - window length: 56
 - hop length: 28
 - number of mfcc frames: 3085715

----------
Source: ./sampledata/Training/Audio/2023-08-15 10-05-57 PM.wav
----------
 - sample_rate: 48000
 - num_channels: 1
 - num_frames: 8640000
 - bits_per_sample: 16
 - encoding: PCM_S
 - duration: 180.0s
 - num_blendshape_frames: 12343
 - num_bone_frames: 12343
 - blendshape fps: 68.5722222222

In [2]:
VISUAL = False

### Play audio of training data

In [3]:
if (VISUAL):
    Audio(train.SPEECH_WAVEFORM, rate=train.SAMPLE_RATE)

### Play audio of testing data

In [4]:
if (VISUAL):
    Audio(test.SPEECH_WAVEFORM, rate=test.SAMPLE_RATE)

### Waveform of training data

In [5]:
if (VISUAL):
    plot_waveform(train.SPEECH_WAVEFORM, train.SAMPLE_RATE, title="Training audio")
    plt.show()

In [6]:
print(train.mfcc[0].shape)
if (VISUAL):
    plot_spectrogram(train.mfcc[0])
    plt.show()

torch.Size([39, 3085715])


### Waveform of testing data

In [7]:
if (VISUAL):
    plot_waveform(test.SPEECH_WAVEFORM, test.SAMPLE_RATE, title="Testing audio")
    plt.show()

In [8]:
print(test.mfcc[0].shape)
if (VISUAL):
    plot_spectrogram(test.mfcc[0])
    plt.show()

torch.Size([39, 17246])


### Prepare model

In [9]:
batch_size = 50

train_loader = DataLoader(train, batch_size=batch_size, shuffle=False)
valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)

_x, _y = train[0]
_x.shape, _y.shape

(torch.Size([39, 64]), torch.Size([396]))

In [10]:
class VRMLoss(nn.Module):
    bs_smooth_diff = 0.1
    bone_smooth_diff = 0.4

    def __init__(self, weights): #, non_expr_face_bs_weight=1.0):
        super(VRMLoss, self).__init__()
        self.bs_huber_weight = weights['bs']['huber']
        self.bs_smooth_weight = weights['bs']['smooth']

        self.bone_huber_weight = weights['bone']['huber']
        self.bone_smooth_weight = weights['bone']['smooth']
        
        # self.non_expr_face_bs_weight = non_expr_face_bs_weight

    def forward(self, output, target):

        bs_output = output.squeeze()[:len(BLENDSHAPE_PARAMS)]
        bs_target = target.squeeze()[:len(BLENDSHAPE_PARAMS)]

        bone_output = output.squeeze()[len(BLENDSHAPE_PARAMS):]
        bone_target = target.squeeze()[len(BLENDSHAPE_PARAMS):]
        
        # n_loss = torch.nn.MSELoss()(bs_output, bs_target)

        # def expr_blendshapes(blendshapes):
        #     params = ['A', 'Angry', 'E', 'Fun', 'I', 'Joy', 'Neutral', 'O', 'U', 'Sorrow', 'Surprised']
        #     # print(blendshapes.shape)
        #     return torch.cat(tuple([blendshapes[:,BLENDSHAPE_PARAMS.index(p)] for p in params]), dim=0)

        # bs_output = expr_blendshapes(bs_output)
        # bs_target = expr_blendshapes(bs_target)
        
        bs_prev = None
        bs_smooth = 0.
        
        bone_prev = None
        bone_smooth = 0.
        
        bs_huber = nn.HuberLoss()(bs_output, bs_target)
        for sample in range(len(bs_target)):
            # print(output[sample].shape, target[sample].shape)
            # print(bs_output[sample].shape, bs_target[sample].shape)
            bs_smooth += 0 if bs_prev is None else 1. - abs(self.bs_smooth_diff - nn.CosineSimilarity(dim=0)(bs_output[sample], bs_prev).mean())
            bs_prev = bs_output[sample]
        
        bone_huber = nn.HuberLoss()(bone_output, bone_target)
        for sample in range(len(bone_target)):
            bone_smooth += 0 if bone_prev is None else 1. - abs(self.bone_smooth_diff - nn.CosineSimilarity(dim=0)(bone_output[sample], bone_prev).mean())
            bone_prev = bone_output[sample]

        def a2f_loss(huber, smooth, length, w=[1.0, 1.0]):
            return (w[0]*huber + w[1]*smooth) / length
        
        return a2f_loss(bs_huber, bs_smooth, len(bs_target), [self.bs_huber_weight, self.bs_smooth_weight]) + a2f_loss(bone_huber, bone_smooth, len(bone_target), [self.bone_huber_weight, self.bone_smooth_weight]) # + self.non_expr_face_bs_weight * n_loss

In [11]:
class BiLSTM(nn.Module):

    def __init__(self, input_dim, hidden_dim, batch_size, output_dim, num_layers, p):
        """
        Arguments:
            input_dim: Input layer dimension
            hidden_dim: Hidden layer dimension
            batch_size: Batch size of data
            output_dim: Output layer dimension
            num_layers: Number of layers
            p: Dropout
        """
        super(BiLSTM, self).__init__()

        self.dropout = nn.Dropout(p)
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers
        
        # self.init_linear = nn.Linear(self.input_dim, self.input_dim)
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True, dropout=p)

        self.mdn = mdn.MDN(self.hidden_dim * 2, self.hidden_dim * 2, 1)
        self.linear_hidden = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        
        self.energy = nn.Linear(self.hidden_dim*3, 1)
        self.softmax = nn.Softmax(dim=0)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(self.hidden_dim*4, output_dim)

    def init_hidden(self):
        w1 = torch.randn(self.num_layers, self.batch_size, self.hidden_dim)
        w2 = torch.randn(self.num_layers, self.batch_size, self.hidden_dim)
        return w1, w2

    def forward(self, input):
        input = self.dropout(input)

        # print('----------------------------------------------------------------------------------')
        # print(input)
        
        lstm_out, (hidden, cell) = self.lstm(input)
        # print('               LSTM')
        # print(lstm_out)

        # print('               MDN')
        # print(mdn_out)

        # print(hidden[0:2].shape)
        
        hidden = self.linear_hidden(hidden[0:2].reshape(1, -1, self.hidden_dim * 2)).permute(1, 0, 2)
        # print('               HIDD')
        # print(hidden)

        attn = self.softmax(self.relu(self.energy(torch.cat((hidden, lstm_out), dim=2))))
        # print('               ATTN')
        # print(attn)
        context = torch.bmm(attn, lstm_out).permute(1, 0, 2)
        mdn_out = self.mdn(context)[2].permute(1, 0, 2)

        # print(context.shape, lstm_out.shape)
        y_pred = self.fc(torch.cat((mdn_out, lstm_out.permute(1, 0, 2)), dim=2)).squeeze()
        # print('               PRED')
        # print(y_pred)
        return y_pred

n_epochs = 3
lr = 0.0001
lstm_input_size = train.input_dim()
hidden_state_size = 512
num_sequence_layers = 2
output_dim = train.output_dim()
save_interval = 50
dropoff = 0.02

model = BiLSTM(lstm_input_size, hidden_state_size, batch_size, output_dim, num_sequence_layers, dropoff)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

### Training

In [12]:
model.train()

BiLSTM(
  (dropout): Dropout(p=0.02, inplace=False)
  (lstm): LSTM(2496, 512, num_layers=2, batch_first=True, dropout=0.02, bidirectional=True)
  (mdn): MDN(
    (pi): Sequential(
      (0): Linear(in_features=1024, out_features=1, bias=True)
      (1): Softmax(dim=1)
    )
    (sigma): Linear(in_features=1024, out_features=1024, bias=True)
    (mu): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (linear_hidden): Linear(in_features=1024, out_features=512, bias=True)
  (energy): Linear(in_features=1536, out_features=1, bias=True)
  (softmax): Softmax(dim=0)
  (relu): ReLU()
  (fc): Linear(in_features=2048, out_features=396, bias=True)
)

In [13]:
weights = {
    'bs': {
        'huber': 9.0,
        'smooth': 0.2
    },
    'bone': {
        'huber': 1.0,
        'smooth': 0.2
    }
}
loss_fn = VRMLoss(weights)
# valid_loss_fn = nn.MSELoss()

In [None]:
def save_model_state(state_dict, name, train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss):
    torch.save(state_dict, 'models/{}.pt'.format(name))
    with open('models/meta-{}.txt'.format(name), 'w') as f:
        metastring = 'Train file: {}\nBatch size: {}\nNumber of epochs: {}\nLearning rate: {}\nHidden state size: {}\nNumber of LSTM layers: {}\nWeights: {}\nValidation loss: {}'
        f.write(metastring.format(train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, str(weights), valid_loss))

def save_model(model, name, train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss):
    save_model_state(model.state_dict(), name, train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss)

In [None]:
min_valid_loss = np.inf
best_model = None
best_model_epoch = -1

for epoch in range(n_epochs):
    start_time = time.time()
    train_loss = 0.0
    model.train()
    for i, (mfcc, vrm_params) in enumerate(train_loader):
        mfcc = mfcc.to(device=device).reshape(-1, 1, lstm_input_size)
        vrm_params = vrm_params.to(device=device).reshape(-1, 1, output_dim)

        pred = model(mfcc)

        loss = loss_fn(pred, vrm_params)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    train_loss /= len(train_loader)
        
    valid_loss = 0.0
    model.eval()
    for i, (mfcc, vrm_params) in enumerate(valid_loader):
        mfcc = mfcc.to(device=device).reshape(-1, 1, lstm_input_size)
        vrm_params = vrm_params.to(device=device).reshape(-1, 1, output_dim)
        
        pred = model(mfcc)
        
        loss = loss_fn(pred, vrm_params)
        valid_loss += loss.item()

    valid_loss /= len(valid_loader)

    print('Epoch {}/{} \tTrain loss={:.10f} \tValid loss={:.10f}\tTime={:.2f}s'.format(epoch + 1, n_epochs, train_loss, valid_loss, time.time() - start_time))

    if ((epoch+1) % save_interval == 0):
        save_model(model, 'sample-model-' + datetime.now().strftime("%Y-%m-%d_%H-%M-%SZ"), train_file, batch_size, epoch, lr, hidden_state_size, num_sequence_layers, weights, valid_loss)
    
    if min_valid_loss > valid_loss:
        min_valid_loss = valid_loss
        best_model = model.state_dict()
        best_model_epoch = epoch

        print('Best model: Model at epoch {} with valid loss {:.10f}'.format(best_model_epoch, min_valid_loss))

print('Best model: Model at epoch {} with valid loss {:.10f}'.format(best_model_epoch, min_valid_loss))

Epoch 1/50 	Train loss=0.0987156145 	Valid loss=0.1011988965	Time=109.99s
Best model: Model at epoch 0 with valid loss 0.1011988965
Epoch 2/50 	Train loss=0.0980066061 	Valid loss=0.1011967421	Time=109.98s
Best model: Model at epoch 1 with valid loss 0.1011967421
Epoch 3/50 	Train loss=0.0980107297 	Valid loss=0.1012119357	Time=109.15s


In [None]:
weights_file = 'model_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%SZ")
# Save latest model
save_model(model, weights_file + '_LATEST', train_file, batch_size, n_epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss)
# Save best model
save_model_state(best_model, weights_file + '_BEST', train_file, batch_size, best_model_epoch, lr, hidden_state_size, num_sequence_layers, weights, min_valid_loss)

In [None]:
model = BiLSTM(lstm_input_size, hidden_state_size, batch_size, output_dim, num_sequence_layers, dropoff)
model = model.to(device)
model.load_state_dict(torch.load('models/{}.pt'.format(weights_file + '_BEST'), map_location=device))

model

### Prediction

In [None]:
model.eval()

In [None]:
test_preds = torch.zeros(output_dim * len(test))
test_actual = torch.zeros(output_dim * len(test))

for i, (mfcc, vrm_params) in enumerate(test_loader):
    mfcc = mfcc.to(device=device).reshape(-1, 1, lstm_input_size)
    if not test.audio_only:
        vrm_params = vrm_params.cpu().reshape(-1)
    y_pred = model(mfcc).cpu().detach().reshape(-1)
    test_preds[i*batch_size*output_dim:(i+1)*batch_size*output_dim] = y_pred
    if not test.audio_only:
        test_actual[i*batch_size*output_dim:(i+1)*batch_size*output_dim] = vrm_params

if not test.audio_only:
    rmse = torch.sqrt(torch.nn.functional.mse_loss(test_preds, test_actual))
    print('RMSE: {:.10f}'.format(rmse))
        
test_preds = test_preds.reshape(len(test), output_dim)
test_actual = test_actual.reshape(len(test), output_dim)

test_preds.shape

In [None]:
test_preds

In [None]:
PREDICTION_DATA_PATH = './sampledata/Prediction'

def to_csv(test_preds, name='prediction'):
    blendshape_params = pd.DataFrame(test_preds[:,:len(BLENDSHAPE_PARAMS)])
    blendshape_params = pd.DataFrame(train.bs_scaler.inverse_transform(blendshape_params.values))
    blendshape_params.columns = BLENDSHAPE_PARAMS
    
    bone_params = pd.DataFrame(test_preds[:,len(BLENDSHAPE_PARAMS):])
    bone_params = pd.DataFrame(train.bone_scaler.inverse_transform(bone_params.values))
    bone_params.columns = [bone + t for t in ('PosX', 'PosY', 'PosZ', 'RotX', 'RotY', 'RotZ', 'RotW') for bone in BONE_PARAMS]
    
    STATIC_FRAME = test.STATIC_FRAME
    
    time_column = pd.DataFrame({'Time': [float(i)*STATIC_FRAME for i in range(len(test_preds))]})
    
    blendshape_params = pd.concat([time_column, blendshape_params], axis=1)
    bone_params = pd.concat([time_column, bone_params], axis=1)
    
    blendshape_params.to_csv(PREDICTION_DATA_PATH + '/Blendshapes/' + name + '.csv', index=False)
    bone_params.to_csv(PREDICTION_DATA_PATH + '/Bones/' + name + '.csv', index=False)
    
    shutil.copyfile(test.DATA_PATH + '/Audio/' + test.name + '.wav', PREDICTION_DATA_PATH + '/Audio/' + name + '.wav')

datestring = datetime.now().strftime("%Y-%m-%d_%H-%M-%SZ")
to_csv(test_preds, 'prediction-' + datestring)
if not test.audio_only:
    to_csv(test_actual, 'actual-' + datestring)