In [1]:
import torch 
import argparse
import os
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from utils import midi_to_song, log_midis
from loss import kl_normal, log_bernoulli_with_logits
import logging
import torch.nn.functional as F
from dataloader import MusicDataset
from model import DVAE 
from einops import repeat, rearrange


In [10]:
device = 'cpu'

In [15]:
config = OmegaConf.load('config.yaml')

model = DVAE(input_dim=config.model.input_dim, 
                hidden_dim=config.model.hidden_dim,
                hidden_dim_em=config.model.hidden_dim_em, 
                hidden_dim_tr=config.model.hidden_dim_tr, 
                latent_dim=config.model.latent_dim,
                dropout=config.model.dropout,
                combiner_type=config.model.combiner_type,
                rnn_type=config.model.rnn_type).to(device)

dataset = MusicDataset(config.dataset, split=config.sample.split)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
#load weights
ckpt_path = config.test.ckpt_path
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [16]:
model.eval()

a = 0
b = 0
c = 0

In [31]:


with torch.no_grad():   
    for j, (encodings, sequence_lengths) in enumerate(dataloader):
        
        print(f'sequence_lengths: {sequence_lengths}')
        encodings = encodings.to(device)
        sequence_lengths = sequence_lengths.to(device)
        
        x_hat, mus_inference, sigmas_inference, mus_generator, sigmas_generators = model(encodings)
        
        #get loss with only sum over latent dim dimension
        reconstruction_loss = log_bernoulli_with_logits(encodings, x_hat, sequence_lengths, T_reduction='none') 
        kl_loss = kl_normal(mus_inference, 
                            sigmas_inference, 
                            mus_generator, 
                            sigmas_generators, 
                            sequence_lengths,
                            T_reduction='none')
        
        kl_loss = kl_loss.sum(-1) #sum over T
        reconstruction_loss = reconstruction_loss.sum(-1) #sum over T
        
        #for a: #importance sampling
        z, mu_q, var_q = model.encoder(encodings)
        bs = encodings.shape[0]
        max_sequence_length = encodings.shape[1]
        loss_s = torch.zeros(bs)
        all_exponent_args = []
        
        for _ in range(config.test.S):
            z_s = mu_q + torch.sqrt(var_q) * torch.randn_like(mu_q)
            x_hat_s, mu_p, var_p = model.decoder(z_s)

            range_tensor = repeat(torch.arange(max_sequence_length), 'l -> b l', b=bs).to(sequence_lengths.device) #shape: (batch, seq_len)
            mask = range_tensor < rearrange(sequence_lengths, 'b -> b ()')
            mask = mask.to(sequence_lengths.device)
            mask = rearrange(mask, 'b s -> b s ()') #shape : (bs, seq_len, latent_dim)
            
            #binary cross entropy
            log_s_recosntruction_loss = log_bernoulli_with_logits(encodings, x_hat_s, sequence_lengths, T_reduction='mean')
            
            #gaussian log prob p(z)
            nll_p_z = F.gaussian_nll_loss(mu_p, z_s, var_p, reduction='none')
            nll_p_z = nll_p_z * mask.float()
            log_p_z = nll_p_z.sum(-1).mean(-1) #sum over latent dim and T #final shape (batch,)
            
            #gaussian log prob q(z|x)
            nll_q_z = F.gaussian_nll_loss(mu_q, z_s, var_q, reduction='none')
            nll_q_z = nll_q_z * mask.float()
            log_q_z = nll_q_z.sum(-1).mean(-1) #sum over latent dim and T #final shape (batch,)
            
            # loss_s += torch.exp(-(log_s_recosntruction_loss + log_p_z - log_q_z))
            exponent_arg = -(log_s_recosntruction_loss + log_p_z - log_q_z)
            all_exponent_args.append(exponent_arg)
                
            
        loss_s = -torch.logsumexp(torch.stack(all_exponent_args), dim=0) + torch.log(torch.tensor(config.test.S, dtype=torch.float))
        loss_s = loss_s.mean()
        break

sequence_lengths: tensor([129,  65])


In [32]:
all_exponent_args

[tensor([-13.2727,  -9.4349]), tensor([-13.1605,  -9.4529])]

In [33]:
loss_s

tensor(11.3294)