# Motion Continuation LSTM (Inference)

## Imports

In [1]:
import numpy as np
import time
import csv
import os
import re
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import torchaudio.transforms as transforms
from collections import OrderedDict

from vocos import Vocos

import IPython
from IPython.display import display
import ipywidgets as widgets



## Settings

### Compute Device

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cpu device


### Audio Settings

In [3]:
audio_sample_rate = 48000

### Model Settings

In [4]:
rnn_layer_dim = 512
rnn_layer_count = 2

rnn_weights_file = "../../../../Data/Models/AudioContinuation/RNN/Gutenberg/weights/rnn_weights_epoch_200"

rnn_layer_count_gui = widgets.IntText(value=rnn_layer_count, description="LSTM Layer Count:", style={'description_width': 'initial'})
rnn_layer_dim_gui = widgets.IntText(value=rnn_layer_dim, description="LSTM Layer Dim:", style={'description_width': 'initial'})

rnn_weights_file_gui = widgets.Text(value=rnn_weights_file, description="RNN Weights File:", style={'description_width': 'initial'}) 

display(rnn_layer_count_gui)
display(rnn_layer_dim_gui)
display(rnn_weights_file_gui)

IntText(value=2, description='LSTM Layer Count:', style=DescriptionStyle(description_width='initial'))

IntText(value=512, description='LSTM Layer Dim:', style=DescriptionStyle(description_width='initial'))

