In [1]:

import torch
import torch.nn.functional as F

In [2]:
a = torch.randn(1, 129, 88)
input_size = a.shape[2]
input_size = 88
hidden_size = 400

rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, bidirectional=True, batch_first=True)
out, hidden = rnn(a) # out: [1, 129, 800], hidden: [2, 1, 400]
print(f'out.shape: {out.shape}')
print(f'hidden.shape: {hidden.shape}')
h_left = out[:, :, :hidden_size]
h_right = out[:, :, hidden_size:]

print(f'h_left.shape: {h_left.shape}')
print(f'h_right.shape: {h_right.shape}')


out.shape: torch.Size([1, 129, 800])
hidden.shape: torch.Size([2, 1, 400])
h_left.shape: torch.Size([1, 129, 400])
h_right.shape: torch.Size([1, 129, 400])


In [3]:
z_0 = torch.randn(1, 1, 100)

In [4]:
combiner = torch.nn.Linear(100, 400)
h_combined = combiner(z_0)
h_combined = .5 * (F.tanh(h_combined) + h_right)
print(f'h_combined.shape: {h_combined.shape}')

h_combined.shape: torch.Size([1, 129, 400])


In [5]:
def softplus(x):
    return torch.log(1 + torch.exp(x))

In [6]:
mu_linear = torch.nn.Linear(400, 100)
sigma_linear = torch.nn.Linear(400, 100)

mu = mu_linear(h_combined)
sigma = softplus(sigma_linear(h_combined))

print(f'mu.shape: {mu.shape}')
print(f'sigma.shape: {sigma.shape}')    


mu.shape: torch.Size([1, 129, 100])
sigma.shape: torch.Size([1, 129, 100])


In [7]:
z_1 = mu + sigma * torch.randn_like(mu)


In [8]:
#emission function
emission = torch.nn.Sequential(
    torch.nn.Linear(100, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 88),
    torch.nn.Sigmoid()
)

#transition function
G = torch.nn.Sequential(
    torch.nn.Linear(100, 200), #gating unit
    torch.nn.ReLU(),
    torch.nn.Linear(200, 100), #proposed mean
    torch.nn.Sigmoid()
)

H = torch.nn.Sequential(
    torch.nn.Linear(100, 200), #proposed mean
    torch.nn.ReLU(),
    torch.nn.Linear(200, 100), #proposed mean
    torch.nn.Identity(),
)

mu_gated_linear = torch.nn.Linear(100, 100) #w_{mu_p} * z_{t-1} + b_{mu_p}
sigma_gated_linear = torch.nn.Sequential( #w_{sigma_p} * relu(h_t) + b_{sigma_p}
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.Softplus(),
)
    




In [9]:
emission(z_1).shape

torch.Size([1, 129, 88])

In [10]:
#forward pass of transition function
z_0 = torch.zeros(1, 1, 100)
g_out = G(z_0)
one_minus_g = 1 - g_out # 1 - g
mu_linear_out = mu_gated_linear(z_0) #w_{mu_p} * z_{t-1} + b_{mu_p}
elementwise_mu_out = mu_linear_out * one_minus_g
proposed_mean_out = H(z_0)
mu_generator = (proposed_mean_out * g_out) + elementwise_mu_out

sigma_generator = sigma_gated_linear(proposed_mean_out)


print(f'out.shape: {one_minus_g.shape}')
print(f'sigma_generator.shape: {sigma_generator.shape}')

out.shape: torch.Size([1, 1, 100])
sigma_generator.shape: torch.Size([1, 1, 100])


In [11]:
z_1 = mu_generator + sigma_generator * torch.randn_like(mu_generator)

In [1]:
from model import DVAE

import torch
import torch.nn.functional as F

In [2]:
model = DVAE()

In [3]:
x = torch.randn(3, 129, 88)

In [4]:
x_hat, mus_p_z, sigmas_p_z, mus_generator, sigmas_generators = model(x)

In [5]:
def kl_normal(qm, qv, pm, pv):
    """
    Computes the elem-wise KL divergence between two normal distributions KL(q || p) and
    sum over the last dimension

    Args:
        qm: tensor: (batch, dim): q mean
        qv: tensor: (batch, dim): q variance
        pm: tensor: (batch, dim): p mean
        pv: tensor: (batch, dim): p variance

    Return:
        kl: tensor: (batch,): kl between each sample
    """
    
    element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
    print(f'element_wise.shape: {element_wise.shape}')
    kl = element_wise.sum(-1)
    return kl

