# Sing2Ani Training Pipeline

### Dataset Configuration

In [1]:
%matplotlib inline

import os
import shutil
import time
from datetime import datetime
from pathlib import Path
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.insert(1, 'pytorch-mdn/mdn')
import mdn

from pydub import AudioSegment

from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import torch.nn as nn

from sklearn.preprocessing import MinMaxScaler

print(torchaudio.get_audio_backend())
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

print(torch.__version__)
print(torchaudio.__version__)

from IPython.display import Audio
from torchaudio.utils import download_asset

# torch.random.manual_seed(0)

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

TRAINING_DATA_PATH = "./sampledata/Training" # "C:/Users/Kevin/AppData/LocalLow/kevinjycui/Training"
TESTING_DATA_PATH = "./sampledata/Testing" # "C:/Users/Kevin/AppData/LocalLow/kevinjycui/Testing"

# VMC Protocol Standard
BLENDSHAPE_PARAMS = ["A", "Angry", "Blink", "Blink_L", "Blink_R", "E", "Fun", "I", "Joy", "LookDown", "LookLeft", "LookRight", "LookUp", "Neutral", "O", "Sorrow", "Surprised", "U"]
BONE_PARAMS = ["Chest", "Head", "Hips", "LeftEye", "LeftFoot", "LeftHand", "LeftIndexDistal", "LeftIndexIntermediate", "LeftIndexProximal", "LeftLittleDistal", "LeftLittleIntermediate", "LeftLittleProximal", "LeftLowerArm", "LeftLowerLeg", "LeftMiddleDistal", "LeftMiddleIntermediate", "LeftMiddleProximal", "LeftRingDistal", "LeftRingIntermediate", "LeftRingProximal", "LeftShoulder", "LeftThumbDistal", "LeftThumbIntermediate", "LeftThumbProximal", "LeftToes", "LeftUpperArm", "LeftUpperLeg", "Neck", "RightEye", "RightFoot", "RightHand", "RightIndexDistal", "RightIndexIntermediate", "RightIndexProximal", "RightLittleDistal", "RightLittleIntermediate", "RightLittleProximal", "RightLowerArm", "RightLowerLeg", "RightMiddleDistal", "RightMiddleIntermediate", "RightMiddleProximal", "RightRingDistal", "RightRingIntermediate", "RightRingProximal", "RightShoulder", "RightThumbDistal", "RightThumbIntermediate", "RightThumbProximal", "RightToes", "RightUpperArm", "RightUpperLeg", "Spine", "UpperChest"]

def print_metadata(metadata, vrm, src=None):
    if src:
        print("-" * 10)
        print("Source:", src)
        print("-" * 10)
    print(" - sample_rate:", metadata.sample_rate)
    print(" - num_channels:", metadata.num_channels)
    print(" - num_frames:", metadata.num_frames)
    print(" - bits_per_sample:", metadata.bits_per_sample)
    print(" - encoding:", metadata.encoding)
    duration = metadata.num_frames / metadata.sample_rate
    print(" - duration:", duration, end='s\n')
    print(" - num_vrm_frames:", len(vrm))
    print(" - vrm fps:", len(vrm) / duration)
    frames_per_vrm = metadata.num_frames / len(vrm)
    print(" - frames_per_vrm:", frames_per_vrm)
    seconds_per_vrm = duration / len(vrm)
    print(" - seconds_per_vrm:", seconds_per_vrm, end='s\n')
    print()

def plot_waveform(waveform, sr, title="Waveform"):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sr

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes.plot(time_axis, waveform[0], linewidth=1)
        axes.grid(True)
    else:
        for axis in axes:
            axis.plot(time_axis, waveform[0], linewidth=1)
            axis.grid(True)
    figure.suptitle(title)

def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Spectrogram (db)")
    axs.set_ylabel(ylabel)
    axs.set_xlabel("frame")
    im = axs.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto")
    fig.colorbar(im, ax=axs)