Text(value='../../../../Data/Models/AudioContinuation/RNN/Gutenberg/weights/rnn_weights_epoch_200', descriptio…

In [5]:
rnn_layer_count = rnn_layer_count_gui.value
rnn_layer_dim = rnn_layer_dim_gui.value
rnn_weights_file = rnn_weights_file_gui.value

### Training settings

In [6]:
seq_input_length = 64

seq_input_length_gui = widgets.IntText(value=seq_input_length, description="Sequence Input Length:", style={'description_width': 'initial'})

display(seq_input_length_gui)

IntText(value=64, description='Sequence Input Length:', style=DescriptionStyle(description_width='initial'))

In [7]:
seq_input_length = seq_input_length_gui.value

## Load Vocos Model

In [8]:
vocos = Vocos.from_pretrained("kittn/vocos-mel-48khz-alpha1")

dummy_waveform = torch.zeros((1, audio_sample_rate), dtype=torch.float32).to(device)
dummy_mels = vocos.feature_extractor(dummy_waveform)

audio_features_dim = dummy_mels.shape[1]

## Create Recurrent Model

In [9]:
class Reccurent(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, layer_count):
        super(Reccurent, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.layer_count = layer_count
        self.output_dim = output_dim
            
        rnn_layers = []
        
        rnn_layers.append(("rnn", nn.LSTM(self.input_dim, self.hidden_dim, self.layer_count, batch_first=True)))
        self.rnn_layers = nn.Sequential(OrderedDict(rnn_layers))
        
        dense_layers = []
        dense_layers.append(("dense", nn.Linear(self.hidden_dim, self.output_dim)))
        self.dense_layers = nn.Sequential(OrderedDict(dense_layers))
    
    def forward(self, x):
        x, (_, _) = self.rnn_layers(x)
        
        x = x[:, -1, :] # only last time step 
        x = self.dense_layers(x)
        
        return x

rnn = Reccurent(audio_features_dim, rnn_layer_dim, audio_features_dim, rnn_layer_count).to(device)

print(rnn)

if device == 'cuda':
    rnn.load_state_dict(torch.load(rnn_weights_file))
else:
    rnn.load_state_dict(torch.load(rnn_weights_file, map_location=device ))

Reccurent(
  (rnn_layers): Sequential(
    (rnn): LSTM(128, 512, num_layers=2, batch_first=True)
  )
  (dense_layers): Sequential(
    (dense): Linear(in_features=512, out_features=128, bias=True)
  )
)


## Inference

In [10]:
rnn.eval()

def export_orig_audio(waveform_data, start_time, end_time, file_name):
    
    start_time_samples = int(start_time * audio_sample_rate)
    end_time_samples = int(end_time * audio_sample_rate)
    
    torchaudio.save(file_name, waveform_data[:, start_time_samples:end_time_samples], audio_sample_rate)

def export_ref_audio(waveform_data, start_time, end_time, file_name):
    
    start_time_samples = int(start_time * audio_sample_rate)
    end_time_samples = int(end_time * audio_sample_rate)
    
    # audio features
    audio_features = vocos.feature_extractor(waveform_data[:, start_time_samples:end_time_samples])
    
    ref_audio = vocos.decode(audio_features)
    
    torchaudio.save(file_name, ref_audio.detach().cpu(), audio_sample_rate)
    
def export_pred_audio(waveform_data, start_time, end_time, file_name):
    
    start_time_samples = int(start_time * audio_sample_rate)
    end_time_samples = int(end_time * audio_sample_rate)
    
    # audio features
    audio_features = vocos.feature_extractor(waveform_data[:, start_time_samples:end_time_samples])
    
    #print("audio_features s ", audio_features.shape)
    
    audio_features = audio_features.squeeze(0)
    audio_features = torch.permute(audio_features, (1, 0))
    audio_feature_count = audio_features.shape[0]
    
    #print("audio_feature_count ", audio_feature_count)
    
    input_features = audio_features[:seq_input_length]
    input_features = input_features.unsqueeze(0)
    
    output_features_length = audio_feature_count - seq_input_length
    
    #print("output_features_length ", output_features_length)
    
    _input_features = input_features  
    pred_features = []
    
    with torch.no_grad():
    
        for o_i in range(1, output_features_length):
            
            _input_features = _input_features.to(device)
            
            #print("_input_features s ", _input_features.shape)
            
            _pred_features = rnn(_input_features)
            _pred_features = torch.unsqueeze(_pred_features, axis=1)

            _input_features = _input_features[:, 1:, :].detach().clone()
            _pred_features = _pred_features.detach().clone()
            
            pred_features.append(_pred_features.cpu())
            
            _input_features = torch.cat((_input_features, _pred_features), axis=1)
                
            #print("_input_features s ", _input_features.shape)
            
    pred_features = torch.cat(pred_features, axis=1)
    pred_features = torch.permute(pred_features, (0, 2, 1))
    pred_audio = vocos.decode(pred_features)
    
    torchaudio.save(file_name, pred_audio.detach().cpu(), audio_sample_rate)

### Perform Audio Continuation

In [11]:
audio_file = "../../../../Data/Audio/Gutenberg/Night_and_Day_by_Virginia_Woolf_48khz.wav"
audio_start_time_sec = 60.0
audio_end_time_sec = 80.0

audio_file_gui = widgets.Text(value=audio_file, description="Audio File:", style={'description_width': 'initial'})
audio_start_time_sec_gui = widgets.FloatText(value=audio_start_time_sec, description="Audio Start Time [Seconds]:", style={'description_width': 'initial'})
audio_end_time_sec_gui = widgets.FloatText(value=audio_end_time_sec, description="Audio End Time [Seconds]", style={'description_width': 'initial'})

display(audio_file_gui)
display(audio_start_time_sec_gui)
display(audio_end_time_sec_gui)

Text(value='../../../../Data/Audio/Gutenberg/Night_and_Day_by_Virginia_Woolf_48khz.wav', description='Audio Fi…

FloatText(value=60.0, description='Audio Start Time [Seconds]:', style=DescriptionStyle(description_width='ini…

FloatText(value=80.0, description='Audio End Time [Seconds]', style=DescriptionStyle(description_width='initia…

In [12]:
audio_file = audio_file_gui.value
audio_start_time_sec = audio_start_time_sec_gui.value
audio_end_time_sec = audio_end_time_sec_gui.value

In [14]:
waveform_data, _ = torchaudio.load(audio_file)

export_orig_audio(waveform_data, audio_start_time_sec, audio_end_time_sec, "results/audio/orig_{}-{}.wav".format(audio_start_time_sec, audio_end_time_sec))
export_ref_audio(waveform_data, audio_start_time_sec, audio_end_time_sec, "results/audio/ref_{}-{}.wav".format(audio_start_time_sec, audio_end_time_sec))
export_pred_audio(waveform_data, audio_start_time_sec, audio_end_time_sec, "results/audio/pred_{}-{}.wav".format(audio_start_time_sec, audio_end_time_sec))