In [6]:
print(f'x_hat.shape: {x_hat.shape}')
print(f'mus_p_z.shape: {mus_p_z.shape}')
print(f'sigmas_p_z.shape: {sigmas_p_z.shape}')
print(f'mus_generator.shape: {mus_generator.shape}')
print(f'sigmas_generators.shape: {sigmas_generators.shape}')


x_hat.shape: torch.Size([3, 129, 88])
mus_p_z.shape: torch.Size([3, 129, 100])
sigmas_p_z.shape: torch.Size([3, 129, 100])
mus_generator.shape: torch.Size([3, 129, 100])
sigmas_generators.shape: torch.Size([3, 129, 100])


In [7]:
mu_z_1 = mus_p_z[:, 0, :]
sigma_z_1 = sigmas_p_z[:, 0, :]
print(f'mu_z_1.shape: {mu_z_1.shape}')
print(f'sigma_z_1.shape: {sigma_z_1.shape}')

mus_generator_1 = mus_generator[:, 0, :]
sigmas_generator_1 = sigmas_generators[:, 0, :]
print(f'mus_generator_1.shape: {mus_generator_1.shape}')
print(f'sigmas_generator_1.shape: {sigmas_generator_1.shape}')

single_kl = kl_normal(mu_z_1,
          sigma_z_1,
          mus_generator_1,
          sigmas_generator_1,
          )

print(f'single_kl: {single_kl}')

mu_z_1.shape: torch.Size([3, 100])
sigma_z_1.shape: torch.Size([3, 100])
mus_generator_1.shape: torch.Size([3, 100])
sigmas_generator_1.shape: torch.Size([3, 100])
element_wise.shape: torch.Size([3, 100])
single_kl: tensor([3.3950, 3.1909, 3.2479], grad_fn=<SumBackward1>)


In [8]:
bce = torch.nn.BCEWithLogitsLoss(reduction='none')
def log_bernoulli_with_logits(x, logits):
    """
    Computes the log probability of a Bernoulli given its logits

    Args:
        x: tensor: (batch, dim): Observation
        logits: tensor: (batch, dim): Bernoulli logits

    Return:
        log_prob: tensor: (batch,): log probability of each sample
    """
    log_prob = -bce(input=logits, target=x).sum(-1).sum(-1) #TODO: check if suming over t is correct
    return log_prob

In [9]:
recon_loss = -log_bernoulli_with_logits(x_hat, x)

In [10]:
#KL(q, p) for t=2 to T

kl_q_p_2 = kl_normal(mus_generator[:, 1:, :],
          sigmas_generators[:, 1:, :],
          mus_p_z[:, 1:, :],
          sigmas_p_z[:, 1:, :])

element_wise.shape: torch.Size([3, 128, 100])


In [11]:
kl_q_p_2.sum(-1).shape

torch.Size([3])

In [12]:
single_kl

tensor([3.3950, 3.1909, 3.2479], grad_fn=<SumBackward1>)

In [13]:
nelbo_loss = recon_loss.mean() + kl_q_p_2.sum(-1).mean() + single_kl.mean()

In [2]:
from omegaconf import OmegaConf
import argparse
import os
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import wandb

from dataloader import MusicDataset
from model import DVAE 



In [3]:
path = 'config.yaml'

config = OmegaConf.load(path)

In [4]:


dataset = MusicDataset(config.dataset)
dataloader = DataLoader(dataset, 
                        batch_size=config.train.batch_size, 
                        num_workers=config.train.num_workers,
                        pin_memory=True, #important for speed
                        shuffle=True)

device = torch.device('mps')
model = DVAE(input_dim=config.model.input_dim, 
                hidden_dim=config.model.hidden_dim,
                hidden_dim_em=config.model.hidden_dim_em, 
                hidden_dim_tr=config.model.hidden_dim_tr, 
                latent_dim=config.model.latent_dim)
model = model.to(device)

    
for i, (encodings, sequence_lengths) in enumerate(dataloader):
    # print(f'encodings.shape: {encodings.shape}'
    #       f'masks.shape: {masks.shape}'
    #       f'sequence_lengths.shape: {sequence_lengths.shape}')
    encodings = encodings.to(device)
    # masks = masks.to(device)
#     sequence_lengths = sequence_lengths.to(device)
    x_hat, mus_inference, sigmas_inference, mus_generator, sigmas_generators = model(encodings)
    
    



