In [1]:
import numpy as np

import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions
import torch.optim as optim
import numpy as np
import pdb  # noqa: F401
from collections import OrderedDict

from utils import SMALL, log_sum_exp  # noqa: F401
import shared

In [2]:
ar = range(0,44100*10,1)
wav = np.sin(ar*400)
wavs = [wav[i*44100:(i+1)*44100] for i in range(10)]
wavs = torch.Tensor(wavs)

In [3]:
class MF_IBP(nn.Module):

    def __init__(self,N,N_SAMPLES,K=6,a=1,b=1,sigma_n=0.5,mu_phi=100,sigma_phi=0.5,mu_w=0,sigma_w=0.5):
        super(MF_IBP, self).__init__()
        self.truncation = K
        self.num_features = K
        self.sigma_n = sigma_n
        self.N_SAMPLES = N_SAMPLES
        
        # phi are the logs of the frequencies of the sines
        self.phi_mean = nn.Parameter(torch.zeros(self.num_features,1))
        self.phi_logvar = nn.Parameter(torch.zeros(self.num_features,1))

        # weights
        self.w_logmean = nn.Parameter(torch.zeros(self.num_features,N))
        self.w_logvar = nn.Parameter(torch.zeros(self.num_features,N))

        self.p_pi_alpha = a/float(K)
        self.p_pi_beta = b*(K-1)/float(K)

        # inverse softplus
        a_val = np.log(np.exp(self.p_pi_alpha) - 1)
        b_val = np.log(np.exp(self.p_pi_beta) - 1)
        self.q_pi_alpha = nn.Parameter(torch.Tensor(self.truncation).zero_() + a_val)
        self.q_pi_beta = nn.Parameter(torch.Tensor(self.truncation).zero_() + b_val)
        
        # These are broadcast up into the right shape (they are diagonal)
        self.p_phi = distributions.Normal(loc=mu_phi,scale=sigma_phi)
        self.p_w = distributions.Normal(loc=mu_w,scale=sigma_w)

    def forward(self, indices):
        
        # p(pi)
        batch_sz = indices.size()[0]
        num_feat = self.q_pi_alpha.size()
        p_pi = distributions.Beta(torch.ones(num_feat)*self.p_pi_alpha,torch.ones(num_feat)*self.p_pi_beta)

        # q(pi)
        beta_a = F.softplus(self.q_pi_alpha) + 0.01
        beta_b = F.softplus(self.q_pi_beta) + 0.01
        q_pi = distributions.Beta(beta_a, beta_b)

        # q(z) and sample z
        qpi_sample = q_pi.rsample() # Differentiable Sample Knowles et al. 
        qpi_sample_batch = qpi_sample.repeat(batch_sz,1) # Repeat for whole batch size
        z = shared.STRelaxedBernoulli(temperature=0.1,probs=qpi_sample_batch).rsample() # z is BATCH_SZ X K
        q_z = distributions.Bernoulli(probs=qpi_sample)
        
        print "ZZZZZ"
        print z
        
        # q(w), sample w just for data indices in batch
        w_mean = (self.w_logmean.index_select(1, indices.long())).exp()
        w_std = (self.w_logvar.index_select(1, indices.long())/2).exp()
        q_w = distributions.Normal(loc=w_mean,scale=w_std)
        w = q_w.mean # w is K x BATCH_SZ
        
        # q(phi) and sample phi
        q_phi = distributions.Normal(loc=self.phi_mean,scale=(self.phi_logvar/2).exp())
        phi = q_phi.mean # phi is K x 1
        
        # Alternatively, sample instead of taking mean
        # phi = q_phi.rsample()
        # w = q_w.rsample()
        
    
        # FEATURE BASIS
        samples_onedatapoint = wavs[0].shape[0]
        ones = torch.ones(self.truncation,samples_onedatapoint)
        time = ones*torch.arange(0,samples_onedatapoint,1)
        sinbasis = torch.sin(time*phi)
    
        # MASKED WEIGHTS
        zw = torch.mul(torch.transpose(z, 0, 1),w) # z and w multiplied elementwise, zw is K x BATCH_SZ
          
        nll = 0
        # MUST BE CLEANER WAY TO DO THIS
        for i in indices:
            i_index = torch.Tensor([i]).long()
            wavs_i = wavs.index_select(0,i_index)
            weightings = zw.index_select(1,i_index).view(-1,1).repeat(1,samples_onedatapoint)
            contributions = torch.mul(weightings,sinbasis)
            xi_mean = torch.sum(contributions, dim=0)
            nll_i = torch.sum(distributions.Normal(loc=xi_mean,scale=self.sigma_n).log_prob(wavs_i))
            nll += nll_i
                    
        return nll, p_pi, q_pi, q_z, q_phi, q_w, sinbasis

    def NEG_ELBO(self,nll,p_pi,q_pi,q_z,q_phi,q_w):
        kl_div = distributions.kl_divergence
        # need kl_z
        kl_pi = kl_div(q_pi,p_pi).sum()
        kl_phi = kl_div(q_phi,self.p_phi).sum()
        kl_w = kl_div(q_w,self.p_w).sum()
        return nll + kl_pi + kl_phi + kl_w

In [6]:
SAMPLE_RATE=44100
NUM_SECONDS=10
N_SAMPLES = SAMPLE_RATE*NUM_SECONDS
m = MF_IBP(10,N_SAMPLES)
params = filter(lambda x: x.requires_grad, m.parameters())
print params
optimizer = optim.Adam(params, lr=.0001)

[Parameter containing:
tensor([[ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.]]), Parameter containing:
tensor([[ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 0.]]), Parameter containing:
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]), Parameter containing:
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]), Par

In [7]:
for i in range(100):
    nll,p_pi,q_pi,q_z,q_phi,q_w,sinbasis = m.forward(torch.Tensor([0,1,2,3,4,5,6,7,8,9]))
    loss = m.NEG_ELBO(nll,p_pi,q_pi,q_z,q_phi,q_w)
    loss.backward()
    optimizer.step() 

ZZZZZ
tensor([[ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  1.]])
ZZZZZ
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.,  0.]])
ZZZZZ
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  

KeyboardInterrupt: 