# Inference demo

## Step 1. Load a pretrained model

In [56]:
import os
import pickle
import torch
import torchaudio
import numpy as np
from torch import nn
from model import MyModel
from transformers import DistilBertTokenizer

# load checkpoint
DATA_PATH = './../../data/'
DEVICE = 'cpu'
S = torch.load(os.path.join(DATA_PATH, 'pretrained/alm_cross.ckpt'), map_location=torch.device(DEVICE))['state_dict']
NS = {k[6:]: S[k] for k in S.keys() if (k[:5] == 'model')}

# load model
model = MyModel()
model.load_state_dict(NS)
model = model.eval()

# load word2vec
word2vec = pickle.load(open(os.path.join(DATA_PATH, 'w2v.pkl'), 'rb'))

# load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

## Step 2. Map emotion tags to the shared embedding space

In [57]:
emotions = ['angry', 'scary', 'happy', 'sad', 'tender'] # add more emotions you like
emotion_w2v = torch.tensor([word2vec[emotion] for emotion in emotions])
with torch.no_grad():
    emotion_embeddings = model.tag_to_embedding(emotion_w2v).detach().cpu()

In [58]:
print(emotion_embeddings.shape)

torch.Size([5, 64])


## Step 3. Map text to the shared embedding space

In [72]:
text = 'I am super happy today!'
tokens = tokenizer([text, text], return_tensors='pt', padding=True, truncation=True) # made a list of the text to avoid batch_normalization issue
with torch.no_grad():
    text_embedding = model.text_to_embedding(tokens['input_ids'], tokens['attention_mask'])[0].detach().cpu()

In [60]:
print(text_embedding.shape)

torch.Size([64])


## Step 4. Map music to the shared embedding space

In [61]:
INPUT_LENGTH = 80000
NUM_CHUNKS = 8
get_spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, f_min=0.0, f_max=8000.0, n_mels=128)

# load audio
song = np.zeros((1, 16000 * 30)).astype('float32') # an example of 30-second of audio

# get multiple chunks
hop = (len(song) - INPUT_LENGTH) // NUM_CHUNKS
song = torch.tensor([song[i*hop:i*hop+INPUT_LENGTH] for i in range(NUM_CHUNKS)]).squeeze(1)
with torch.no_grad():
    spec = get_spec(song)
    song_embedding = model.spec_to_embedding(spec).detach().cpu().mean(dim=0)

In [62]:
print(song_embedding.shape)

torch.Size([64])
