In [1]:
import torch 
import argparse
import os
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from utils import midi_to_song, log_midis
from loss import kl_normal, log_bernoulli_with_logits
import logging
import torch.nn.functional as F
from dataloader import MusicDataset
from model import DVAE 
from einops import repeat, rearrange


In [2]:
device = 'cpu'

In [3]:
config = OmegaConf.load('config.yaml')

model = DVAE(input_dim=config.model.input_dim, 
                hidden_dim=config.model.hidden_dim,
                hidden_dim_em=config.model.hidden_dim_em, 
                hidden_dim_tr=config.model.hidden_dim_tr, 
                latent_dim=config.model.latent_dim,
                dropout=config.model.dropout,
                combiner_type=config.model.combiner_type,
                rnn_type=config.model.rnn_type).to(device)

dataset = MusicDataset(config.dataset, split=config.sample.split)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
#load weights
ckpt_path = config.test.ckpt_path
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [4]:
model.eval()

a = 0
b = 0
c = 0

In [63]:


with torch.no_grad():   
    for j, (encodings, sequence_lengths) in enumerate(dataloader):
        
        print(f'sequence_lengths: {sequence_lengths}')
        encodings = encodings.to(device)
        sequence_lengths = sequence_lengths.to(device)
        
        x_hat, mus_inference, sigmas_inference, mus_generator, sigmas_generators = model(encodings)
        
        #get loss with only sum over latent dim dimension
        reconstruction_loss = log_bernoulli_with_logits(encodings, x_hat, sequence_lengths, T_reduction='none') 
        kl_loss = kl_normal(mus_inference, 
                            sigmas_inference, 
                            mus_generator, 
                            sigmas_generators, 
                            sequence_lengths,
                            T_reduction='none')
        
        kl_loss = kl_loss.mean(-1) #sum over T
        reconstruction_loss = reconstruction_loss.mean(-1) #sum over T
        
        #for a: #importance sampling
        z, mu_q, var_q = model.encoder(encodings)
        bs = encodings.shape[0]
        max_sequence_length = encodings.shape[1]
        loss_s = torch.zeros(bs)
        all_exponent_args = []
        break
        

sequence_lengths: tensor([129,  65])


In [69]:
sequence_lengths.shape[0]

2

In [64]:
reconstruction_loss = reconstruction_loss[0]
kl_loss = kl_loss[0]

print(f'reconstruction_loss: {reconstruction_loss}')
print(f'kl_loss: {kl_loss}')

reconstruction_loss: 10.08811092376709
kl_loss: 1.011856198310852


In [65]:
nelbo_matrix = reconstruction_loss + kl_loss
nelbo_matrix

tensor(11.1000)

In [47]:
# nelbo_matrix = nelbo_matrix.sum(-1) #sum over batch_size
# nelbo_matrix

tensor(2494.1921)

In [66]:
sequence_lengths_sum = sequence_lengths[0]
sequence_lengths_sum
nelbo_b = nelbo_matrix / sequence_lengths_sum


In [67]:
nelbo_b

tensor(0.0860)

In [42]:
#for b:
nelbo_matrix = reconstruction_loss + kl_loss
nelbo_matrix = nelbo_matrix.sum(-1) #sum over batch_size
sequence_lengths_sum = sequence_lengths.sum(-1)
nelbo_b = nelbo_matrix / sequence_lengths_sum
b += nelbo_b

In [43]:
b

tensor(12.8567)

In [71]:
#define an rnn with 2 layers
rnn = torch.nn.RNN(input_size=88, hidden_size=600, num_layers=2, batch_first=True)

In [72]:
z = torch.randn(2, 10, 88)

In [74]:
out, hidden = rnn(z)

In [75]:
out.shape

torch.Size([2, 10, 600])

In [77]:
a = torch.randn(2, 10, 88)

In [78]:
from torch.nn.utils.rnn import pad_sequence

In [80]:
pad_sequence(a).shape

torch.Size([10, 2, 88])

In [5]:
for i, (encodings, sequence_lengths) in enumerate(dataloader):
    print(f'encodings: {encodings.shape}')
    print(f'sequence_lengths: {sequence_lengths.shape}')

encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Size([2, 129, 88])
sequence_lengths: torch.Size([2])
encodings: torch.Siz

In [114]:
a = dataset.read_pickle_from_url('data/jsb_chorales.pickle')['test'][0] #84

In [115]:
import numpy as np

In [134]:
split_data = [a]
all_music_one_hot_list = []
note_range = 88
min_note = 21
sequence_lengths = []

In [135]:
for music in split_data:
    one_hot_matrix = np.zeros((len(music), note_range), dtype=int)
    
    for row_index, keys in enumerate(music):

        for note in keys:
            one_hot_matrix[row_index, note - min_note] = 1  
            
    all_music_one_hot_list.append(one_hot_matrix)
    sequence_lengths.append(len(music))

In [130]:
sample = all_music_one_hot_list

In [129]:
length = len(sample)

print(f'sample: {sample}')
print(f'length: {length}')


sample: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
length: 88


In [6]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

[array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]])]

In [9]:
x_packed[1].shape

torch.Size([65])

In [23]:
x_packed = pack_padded_sequence(encodings, sequence_lengths, batch_first=True, enforce_sorted=False)

In [24]:
rnn = torch.nn.RNN(input_size=88, hidden_size=600, num_layers=1, batch_first=True)

In [27]:
out = rnn(x_packed)

In [31]:
torch.tensor(34).unsqueeze(0)

tensor([34])

In [26]:
out, _ = pad_packed_sequence(out, batch_first=True)

AttributeError: 'tuple' object has no attribute 'batch_sizes'

In [11]:
out.shape

torch.Size([1, 65, 88])

In [5]:
for i, (encodings, sequence_lengths) in enumerate(dataloader):
    print(f'encodings: {encodings.shape}')
    print(f'sequence_lengths: {sequence_lengths.shape}')
    break

encodings shape: torch.Size([1, 129, 88])
encodings shape: torch.Size([1, 65, 88])


RuntimeError: stack expects each tensor to be equal size, but got [1, 129, 88] at entry 0 and [1, 65, 88] at entry 1