In [4]:

import torch
import torch.nn.functional as F

In [4]:
a = torch.randn(1, 129, 88)
input_size = a.shape[2]
input_size = 88
hidden_size = 400

rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, bidirectional=True, batch_first=True)
out, hidden = rnn(a) # out: [1, 129, 800], hidden: [2, 1, 400]
print(f'out.shape: {out.shape}')
print(f'hidden.shape: {hidden.shape}')
h_left = out[:, :, :hidden_size]
h_right = out[:, :, hidden_size:]

print(f'h_left.shape: {h_left.shape}')
print(f'h_right.shape: {h_right.shape}')


out.shape: torch.Size([1, 129, 800])
hidden.shape: torch.Size([2, 1, 400])
h_left.shape: torch.Size([1, 129, 400])
h_right.shape: torch.Size([1, 129, 400])


In [5]:
z_0 = torch.randn(1, 1, 100)

In [6]:
combiner = torch.nn.Linear(100, 400)
h_combined = combiner(z_0)
h_combined = .5 * (F.tanh(h_combined) + h_right)
print(f'h_combined.shape: {h_combined.shape}')

h_combined.shape: torch.Size([1, 129, 400])


In [7]:
def softplus(x):
    return torch.log(1 + torch.exp(x))

In [8]:
mu_linear = torch.nn.Linear(400, 100)
sigma_linear = torch.nn.Linear(400, 100)

mu = mu_linear(h_combined)
sigma = softplus(sigma_linear(h_combined))

print(f'mu.shape: {mu.shape}')
print(f'sigma.shape: {sigma.shape}')    


mu.shape: torch.Size([1, 129, 100])
sigma.shape: torch.Size([1, 129, 100])


In [9]:
z_1 = mu + sigma * torch.randn_like(mu)


In [10]:
#emission function
emission = torch.nn.Sequential(
    torch.nn.Linear(100, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 88),
    torch.nn.Sigmoid()
)

#transition function
G = torch.nn.Sequential(
    torch.nn.Linear(100, 200), #gating unit
    torch.nn.ReLU(),
    torch.nn.Linear(200, 100), #proposed mean
    torch.nn.Sigmoid()
)

H = torch.nn.Sequential(
    torch.nn.Linear(100, 200), #proposed mean
    torch.nn.ReLU(),
    torch.nn.Linear(200, 100), #proposed mean
    torch.nn.Identity(),
)

mu_gated_linear = torch.nn.Linear(100, 100) #w_{mu_p} * z_{t-1} + b_{mu_p}
sigma_gated_linear = torch.nn.Sequential( #w_{sigma_p} * relu(h_t) + b_{sigma_p}
    torch.nn.ReLU(),
    torch.nn.Linear(100, 100),
    torch.nn.Softplus(),
)
    




In [11]:
emission(z_1).shape

torch.Size([1, 129, 88])

In [12]:
#forward pass of transition function
z_0 = torch.zeros(1, 1, 100)
g_out = G(z_0)
one_minus_g = 1 - g_out # 1 - g
mu_linear_out = mu_gated_linear(z_0) #w_{mu_p} * z_{t-1} + b_{mu_p}
elementwise_mu_out = mu_linear_out * one_minus_g
proposed_mean_out = H(z_0)
mu_generator = (proposed_mean_out * g_out) + elementwise_mu_out

sigma_generator = sigma_gated_linear(proposed_mean_out)


print(f'out.shape: {one_minus_g.shape}')
print(f'sigma_generator.shape: {sigma_generator.shape}')

out.shape: torch.Size([1, 1, 100])
sigma_generator.shape: torch.Size([1, 1, 100])


In [13]:
z_1 = mu_generator + sigma_generator * torch.randn_like(mu_generator)

In [1]:
from model import DVAE

import torch
import torch.nn.functional as F

In [2]:
model = DVAE()

In [3]:
x = torch.randn(3, 129, 88)

In [4]:
x_hat, mus_p_z, sigmas_p_z, mus_generator, sigmas_generators = model(x)

In [6]:
def kl_normal(qm, qv, pm, pv):
    """
    Computes the elem-wise KL divergence between two normal distributions KL(q || p) and
    sum over the last dimension

    Args:
        qm: tensor: (batch, dim): q mean
        qv: tensor: (batch, dim): q variance
        pm: tensor: (batch, dim): p mean
        pv: tensor: (batch, dim): p variance

    Return:
        kl: tensor: (batch,): kl between each sample
    """
    element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
    kl = element_wise.sum(-1)
    return kl

In [12]:
print(f'x_hat.shape: {x_hat.shape}')
print(f'mus_p_z.shape: {mus_p_z.shape}')
print(f'sigmas_p_z.shape: {sigmas_p_z.shape}')
print(f'mus_generator.shape: {mus_generator.shape}')
print(f'sigmas_generators.shape: {sigmas_generators.shape}')


x_hat.shape: torch.Size([3, 129, 88])
mus_p_z.shape: torch.Size([3, 129, 100])
sigmas_p_z.shape: torch.Size([3, 129, 100])
mus_generator.shape: torch.Size([3, 129, 100])
sigmas_generators.shape: torch.Size([3, 129, 100])


In [11]:
mu_z_1 = mus_p_z[:, 0, :]
sigma_z_1 = sigmas_p_z[:, 0, :]
print(f'mu_z_1.shape: {mu_z_1.shape}')
print(f'sigma_z_1.shape: {sigma_z_1.shape}')

mus_generator_1 = mus_generator[:, 0, :]
sigmas_generator_1 = sigmas_generators[:, 0, :]
print(f'mus_generator_1.shape: {mus_generator_1.shape}')
print(f'sigmas_generator_1.shape: {sigmas_generator_1.shape}')

single_kl = kl_normal(mus_generator_1, 
          sigmas_generator_1, 
          mu_z_1, 
          sigma_z_1)


mu_z_1.shape: torch.Size([3, 100])
sigma_z_1.shape: torch.Size([3, 100])
mus_generator_1.shape: torch.Size([3, 100])
sigmas_generator_1.shape: torch.Size([3, 100])


In [13]:
bce = torch.nn.BCEWithLogitsLoss(reduction='none')

In [18]:
def log_bernoulli_with_logits(x, logits):
    """
    Computes the log probability of a Bernoulli given its logits

    Args:
        x: tensor: (batch, dim): Observation
        logits: tensor: (batch, dim): Bernoulli logits

    Return:
        log_prob: tensor: (batch,): log probability of each sample
    """
    log_prob = -bce(input=logits, target=x).sum(-1).sum(-1) #TODO: check if suming over t is correct
    return log_prob

In [26]:
recon_loss = -log_bernoulli_with_logits(x_hat, x)

In [24]:
#KL(q, p) for t=2 to T

kl_q_p_2 = kl_normal(mus_generator[:, 1:, :],
          sigmas_generators[:, 1:, :],
          mus_p_z[:, 1:, :],
          sigmas_p_z[:, 1:, :])

In [34]:
kl_q_p_2.sum(-1).shape

torch.Size([3])

In [40]:
single_kl

tensor([3.2470, 3.5192, 3.6346], grad_fn=<SumBackward1>)

In [41]:
nelbo_loss = recon_loss.mean() + kl_q_p_2.sum(-1).mean() + single_kl.mean()