In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions
import numpy as np
import pdb  # noqa: F401
from collections import OrderedDict

from utils import SMALL, log_sum_exp  # noqa: F401
import shared

In [2]:
ar = range(0,44100*10,1)
wav = np.sin(ar*400)
wavs = [wav[i*44100:(i+1)*44100] for i in range(10)]
wavs = torch.Tensor(wavs)

In [12]:
class MF_IBP(nn.Module):

    def __init__(self,N_SAMPLES,K=6,a=1,b=1,sigma_n=0.5,mu_phi=100,sigma_phi=0.5,mu_w=0,sigma_w=0.5):
        super(MF_IBP, self).__init__()
        self.truncation = K
        self.num_features = K
        self.sigma_n = sigma_n
        self.N_SAMPLES = N_SAMPLES
        
        # phi are the logs of the frequencies of the sines
        self.phi_mean = nn.Parameter(torch.zeros(self.num_features,1))
        self.phi_logvar = nn.Parameter(torch.zeros(self.num_features,1))

        # weights
        self.w_mean = nn.Parameter(torch.ones(self.num_features,10))
        self.w_logvar = nn.Parameter(torch.zeros(self.num_features,10))

        self.p_pi_alpha = a/float(K)
        self.p_pi_beta = b*(K-1)/float(K)

        # inverse softplus
        a_val = np.log(np.exp(self.p_pi_alpha) - 1)
        b_val = np.log(np.exp(self.p_pi_beta) - 1)
        self.q_pi_alpha = nn.Parameter(torch.Tensor(self.truncation).zero_() + a_val)
        self.q_pi_beta = nn.Parameter(torch.Tensor(self.truncation).zero_() + b_val)
        
        # These are broadcast up into the right shape (they are diagonal)
        self.p_phi = distributions.Normal(loc=mu_phi,scale=sigma_phi)
        self.p_w = distributions.Normal(loc=mu_w,scale=sigma_w)


    def forward(self, x):
        
        batch_sz = x.size()[0]
        sz = self.q_pi_alpha.size()
        print "batch sz", batch_sz
        print "sz",sz

        p_pi = distributions.Beta(torch.ones(sz)*self.p_pi_alpha,torch.ones(sz)*self.p_pi_beta)

        beta_a = F.softplus(self.q_pi_alpha) + 0.01
        beta_b = F.softplus(self.q_pi_beta) + 0.01
        q_pi = distributions.Beta(beta_a, beta_b)

        # Differentiable Sample Knowles et al. 
        qpi_sample = q_pi.rsample()
        qpi_sample_N = qpi_sample.repeat(batch_sz,1)
        q_z = shared.STRelaxedBernoulli(temperature=0.1,probs=qpi_sample_N)
        z = q_z.rsample() 
        q_z = distributions.Bernoulli(probs=qpi_sample)
        
        q_phi = distributions.Normal(loc=self.phi_mean,scale=(self.phi_logvar/2).exp())
        q_w = distributions.Normal(loc=self.w_mean,scale=(self.w_logvar/2).exp())
        
        # For now, just take the mean
        phi = q_phi.mean
        w = q_w.mean

        # Alternatively, sample
        # phi = q_phi.rsample()
        # w = q_w.rsample()

        # NLL
        samples_onedatapoint = x[0].shape[0]
        sinbasis = torch.ones(self.truncation,samples_onedatapoint)*torch.arange(0,samples_onedatapoint,1)

        for k in range(self.truncation):
            sinbasis[k] = torch.sin(sinbasis[k]*phi[k])
        
        z = torch.transpose(z, 0, 1)
        zw = torch.mul(z,w) # z and w multiplied elementwise
        #print "x shape", x.shape
        #x_mean = torch.mm(zw,sinbasis)
        
        nll = 0
        
        print "NSAMPLES is ",N_SAMPLES
        for i in range(batch_sz):
            masked_weights = zw[:,i].view(-1,1).repeat(1,samples_onedatapoint)
            contributions = torch.mul(masked_weights,sinbasis)
            xi_mean = torch.sum(contributions, dim=0)
            
            #xi_mean = torch.sin(400*torch.arange(0,samples_onedatapoint,1))
                        
            #nll_i = -(distributions.
            #          MultivariateNormal(loc=xi_mean,covariance_matrix=torch.eye(xi_mean.shape[0])*self.sigma_n).log_prob(x[i]))
            
            nll_i = torch.sum(distributions.Normal(loc=xi_mean,scale=self.sigma_n).log_prob(x[i]))
            
            print nll_i
            nll += nll_i
                    
        return nll, p_pi, q_pi, q_z, q_phi, q_w, sinbasis

In [13]:
SAMPLE_RATE=44100
NUM_SECONDS=10
N_SAMPLES = SAMPLE_RATE*NUM_SECONDS
m = MF_IBP(N_SAMPLES)

In [14]:
m.forward(wavs)

batch sz 10
sz torch.Size([6])
NSAMPLES is  441000
tensor(-98154.0703)
tensor(-98157.6172)
tensor(-98156.2891)
tensor(-98157.0781)
tensor(-98154.1562)
tensor(-98157.9219)
tensor(-98156.4453)
tensor(-98156.0938)
tensor(-98155.1797)
tensor(-98157.4375)


(tensor(1.00000e+05 *
        -9.8156),
 Beta(),
 Beta(),
 Bernoulli(),
 Normal(),
 Normal(),
 tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  ...,  0.,  0.,  0.]]))

In [9]:
ones = torch.ones(3)


In [10]:
ones

tensor([ 1.,  1.,  1.])

In [11]:
torch.arange(0,3,1)

tensor([ 0.,  1.,  2.])