# Audio Continuation Transformer Decoder (Inference)

## Imports

In [None]:
import numpy as np
import math
import time
import csv
import os
import re
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import torchaudio.transforms as transforms
from collections import OrderedDict

from vocos import Vocos

import IPython
from IPython.display import display
import ipywidgets as widgets

## Settings

### Compute Device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

### Audio Settings

In [None]:
audio_sample_rate = 48000

### Model Settings

In [None]:
decoder_layer_count = 6
decoder_head_count = 8
decoder_embed_dim = 512
decoder_ff_dim = 2048
decoder_dropout = 0.1

load_weights = True
decoder_weights_file = "../../../../Data/Models/AudioContinuation/TransformerDecoder/Gutenberg/weights/decoder_weights_epoch_200"

decoder_layer_count_gui = widgets.IntText(value=decoder_layer_count, description="Decoder Layer Count:", style={'description_width': 'initial'})
decoder_head_count_gui = widgets.IntText(value=decoder_head_count, description="Decoder Head Count:", style={'description_width': 'initial'})
decoder_embed_dim_gui = widgets.IntText(value=decoder_embed_dim, description="Decoder Embed Dim:", style={'description_width': 'initial'})
decoder_ff_dim_gui = widgets.IntText(value=decoder_ff_dim, description="Decoder Fordward Dim:", style={'description_width': 'initial'})
decoder_dropout_gui = widgets.FloatText(value=decoder_dropout, description="Decoder Dropout:", style={'description_width': 'initial'})

decoder_weights_file_gui = widgets.Text(value=decoder_weights_file, description="Decoder Weights File:", style={'description_width': 'initial'}) 

display(decoder_layer_count_gui)
display(decoder_head_count_gui)
display(decoder_embed_dim_gui)
display(decoder_ff_dim_gui)
display(decoder_dropout_gui)
display(decoder_weights_file_gui)

In [None]:
decoder_layer_count = decoder_layer_count_gui.value
decoder_head_count = decoder_head_count_gui.value
decoder_embed_dim = decoder_embed_dim_gui.value
decoder_ff_dim = decoder_ff_dim_gui.value
decoder_dropout = decoder_dropout_gui.value
decoder_weights_file = decoder_weights_file_gui.value

### Training settings

In [None]:
seq_input_length = 64

seq_input_length_gui = widgets.IntText(value=seq_input_length, description="Sequence Input Length:", style={'description_width': 'initial'})

display(seq_input_length_gui)

In [None]:
seq_input_length = seq_input_length_gui.value

## Load Vocos Model

In [None]:
vocos = Vocos.from_pretrained("kittn/vocos-mel-48khz-alpha1")

dummy_waveform = torch.zeros((1, audio_sample_rate), dtype=torch.float32).to(device)
dummy_mels = vocos.feature_extractor(dummy_waveform)

audio_features_dim = dummy_mels.shape[1]

## Create Models

## PositionalEncoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

## Create TransformerDecoder

In [None]:
class TransformerDecoder(nn.Module):

    # Constructor
    def __init__(
        self,
        audio_dim,
        embed_dim,
        num_heads,
        num_decoder_layers,
        ff_dim,
        dropout_p,
        pos_encoding_max_length
    ):
        super().__init__()

        self.embed_dim = embed_dim

        self.audio2embed = nn.Linear(audio_dim, embed_dim) # map audio data to embedding

        self.positional_encoder = PositionalEncoding(
            dim_model=embed_dim, dropout_p=dropout_p, max_len=pos_encoding_max_length
        )
        
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout_p, batch_first=True)
        #self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers = num_decoder_layers)

        # build a decoder directly from TransformerDecoderLayer
        # rather than using the nn.TransformerDecoder module which requires also a Transformer Encoder
        self.decoder = self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=ff_dim,
                dropout=dropout_p,
                activation='gelu',
                batch_first=True
            ) for _ in range(num_decoder_layers)
        ])

        self.embed2audio = nn.Linear(embed_dim, audio_dim) # map embedding to audio data

    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        
        return mask
        
       
    def forward(self, audio_data):
        
        #print("forward")
        
        #print("audio_data s ", audio_data.shape)
        
        # dummy "memory" as zero (only self-attention is used)
        memory = torch.zeros(audio_data.size(0), audio_data.size(1), self.embed_dim, device=audio_data.device)

        #print("memory s ", memory.shape)

        # Lower triangular matrix for autoregressive masking
        tgt_mask = self.get_tgt_mask(audio_data.shape[1]).to(audio_data.device)

        #print("tgt_mask s ", tgt_mask.shape)

        audio_embedded = self.audio2embed(audio_data) * math.sqrt(self.embed_dim)
        
        #print("audio_embedded 1 s ", audio_embedded.shape)
        
        audio_embedded = self.positional_encoder(audio_embedded)
        
        #print("audio_embedded 2 s ", audio_embedded.shape)
        
        x = audio_embedded
        
        #print("x s ", x.shape)
        
        for layer in self.layers:
            
            #print("x in s ", x.shape)
            
            x = layer(x, memory, tgt_mask=tgt_mask)
            
            #print("x out s ", x.shape)

        decoder_out = x

        out = self.embed2audio(decoder_out)
        
        out = out[:, -1, :] # only last time step 
        
        return out