class VRMParamsDataset(Dataset):
    """ VRM Parameter Dataset """

    n_mfcc = 39
    n_fft_per_vrm = 16
    n_vrmframes = 64

    def __init__(self, filename, DATA_PATH=TRAINING_DATA_PATH, effects=None, audio_only=False, scaler=None):
        """
            filename (string): Path to wav/csv files.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.audio_only = audio_only
        self.name = Path(filename).stem
        self.DATA_PATH = DATA_PATH

        self.audio_path = self.DATA_PATH + f"/Audio/{self.name}.wav"

        self.SPEECH_WAVEFORM, self.SAMPLE_RATE = torchaudio.load(self.audio_path)
        self.SPEECH_WAVEFORM = torch.mean(self.SPEECH_WAVEFORM, dim=0).unsqueeze(0)

        if self.audio_only:
            return

        blendshapes = pd.read_csv(DATA_PATH + "/Blendshapes/" + self.name + ".csv")
        bones = pd.read_csv(DATA_PATH + "/Bones/" + self.name + ".csv")

        # Combine blendshape and bone data
        self.data = pd.concat([blendshapes, bones.iloc[:, 1:]], axis=1)

        metadata = torchaudio.info(self.audio_path)
        print_metadata(metadata, self.data, src=self.audio_path)

        # Set static time window for each row
        self.STATIC_FRAME = (metadata.num_frames / self.SAMPLE_RATE) / len(self.data)

        for idx in range(len(self.data)):
            approx_time = (idx + 1) * self.STATIC_FRAME
            self.data.iloc[idx, 0] = approx_time

        # Normalize data
        self.scaler = scaler
        if self.scaler == None:
            self.scaler = MinMaxScaler(feature_range=(0, 1)).fit(self.data.iloc[:, 1:])

        data_scaled = self.scaler.transform(self.data.iloc[:, 1:].values)
        data_scaled = pd.DataFrame(data_scaled)
        self.data = pd.concat([self.data.iloc[:, 0], data_scaled], axis=1)

        # Apply effects such as noise reduction
        if effects:
            self.SPEECH_WAVEFORM, self.SAMPLE_RATE = torchaudio.sox_effects.apply_effects_tensor(self.SPEECH_WAVEFORM, self.SAMPLE_RATE, effects)

        self.init_mfcc()

    def init_mfcc(self):
        '''
        Create MFCC from audio
        '''
        def mfcc_transform(n_mfcc, n_mels, n_fft, win_length, hop_length):
            print("MFCC with")
            print(" - number of mfcc:", n_mfcc)
            print(" - number of mels:", n_mels)
            print(" - number of fft:", n_fft)
            print(" - window length:", win_length)
            print(" - hop length:", hop_length)
            return T.MFCC(
                sample_rate=self.SAMPLE_RATE,
                n_mfcc=n_mfcc,
                melkwargs={
                    "n_fft": n_fft,
                    "n_mels": n_mels,
                    "win_length": win_length,
                    "hop_length": hop_length,
                    "window_fn": torch.hann_window
                },
            )
        # Number of MFCC
        n_mfcc = self.n_mfcc
        n_mels = n_mfcc * 2
        # FFT is every nth of a window, allows for n*2 FFTs per VRM/VMC with potential overlap
        n_fft = int(self.STATIC_FRAME * self.SAMPLE_RATE) // self.n_fft_per_vrm
        win_length = n_fft
        hop_length = n_fft // 2

        self.mfcc = mfcc_transform(n_mfcc, n_mels, n_fft, win_length, hop_length)(self.SPEECH_WAVEFORM)
        self.hop_length = hop_length
        print(" - number of mfcc frames:", len(self.mfcc[0][0]))
        print()

    def input_dim(self):
        return self.n_mfcc * self.n_vrmframes
    
    def output_dim(self):
        return len(self.data.iloc[0]) - 1

    def _set_static_frame(self, _sf):
        self.STATIC_FRAME = _sf
        self.init_mfcc()

    def __len__(self):
        if self.audio_only:
            metadata = torchaudio.info(self.audio_path)
            return int((metadata.num_frames / metadata.sample_rate) / self.STATIC_FRAME)
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Number of frames before and after timestamp to get
        n_vrmframes_half = self.n_vrmframes // 2

        mfcc_index = min((idx + 1) * n_vrmframes_half, self.mfcc.shape[2] - n_vrmframes_half - 1)
        mfcc_frame = torch.zeros(self.n_mfcc, self.n_vrmframes)
        # print(idx, mfcc_index, n_vrmframes_half)
        mfcc_frame[:] = self.mfcc[0,:,mfcc_index - n_vrmframes_half:mfcc_index + n_vrmframes_half]

        if self.audio_only:
            return mfcc_frame, torch.empty(1)
                
        time_window = self.data.iloc[idx, 0]

        vrm_params = self.data.iloc[idx, 1:]
        vrm_params = np.asarray(vrm_params)
        vrm_params = vrm_params.astype('float')

        return mfcc_frame, torch.Tensor(vrm_params)
    

effect = [["sinc", "300-3k"]]
effect = None # libsox not available for Windows

train_file = ["8-8-2023 3-15-42 PM"]
valid_file = "2023-08-15 10-05-57 PM"

train0 = VRMParamsDataset(train_file[0], TRAINING_DATA_PATH, effect)
train = [VRMParamsDataset(file, TRAINING_DATA_PATH, effect, scaler=train0.scaler) for file in train_file[1:]]
train.insert(0, train0)
assert all([t.input_dim() == train0.input_dim() and t.output_dim() == train0.output_dim() for t in train])

valid = VRMParamsDataset(valid_file, TRAINING_DATA_PATH, effect, scaler=train0.scaler)
test = VRMParamsDataset("7-17-2023 4-53-18 PM", TRAINING_DATA_PATH, effect, scaler=train0.scaler)
# test._set_static_frame(train.STATIC_FRAME)

soundfile
cuda:0
2.0.1+cu118
2.0.2+cu118
----------
Source: ./sampledata/Training/Audio/8-8-2023 3-15-42 PM.wav
----------
 - sample_rate: 48000
 - num_channels: 1
 - num_frames: 86400000
 - bits_per_sample: 16
 - encoding: PCM_S
 - duration: 1800.0s
 - num_vrm_frames: 95449
 - vrm fps: 53.02722222222222
 - frames_per_vrm: 905.195444687739
 - seconds_per_vrm: 0.018858238430994562s

MFCC with
 - number of mfcc: 39
 - number of mels: 78
 - number of fft: 56
 - window length: 56
 - hop length: 28
 - number of mfcc frames: 3085715

----------
Source: ./sampledata/Training/Audio/7-17-2023 4-53-18 PM.wav
----------
 - sample_rate: 48000
 - num_channels: 1
 - num_frames: 2880000
 - bits_per_sample: 16
 - encoding: PCM_S
 - duration: 60.0s
 - num_vrm_frames: 537
 - vrm fps: 8.95
 - frames_per_vrm: 5363.128491620112
 - seconds_per_vrm: 0.11173184357541899s

MFCC with
 - number of mfcc: 39
 - number of mels: 78
 - number of fft: 335
 - window length: 335
 - hop length: 167
 - number of mfcc fram

In [2]:
# Plot waveforms and spectrograms
VISUAL = False

### Play audio of training data

In [3]:
if (VISUAL):
    Audio(train[0].SPEECH_WAVEFORM, rate=train[0].SAMPLE_RATE)

### Play audio of testing data

In [4]:
if (VISUAL):
    Audio(test.SPEECH_WAVEFORM, rate=test.SAMPLE_RATE)

### Waveform of training data

In [5]:
if (VISUAL):
    plot_waveform(train[0].SPEECH_WAVEFORM, train.SAMPLE_RATE, title="Training audio")
    plt.show()

In [6]:
print(train[0].mfcc[0].shape)
if (VISUAL):
    plot_spectrogram(train[0].mfcc[0])
    plt.show()

torch.Size([39, 3085715])


### Waveform of testing data

In [7]:
if (VISUAL):
    plot_waveform(test.SPEECH_WAVEFORM, test.SAMPLE_RATE, title="Testing audio")
    plt.show()

In [8]:
print(test.mfcc[0].shape)
if (VISUAL):
    plot_spectrogram(test.mfcc[0])
    plt.show()

torch.Size([39, 411429])


### Prepare model

In [9]:
batch_size = 100

train_loader = [DataLoader(t, batch_size=batch_size, shuffle=False) for t in train]
valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)

_x, _y = train[0][0]
_x.shape, _y.shape

(torch.Size([39, 64]), torch.Size([396]))

In [10]:
class VRMLoss(nn.Module):
    bs_smooth_diff = 0.1
    bone_smooth_diff = 0.4

    def __init__(self, weights): #, non_expr_face_bs_weight=1.0):
        super(VRMLoss, self).__init__()
        self.bs_huber_weight = weights['bs']['huber']
        self.bs_smooth_weight = weights['bs']['smooth']

        self.bone_huber_weight = weights['bone']['huber']
        self.bone_smooth_weight = weights['bone']['smooth']
        
        # self.non_expr_face_bs_weight = non_expr_face_bs_weight

    def forward(self, output, target):
        # Separate data into blendshape and bone
        bs_output = output.squeeze()[:len(BLENDSHAPE_PARAMS)]
        bs_target = target.squeeze()[:len(BLENDSHAPE_PARAMS)]

        bone_output = output.squeeze()[len(BLENDSHAPE_PARAMS):]
        bone_target = target.squeeze()[len(BLENDSHAPE_PARAMS):]
        
        # n_loss = torch.nn.MSELoss()(bs_output, bs_target)

        # def expr_blendshapes(blendshapes):
        #     params = ['A', 'Angry', 'E', 'Fun', 'I', 'Joy', 'Neutral', 'O', 'U', 'Sorrow', 'Surprised']
        #     # print(blendshapes.shape)
        #     return torch.cat(tuple([blendshapes[:,BLENDSHAPE_PARAMS.index(p)] for p in params]), dim=0)

        # bs_output = expr_blendshapes(bs_output)
        # bs_target = expr_blendshapes(bs_target)
        
        bs_prev = None
        bs_smooth = 0.
        
        bone_prev = None
        bone_smooth = 0.

        bs_huber = nn.HuberLoss()(bs_output, bs_target)
        for sample in range(len(bs_target)):
            # print(output[sample].shape, target[sample].shape)
            # print(bs_output[sample].shape, bs_target[sample].shape)
            bs_smooth += 0 if bs_prev is None else 1. - abs(self.bs_smooth_diff - nn.CosineSimilarity(dim=0)(bs_output[sample], bs_prev).mean())
            bs_prev = bs_output[sample]
        
        bone_huber = nn.HuberLoss()(bone_output, bone_target)
        for sample in range(len(bone_target)):
            bone_smooth += 0 if bone_prev is None else 1. - abs(self.bone_smooth_diff - nn.CosineSimilarity(dim=0)(bone_output[sample], bone_prev).mean())
            bone_prev = bone_output[sample]

        def a2f_loss(huber, smooth, length, w=[1.0, 1.0]):
            # Calculate loss as in Audio2Face (Guanzhong Tian; Yi Yuan; Yong Liu)
            return (w[0]*huber + w[1]*smooth) / length

        # Loss is computed separately for blendshape and bone with different weights and combined
        return a2f_loss(bs_huber, bs_smooth, len(bs_target), [self.bs_huber_weight, self.bs_smooth_weight]) + a2f_loss(bone_huber, bone_smooth, len(bone_target), [self.bone_huber_weight, self.bone_smooth_weight]) # + self.non_expr_face_bs_weight * n_loss

In [11]:
class BiLSTM(nn.Module):

    def __init__(self, input_dim, hidden_dim, batch_size, output_dim, num_layers, p):
        """
        Arguments:
            input_dim: Input layer dimension
            hidden_dim: Hidden layer dimension
            batch_size: Batch size of data
            output_dim: Output layer dimension
            num_layers: Number of layers
            p: Dropout
        """
        super(BiLSTM, self).__init__()

        # Dropout to prevent overfitting
        self.dropout = nn.Dropout(p)
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers
        
        # self.init_linear = nn.Linear(self.input_dim, self.input_dim)

        # Bidirectional LSTM to predict sequence with memory
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True, dropout=p)

        # Mixed Density Network layer to prevent sequence from staying still, outputs a probability density function
        self.mdn = mdn.MDN(self.hidden_dim * 2, self.hidden_dim * 2, 1)

        # Attention layer to determine important parameters from hidden weights
        self.linear_hidden = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        self.energy = nn.Linear(self.hidden_dim*3, 1)
        self.softmax = nn.Softmax(dim=0)
        self.relu = nn.ReLU()

        # Fully connected layer to produce output
        self.fc = nn.Linear(self.hidden_dim*4, output_dim)

    def init_hidden(self):
        # Hidden layers of LSTM initialised as Gaussian distribution
        w1 = torch.randn(self.num_layers, self.batch_size, self.hidden_dim)
        w2 = torch.randn(self.num_layers, self.batch_size, self.hidden_dim)
        return w1, w2

    def forward(self, input):
        input = self.dropout(input)

        # print('----------------------------------------------------------------------------------')
        # print(input)
        
        lstm_out, (hidden, cell) = self.lstm(input)
        # print('               LSTM')
        # print(lstm_out)

        # print('               MDN')
        # print(mdn_out)

        # print(hidden[0:2].shape)
        
        hidden = self.linear_hidden(hidden[0:2].reshape(1, -1, self.hidden_dim * 2)).permute(1, 0, 2)
        # print('               HIDD')
        # print(hidden)

        attn = self.softmax(self.relu(self.energy(torch.cat((hidden, lstm_out), dim=2))))
        # print('               ATTN')
        # print(attn)
        context = torch.bmm(attn, lstm_out).permute(1, 0, 2)
        mdn_out = self.mdn(context)[2].permute(1, 0, 2)

        # print(context.shape, lstm_out.shape)
        y_pred = self.fc(torch.cat((mdn_out, lstm_out.permute(1, 0, 2)), dim=2)).squeeze()
        # print('               PRED')
        # print(y_pred)
        return y_pred

# Training hyperparameters
n_epochs = 30
lr = 0.0001
lstm_input_size = train[0].input_dim()
hidden_state_size = 512
num_sequence_layers = 2
output_dim = train[0].output_dim()
save_interval = 10
dropoff = 0.0

model = BiLSTM(lstm_input_size, hidden_state_size, batch_size, output_dim, num_sequence_layers, dropoff)
model = model.to(device)

# Adam optimizer (gradient descent)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

### Training

In [12]:
weights = {
    'bs': {
        'huber': 9.0,
        'smooth': 0.2
    },
    'bone': {
        'huber': 1.0,
        'smooth': 0.2
    }
}
loss_fn = VRMLoss(weights)
valid_loss_fn = nn.MSELoss()

In [13]:
def save_model_state(state_dict, name, train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss):
    torch.save(state_dict, 'models/{}.pt'.format(name))
    with open('models/meta-{}.txt'.format(name), 'w') as f:
        metastring = 'Train file: {}\nBatch size: {}\nNumber of epochs: {}\nLearning rate: {}\nHidden state size: {}\nNumber of LSTM layers: {}\nWeights: {}\nValidation loss: {}'
        f.write(metastring.format(train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, str(weights), valid_loss))

def save_model(model, name, train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss):
    save_model_state(model.state_dict(), name, train_file, batch_size, epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss)

In [14]:
min_valid_loss = np.inf
best_model = None
best_model_epoch = -2

for epoch in range(n_epochs):
    try:
        start_time = time.time()
        train_loss = 0.0
        model.train()
        for loader in train_loader:
            for i, (mfcc, vrm_params) in enumerate(loader):
                mfcc = mfcc.to(device=device).reshape(-1, 1, lstm_input_size)
                vrm_params = vrm_params.to(device=device).reshape(-1, 1, output_dim)
        
                pred = model(mfcc)
        
                loss = loss_fn(pred, vrm_params)
        
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
        train_loss /= sum([len(loader) for loader in train_loader])
            
        valid_loss = 0.0
        model.eval()
        for i, (mfcc, vrm_params) in enumerate(valid_loader):
            mfcc = mfcc.to(device=device).reshape(-1, 1, lstm_input_size)
            vrm_params = vrm_params.to(device=device).reshape(-1, 1, output_dim)
            
            pred = model(mfcc)
            
            loss = valid_loss_fn(pred, vrm_params)
            valid_loss += loss.item()
    
        valid_loss /= len(valid_loader)
    
        print('Epoch {}/{} \tTrain loss={:.10f} \tValid loss={:.10f}\tTime={:.2f}s'.format(epoch + 1, n_epochs, train_loss, valid_loss, time.time() - start_time))
    
        if ((epoch+1) % save_interval == 0):
            save_model(model, 'sample-model_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%SZ") + '_EPOCH' + str(epoch+1), train_file, batch_size, epoch+1, lr, hidden_state_size, num_sequence_layers, weights, valid_loss)
        
        if min_valid_loss > valid_loss:
            min_valid_loss = valid_loss
            best_model = model.state_dict()
            best_model_epoch = epoch
    
            print('Best model: Model at epoch {} with valid loss {:.10f}'.format(best_model_epoch+1, min_valid_loss))
    except Exception as e:
        print(e)

print('--\nBest model: Model at epoch {} with valid loss {:.10f}'.format(best_model_epoch+1, min_valid_loss))

Epoch 1/30 	Train loss=0.1001580049 	Valid loss=0.0080250793	Time=92.62s
Best model: Model at epoch 1 with valid loss 0.0080250793
Epoch 2/30 	Train loss=0.0995194562 	Valid loss=0.0073329620	Time=92.40s
Best model: Model at epoch 2 with valid loss 0.0073329620
Epoch 3/30 	Train loss=0.0995294764 	Valid loss=0.0073845716	Time=92.37s
Epoch 4/30 	Train loss=0.0995377378 	Valid loss=0.0073998995	Time=95.35s
Epoch 5/30 	Train loss=0.0995425839 	Valid loss=0.0074140476	Time=94.46s
Epoch 6/30 	Train loss=0.0995465403 	Valid loss=0.0074324143	Time=94.36s
Epoch 7/30 	Train loss=0.0995503215 	Valid loss=0.0074371019	Time=94.36s
Epoch 8/30 	Train loss=0.0995535396 	Valid loss=0.0074490316	Time=94.49s
Epoch 9/30 	Train loss=0.0995561626 	Valid loss=0.0074461555	Time=94.04s
Epoch 10/30 	Train loss=0.0995585641 	Valid loss=0.0074552527	Time=94.35s
Epoch 11/30 	Train loss=0.0995606384 	Valid loss=0.0074477620	Time=94.39s
Epoch 12/30 	Train loss=0.0995620016 	Valid loss=0.0074486457	Time=95.35s
Epoch

In [15]:
weights_file = 'sample-model_2023-08-16_22-22-02Z_EPOCH60'

In [16]:
weights_file = 'model_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%SZ")
# Save latest model
save_model(model, weights_file + '_LATEST', str(train_file), batch_size, n_epochs, lr, hidden_state_size, num_sequence_layers, weights, valid_loss)
# Save best model
save_model_state(best_model, weights_file + '_BEST', str(train_file), batch_size, best_model_epoch+1, lr, hidden_state_size, num_sequence_layers, weights, min_valid_loss)
weights_file += '_BEST'

In [17]:
model = BiLSTM(lstm_input_size, hidden_state_size, batch_size, output_dim, num_sequence_layers, dropoff)
model = model.to(device)
model.load_state_dict(torch.load('models/{}.pt'.format(weights_file), map_location=device))

model

BiLSTM(
  (dropout): Dropout(p=0.0, inplace=False)
  (lstm): LSTM(2496, 512, num_layers=2, batch_first=True, bidirectional=True)
  (mdn): MDN(
    (pi): Sequential(
      (0): Linear(in_features=1024, out_features=1, bias=True)
      (1): Softmax(dim=1)
    )
    (sigma): Linear(in_features=1024, out_features=1024, bias=True)
    (mu): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (linear_hidden): Linear(in_features=1024, out_features=512, bias=True)
  (energy): Linear(in_features=1536, out_features=1, bias=True)
  (softmax): Softmax(dim=0)
  (relu): ReLU()
  (fc): Linear(in_features=2048, out_features=396, bias=True)
)

### Prediction

In [18]:
model.eval()

BiLSTM(
  (dropout): Dropout(p=0.0, inplace=False)
  (lstm): LSTM(2496, 512, num_layers=2, batch_first=True, bidirectional=True)
  (mdn): MDN(
    (pi): Sequential(
      (0): Linear(in_features=1024, out_features=1, bias=True)
      (1): Softmax(dim=1)
    )
    (sigma): Linear(in_features=1024, out_features=1024, bias=True)
    (mu): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (linear_hidden): Linear(in_features=1024, out_features=512, bias=True)
  (energy): Linear(in_features=1536, out_features=1, bias=True)
  (softmax): Softmax(dim=0)
  (relu): ReLU()
  (fc): Linear(in_features=2048, out_features=396, bias=True)
)

In [19]:
test_preds = torch.zeros(output_dim * len(test))
test_actual = torch.zeros(output_dim * len(test))

for i, (mfcc, vrm_params) in enumerate(test_loader):
    mfcc = mfcc.to(device=device).reshape(-1, 1, lstm_input_size)
    if not test.audio_only:
        vrm_params = vrm_params.cpu().reshape(-1)
    y_pred = model(mfcc).cpu().detach().reshape(-1)
    test_preds[i*batch_size*output_dim:(i+1)*batch_size*output_dim] = y_pred
    if not test.audio_only:
        test_actual[i*batch_size*output_dim:(i+1)*batch_size*output_dim] = vrm_params

if not test.audio_only:
    rmse = torch.sqrt(torch.nn.functional.mse_loss(test_preds, test_actual))
    print('RMSE: {:.10f}'.format(rmse))
        
test_preds = test_preds.reshape(len(test), output_dim)
test_actual = test_actual.reshape(len(test), output_dim)

test_preds.shape

RMSE: 0.1353097111


torch.Size([12343, 396])

In [20]:
test_preds

tensor([[ 0.1788,  0.0004,  0.1114,  ...,  0.0060,  0.0003, -0.0004],
        [ 0.1788,  0.0004,  0.1114,  ...,  0.0060,  0.0003, -0.0004],
        [ 0.1788,  0.0004,  0.1114,  ...,  0.0060,  0.0003, -0.0004],
        ...,
        [ 0.1660, -0.0003,  0.1143,  ...,  0.0054,  0.0008,  0.0007],
        [ 0.1660, -0.0003,  0.1143,  ...,  0.0054,  0.0008,  0.0007],
        [ 0.1660, -0.0003,  0.1143,  ...,  0.0054,  0.0008,  0.0007]])

In [21]:
PREDICTION_DATA_PATH = './sampledata/Prediction'

def to_csv(test_preds, name='prediction'):
    test_preds = train[0].scaler.inverse_transform(test_preds)
    
    blendshape_params = pd.DataFrame(test_preds[:,:len(BLENDSHAPE_PARAMS)])
    blendshape_params.columns = BLENDSHAPE_PARAMS
    
    bone_params = pd.DataFrame(test_preds[:,len(BLENDSHAPE_PARAMS):])
    bone_params.columns = [bone + t for t in ('PosX', 'PosY', 'PosZ', 'RotX', 'RotY', 'RotZ', 'RotW') for bone in BONE_PARAMS]
    
    STATIC_FRAME = test.STATIC_FRAME
    
    time_column = pd.DataFrame({'Time': [float(i)*STATIC_FRAME for i in range(len(test_preds))]})
    
    blendshape_params = pd.concat([time_column, blendshape_params], axis=1)
    bone_params = pd.concat([time_column, bone_params], axis=1)
    
    blendshape_params.to_csv(PREDICTION_DATA_PATH + '/Blendshapes/' + name + '.csv', index=False)
    bone_params.to_csv(PREDICTION_DATA_PATH + '/Bones/' + name + '.csv', index=False)
    
    shutil.copyfile(test.DATA_PATH + '/Audio/' + test.name + '.wav', PREDICTION_DATA_PATH + '/Audio/' + name + '.wav')

datestring = datetime.now().strftime("%Y-%m-%d_%H-%M-%SZ")
to_csv(test_preds, 'prediction-' + datestring)
if not test.audio_only:
    to_csv(test_actual, 'actual-' + datestring)