In [34]:
import torch 
import argparse
import os
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from utils import midi_to_song, log_midis
from loss import kl_normal, log_bernoulli_with_logits
import logging
import torch.nn.functional as F
from dataloader import MusicDataset
from model import DVAE 
from einops import repeat, rearrange


In [35]:
device = 'cpu'

In [36]:
config = OmegaConf.load('config.yaml')

model = DVAE(input_dim=config.model.input_dim, 
                hidden_dim=config.model.hidden_dim,
                hidden_dim_em=config.model.hidden_dim_em, 
                hidden_dim_tr=config.model.hidden_dim_tr, 
                latent_dim=config.model.latent_dim,
                dropout=config.model.dropout,
                combiner_type=config.model.combiner_type,
                rnn_type=config.model.rnn_type).to(device)

dataset = MusicDataset(config.dataset, split=config.sample.split)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
#load weights
ckpt_path = config.test.ckpt_path
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [37]:
model.eval()

a = 0
b = 0
c = 0

In [63]:


with torch.no_grad():   
    for j, (encodings, sequence_lengths) in enumerate(dataloader):
        
        print(f'sequence_lengths: {sequence_lengths}')
        encodings = encodings.to(device)
        sequence_lengths = sequence_lengths.to(device)
        
        x_hat, mus_inference, sigmas_inference, mus_generator, sigmas_generators = model(encodings)
        
        #get loss with only sum over latent dim dimension
        reconstruction_loss = log_bernoulli_with_logits(encodings, x_hat, sequence_lengths, T_reduction='none') 
        kl_loss = kl_normal(mus_inference, 
                            sigmas_inference, 
                            mus_generator, 
                            sigmas_generators, 
                            sequence_lengths,
                            T_reduction='none')
        
        kl_loss = kl_loss.mean(-1) #sum over T
        reconstruction_loss = reconstruction_loss.mean(-1) #sum over T
        
        #for a: #importance sampling
        z, mu_q, var_q = model.encoder(encodings)
        bs = encodings.shape[0]
        max_sequence_length = encodings.shape[1]
        loss_s = torch.zeros(bs)
        all_exponent_args = []
        break
        

sequence_lengths: tensor([129,  65])


In [69]:
sequence_lengths.shape[0]

2

In [64]:
reconstruction_loss = reconstruction_loss[0]
kl_loss = kl_loss[0]

print(f'reconstruction_loss: {reconstruction_loss}')
print(f'kl_loss: {kl_loss}')

reconstruction_loss: 10.08811092376709
kl_loss: 1.011856198310852


In [65]:
nelbo_matrix = reconstruction_loss + kl_loss
nelbo_matrix

tensor(11.1000)

In [47]:
# nelbo_matrix = nelbo_matrix.sum(-1) #sum over batch_size
# nelbo_matrix

tensor(2494.1921)

In [66]:
sequence_lengths_sum = sequence_lengths[0]
sequence_lengths_sum
nelbo_b = nelbo_matrix / sequence_lengths_sum


In [67]:
nelbo_b

tensor(0.0860)

In [42]:
#for b:
nelbo_matrix = reconstruction_loss + kl_loss
nelbo_matrix = nelbo_matrix.sum(-1) #sum over batch_size
sequence_lengths_sum = sequence_lengths.sum(-1)
nelbo_b = nelbo_matrix / sequence_lengths_sum
b += nelbo_b

In [43]:
b

tensor(12.8567)