MUS SHAPE: torch.Size([64, 129, 100]) 
SIGMAS SHAPE: torch.Size([64, 129, 100]) 
MUS SHAPE: torch.Size([64, 129, 100]) 
SIGMAS SHAPE: torch.Size([64, 129, 100]) 
MUS SHAPE: torch.Size([64, 129, 100]) 
SIGMAS SHAPE: torch.Size([64, 129, 100]) 
MUS SHAPE: torch.Size([37, 129, 100]) 
SIGMAS SHAPE: torch.Size([37, 129, 100]) 


In [5]:
print(f'x_hat.shape: {x_hat.shape}')
print(f'mus_p_z.shape: {mus_inference.shape}')
print(f'sigmas_p_z.shape: {sigmas_inference.shape}')
print(f'mus_generator.shape: {mus_generator.shape}')
print(f'sigmas_generators.shape: {sigmas_generators.shape}')


x_hat.shape: torch.Size([37, 129, 88])


NameError: name 'mus_p_z' is not defined

In [101]:
nll = log_bernoulli_with_logits(x_hat, encodings)
nll.shape
nll

tensor([-8389.2754, -8827.4854, -8549.8721, -8581.2773, -8714.9258, -8580.8516,
        -8826.5322, -8788.3408, -8785.7773, -8091.1851, -8582.5850, -8470.4922,
        -8396.6348, -8581.1172, -8908.4766, -8825.5752, -8755.4609, -8580.4609,
        -8458.4258, -7926.2002, -8349.0488, -8079.7500, -8757.4297, -8622.7598,
        -8538.4902, -8746.5000, -8867.4033, -8747.2969, -8746.8955, -8621.9541,
        -8743.1934, -8252.8320, -8735.8770, -8755.2334, -8468.2754, -8293.1406,
        -8581.4688], device='mps:0', grad_fn=<NegBackward0>)

In [111]:
nll.mean().backward()

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [MPSFloatType [37, 100]], which is output 0 of AsStridedBackward0, is at version 129; expected version 128 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [102]:
from einops import repeat, rearrange
"""
Computes the elem-wise KL divergence between two normal distributions KL(q || p) and
sum over the last dimension

Args:
    qm: tensor: (batch, dim): q mean
    qv: tensor: (batch, dim): q variance
    pm: tensor: (batch, dim): p mean
    pv: tensor: (batch, dim): p variance

Return:
    kl: tensor: (batch,): kl between each sample
"""
qm = mus_generator
qv = sigmas_generators
pm = mus_inference
pv = sigmas_inference

element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)

# mask out the padding after sequence length for each datapoint
bs, max_sequence_length, _ = element_wise.shape

# range_tensor = torch.arange(max_sequence_length).unsqueeze(0).expand(bs, -1) # [seq] -> [1, seq] -> [bs, seq]
range_tensor = repeat(torch.arange(max_sequence_length), 'l -> b l', b=bs)
mask = range_tensor < rearrange(sequence_lengths, 'b -> b ()')
mask = rearrange(mask, 'b s -> b s ()')

kl = element_wise * mask.float()
kl = kl.sum(-1).sum(-1)


In [109]:
kl.mean().backward()

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [85]:
range_tensor = torch.arange(129).unsqueeze(0).expand(element_wise.shape[0], -1).to('cpu')

# Create the mask: True for elements we want to keep, False for elements to be masked
mask = range_tensor < sequence_lengths.unsqueeze(1).to('cpu')

# Apply the mask to the data
# You need to unsqueeze the mask to make its shape (batch_size, sequence_length, 1) so it broadcasts correctly over the feature dimension
masked_data = element_wise * mask.unsqueeze(2).float()

In [86]:
element_wise.sum(-1).sum(-1)

tensor([703.9436, 710.7096, 704.5284, 673.8802, 701.5371, 705.6766, 693.5052,
        696.0417, 708.2553, 719.8946, 676.9679, 706.6138, 710.3296, 703.0452,
        697.2532, 714.9741, 686.5779, 712.7692, 731.7780, 722.0126, 721.5752,
        678.3644, 711.3787, 705.0903, 693.5876, 665.4576, 686.4227, 665.4719,
        704.4774, 688.3771, 707.5785, 695.1637, 714.8768, 700.8198, 694.5972,
        691.2836, 700.0506], grad_fn=<SumBackward1>)

In [87]:
masked_data.sum(-1).sum(-1).shape

torch.Size([37])

In [88]:
torch.arange(129).unsqueeze(0).expand(element_wise.shape[0], 0).shape

RuntimeError: The expanded size of the tensor (0) must match the existing size (129) at non-singleton dimension 1.  Target sizes: [37, 0].  Tensor sizes: [1, 129]

In [91]:
masked_data.mean()

tensor(0.0254, grad_fn=<MeanBackward0>)