decoder = TransformerDecoder(audio_dim=audio_features_dim,
                          embed_dim=decoder_embed_dim, 
                          num_heads=decoder_head_count, 
                          num_decoder_layers=decoder_layer_count, 
                          ff_dim = decoder_ff_dim,
                          dropout_p=decoder_dropout,
                          pos_encoding_max_length=seq_input_length).to(device)

print(decoder)

if load_weights == True:
    if device == 'cuda':
        decoder.load_state_dict(torch.load(decoder_weights_file))
    else:
        decoder.load_state_dict(torch.load(decoder_weights_file, map_location=device ))


## Inference

In [None]:
decoder.eval()

def export_orig_audio(waveform_data, start_time, end_time, file_name):
    
    start_time_samples = int(start_time * audio_sample_rate)
    end_time_samples = int(end_time * audio_sample_rate)
    
    torchaudio.save(file_name, waveform_data[:, start_time_samples:end_time_samples], audio_sample_rate)

def export_ref_audio(waveform_data, start_time, end_time, file_name):
    
    start_time_samples = int(start_time * audio_sample_rate)
    end_time_samples = int(end_time * audio_sample_rate)
    
    # audio features
    audio_features = vocos.feature_extractor(waveform_data[:, start_time_samples:end_time_samples])
    
    ref_audio = vocos.decode(audio_features)
    
    torchaudio.save(file_name, ref_audio.detach().cpu(), audio_sample_rate)
    

def export_pred_audio(waveform_data, start_time, end_time, file_name):
    
    start_time_samples = int(start_time * audio_sample_rate)
    end_time_samples = int(end_time * audio_sample_rate)
    
    # audio features
    audio_features = vocos.feature_extractor(waveform_data[:, start_time_samples:end_time_samples])
    
    #print("audio_features s ", audio_features.shape)
    
    audio_features = audio_features.squeeze(0)
    audio_features = torch.permute(audio_features, (1, 0))
    audio_feature_count = audio_features.shape[0]
    
    #print("audio_feature_count ", audio_feature_count)
    
    input_features = audio_features[:seq_input_length]
    input_features = input_features.unsqueeze(0)
    
    output_features_length = audio_feature_count - seq_input_length
    
    #print("output_features_length ", output_features_length)
    
    _input_features = input_features  
    pred_features = []
    
    with torch.no_grad():
    
        for o_i in range(1, output_features_length):
            
            _input_features = _input_features.to(device)
            
            #print("_input_features s ", _input_features.shape)
            
            _pred_features = decoder(_input_features)
            _pred_features = torch.unsqueeze(_pred_features, axis=1)

            _input_features = _input_features[:, 1:, :].detach().clone()
            _pred_features = _pred_features.detach().clone()
            
            pred_features.append(_pred_features.cpu())
            
            _input_features = torch.cat((_input_features, _pred_features), axis=1)
                
            #print("_input_features s ", _input_features.shape)
            
    pred_features = torch.cat(pred_features, axis=1)
    pred_features = torch.permute(pred_features, (0, 2, 1))
    pred_audio = vocos.decode(pred_features)
    
    torchaudio.save(file_name, pred_audio.detach().cpu(), audio_sample_rate)

### Perform Audio Continuation

In [None]:
audio_file = "../../../../Data/Audio/Gutenberg/Night_and_Day_by_Virginia_Woolf_48khz.wav"
audio_start_time_sec = 60.0
audio_end_time_sec = 70.0

audio_file_gui = widgets.Text(value=audio_file, description="Audio File:", style={'description_width': 'initial'})
audio_start_time_sec_gui = widgets.FloatText(value=audio_start_time_sec, description="Audio Start Time [Seconds]:", style={'description_width': 'initial'})
audio_end_time_sec_gui = widgets.FloatText(value=audio_end_time_sec, description="Audio End Time [Seconds]", style={'description_width': 'initial'})

display(audio_file_gui)
display(audio_start_time_sec_gui)
display(audio_end_time_sec_gui)

In [None]:
audio_file = audio_file_gui.value
audio_start_time_sec = audio_start_time_sec_gui.value
audio_end_time_sec = audio_end_time_sec_gui.value

In [None]:
waveform_data, _ = torchaudio.load(audio_file)

export_orig_audio(waveform_data, audio_start_time_sec, audio_end_time_sec, "results/audio/orig_{}-{}.wav".format(audio_start_time_sec, audio_end_time_sec))
export_ref_audio(waveform_data, audio_start_time_sec, audio_end_time_sec, "results/audio/ref_{}-{}.wav".format(audio_start_time_sec, audio_end_time_sec))
export_pred_audio(waveform_data, audio_start_time_sec, audio_end_time_sec, "results/audio/pred_{}-{}.wav".format(audio_start_time_sec, audio_end_time_sec))