In [2]:
!pip3 install pypolyagamma

Collecting pypolyagamma
[?25l  Downloading https://files.pythonhosted.org/packages/f1/4f/8f444c95283613ebacb2703905f3454ccd6c045c87f8f3a48d8e3206eb9c/pypolyagamma-1.2.2.tar.gz (233kB)
[K     |█▍                              | 10kB 16.7MB/s eta 0:00:01[K     |██▉                             | 20kB 23.1MB/s eta 0:00:01[K     |████▏                           | 30kB 27.6MB/s eta 0:00:01[K     |█████▋                          | 40kB 30.7MB/s eta 0:00:01[K     |███████                         | 51kB 34.1MB/s eta 0:00:01[K     |████████▍                       | 61kB 36.5MB/s eta 0:00:01[K     |█████████▉                      | 71kB 37.6MB/s eta 0:00:01[K     |███████████▎                    | 81kB 39.1MB/s eta 0:00:01[K     |████████████▋                   | 92kB 40.6MB/s eta 0:00:01[K     |██████████████                  | 102kB 42.4MB/s eta 0:00:01[K     |███████████████▍                | 112kB 42.4MB/s eta 0:00:01[K     |████████████████▉               | 122kB 42.4

In [0]:
from pypolyagamma import PyPolyaGamma

In [5]:
!git clone https://github.com/slinderman/ssm.git
%cd ssm
!pip install -e .

Cloning into 'ssm'...
remote: Enumerating objects: 202, done.[K
remote: Counting objects: 100% (202/202), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 2218 (delta 134), reused 139 (delta 84), pack-reused 2016[K
Receiving objects: 100% (2218/2218), 16.58 MiB | 5.90 MiB/s, done.
Resolving deltas: 100% (1527/1527), done.
/content/ssm
Obtaining file:///content/ssm
Installing collected packages: ssm
  Running setup.py develop for ssm
Successfully installed ssm


In [0]:
from ssm import messages

In [0]:
import numpy as np
import numpy.random as npr
from scipy.special import logsumexp
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from scipy.stats import norm
from scipy.stats import multivariate_normal as MVN
import sys

In [0]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [0]:
# Define the piecewise linear tanh approximation
class PiecewiseLinearSigmoid(object):
    """
    Piecewise linear approximation to sigma(x) with the specified 
    window size and number of linear parts.
    """
    def __init__(self, num_parts=16, window=4):
        self.num_parts = num_parts
        print('num_parts={}'.format(num_parts))
        self.knots = np.linspace(-window, window, num_parts-1)
        self.values = sigmoid(self.knots)

        # Compute slopes and intercepts for each bin
        self.slopes = np.concatenate(([0], np.diff(self.values) / np.diff(self.knots), [0]))
        self.intercepts = np.concatenate(([self.values[0]], self.values - self.slopes[1:] * self.knots))

    def f(self, x):
        z = np.digitize(x, self.knots)
        return self.intercepts[z] + self.slopes[z] * x
        
class PiecewiseLinearTanh(PiecewiseLinearSigmoid):
    def f(self, x):
        return 2 * super(PiecewiseLinearTanh, self).f(2 * x) - 1



In [0]:
# Define the piecewise linear tanh approximation
class StochPiecewiseLinearSigmoid(PiecewiseLinearSigmoid):
    """
    Stochastic p.w.l. approximation to sigma(x) with the specified 
    window size and number of linear parts.  Here, the discrete
    part is chosen randomly based on a tree-structured stick breaking
    of the interval.
    """
    def __init__(self, num_parts=16, window=4, temp=0.25):
        super(StochPiecewiseLinearSigmoid, self).__init__(num_parts, window)
        self.temp = temp

        # Precompute the transition probabilities on a dense grid        
        # num_pts = 1000
        # x = np.linspace(-1.5 * window, 1.5 * window, num_pts)
        # p = np.ones((num_pts, num_parts))

        # Get length of binary representation (i.e. depth of tree)
        decimal, depth = np.modf(np.log2(num_parts))
        assert np.allclose(decimal, 0), "number of parts must be a power of 2."
        depth = int(depth)
        self.depth = depth

        # Get a binary representation of each of the discrete parts
        self._bs = np.unpackbits(
            np.arange(self.num_parts, 
                      dtype=np.uint8)[:, None], axis=1)[:, -(self.depth):]

        # Precompute indices of decision knots for each discrete state
        mids = 2**np.arange(self.depth - 1, -1, -1)
        lefts = np.column_stack((np.zeros(self.num_parts), 
                                 np.cumsum(self._bs[:, :-1] * mids[:-1], axis=1)))
        self._inds = np.array(lefts + mids - 1, dtype=int)

    
    def f_given_z(self, x, z):
      return self.intercepts[z] + self.slopes[z] * x
    
    def f(self, x, return_z=False):
        """
        Sample a discrete state given x, then return the corresponding
        linear function of x.
        """
        x = np.atleast_1d(x)
        z = np.zeros_like(x, dtype=int)
        shp = x.shape
        for d in range(self.depth):
            mid = 2**(self.depth-d-1) - 1
            choice = npr.rand(*shp) < sigmoid((x - self.knots[z + mid]) / self.temp)
            z += choice * (mid + 1)
        
        if return_z:
            return self.f_given_z(x,z), z
        else:
            return self.f_given_z(x,z)
        
        '''
        if return_z:
            return self.intercepts[z] + self.slopes[z] * x, z
        else:
            return self.intercepts[z] + self.slopes[z] * x
        '''

    def discrete_prior(self, x):
        """
        Get the prior on discrete states for input x
        """

        N = x.shape[0]
        d = x.shape[1]
      
        # To compute the posterior distribution of discrete states,
        # first evaluate sigmoids at each (input, knot) pair.
        s = sigmoid((x - self.knots[None,None, :]) / self.temp)
        oms = 1 - s

        # Discrete state probabilities are products of sigmoids for subsets
        # of knots.
        prior = np.ones((N, d, self.num_parts))

        for k in range(self.num_parts):
            bk = self._bs[k]
            ik = self._inds[k]

            prior[:,:, k] = np.prod(s[:,:, ik[bk==1]], axis=2) * \
                          np.prod(oms[:,:, ik[bk==0]], axis=2)
        return prior


    def vectorized_sample_categorical(self, prob_matrix, items):
      prob_matrix=prob_matrix.T
      s = prob_matrix.cumsum(axis=0)
      r = np.random.rand(prob_matrix.shape[1])
      k = (s < r).sum(axis=0)
      return items[k]

    def resample_discrete_states(self, x, y, sigmasq):
        """
        Sample discrete variable z given (input, output) pair (x, y) and 
        Gaussian noise variance sigmasq.
        """
        
        # Get the log prior on discrete states
        log_prior = np.log(self.discrete_prior(x))

        # Compute log likelihood (up to constant) under each discrete state
        yhat = x * self.slopes + self.intercepts
        log_lkhd = -0.5 * (y - yhat)**2 / sigmasq

        # Posterior is proportional to prior * lkhd
        log_post = log_prior + log_lkhd
        post = np.exp(log_post - logsumexp(log_post, axis=2, keepdims=True))
        
        self.cond_z = post ##for unit testing   
        
        # Sample the posterior
        post = post.reshape(x.shape[0]*x.shape[1] ,self.num_parts)
        z = self.vectorized_sample_categorical(post, np.arange(self.num_parts))
        z = z.reshape(x.shape[0], x.shape[1],1)
        return z

    def resample_auxiliary_variables(self, x, z):
        # Sample the conditional distribution of the Polya-gamma auxiliary 
        # variable \omega given the input x.  Here the discrete state really
        # corresponds to a set of Bernoulli random variables.  Each one is 
        # parameterized as,
        #
        #     b_i ~ Bern(\sigma((x - \theta_i) / tau)) 
        # 
        # where \theta_i is the knot for the i-th choice.  The conditional
        # distribution is \omega ~ PG(1, (x - \theta_i) / tau)), and is 
        # independent of the binary outcome b_i. 
        
        z = z.ravel()

        b = self._bs[z]
        i = self._inds[z]        
        u = (x.ravel()[:, None] - self.knots[i]) / self.temp

        
        # Sample the auxiliary variable and compute the potential on u
        # TODO: Sample rather than just returning the mean
        #Ju = np.tanh(u / 2) / (2 * u)
        
        pg_shp = u.ravel().shape
        pg = PyPolyaGamma(seed=npr.randint(0,100000000))
        Ju = np.empty(pg_shp)        
        pg.pgdrawv(np.ones(pg_shp), u.ravel(), Ju)
        Ju = Ju.reshape(u.shape)
    
        hu = b - 0.5

        # Convert to mean parameters
        mu = hu / Ju
        Vu = 1 / Ju

        # Convert the potential on u to a potential on x.
        mx = mu * self.temp + self.knots[i]
        Vx = Vu * self.temp**2

        # Convert back to natural parameters, sum, and return mean
        Jx = np.sum(1 / Vx, axis=1)
        hx = np.sum(mx / Vx, axis = 1)
        mx = hx / Jx
        Vx = 1 / Jx

        mx = mx.reshape(x.shape)
        Vx = Vx.reshape(x.shape)
        return mx, Vx

    def resample(self, x, y, sigmasq):
        """
        Resample the discrete state and auxiliary variables for given inputs (x) 
        and outputs (y), and given noise variance sigmasq. 
        """
        # First sample the discrete states
        zs = self.resample_discrete_states(x, y, sigmasq)
        #bs = self._bs[zs]

        # Then sample auxiliary variables from conditional and compute the 
        # effective Gaussian observation potential.
        mx, Vx = self.resample_auxiliary_variables(x, zs)
        return zs, mx, Vx

          


In [0]:
class StochPiecewiseLinearTanh(StochPiecewiseLinearSigmoid):
    def f(self, x, return_z=False):
        if return_z:
            y, z = super(StochPiecewiseLinearTanh, self).f(2 * x, return_z)
            return 2 * y - 1, z
        else:
            y = super(StochPiecewiseLinearTanh, self).f(2 * x, return_z)
            return 2 * y - 1

    def resample(self, x, y, sigmasq):
        """
        Resample the discrete states and auxiliary variables given observations
        of y = tanh(x) + N(0, sigmasq).  This is equivalent to,

            y' = sigmoid(x') + N(0, sigmasq')

        where y' = (y + 1) / 2
              x' = 2 * x
              sigmasq' = sigmasq / 4

        The output is a set of discrete state samples and Gaussian potentials
        on the input to the sigmoid, here p(z | x) ~ N(x' | mx', Vx'), which is
        equivalent to N(x | mx'/2, Vx'/4) 
        """
        zs, mx, Vx = super(StochPiecewiseLinearTanh, self).\
            resample(2 * x, (y + 1) / 2, sigmasq / 4)

        return zs, mx / 2, Vx / 4


    def x_recurrence_params(self, state, sigmasq):

      A = 4*self.slopes[state.z]*state.W
      Bu_tilde = (state.B @ state.u[1:])
      Bu = 4*self.slopes[state.z]*Bu_tilde +2*self.intercepts[state.z]-1
      
      J_ini = np.diag(1/sigmasq[:,0]) 
      J_dyn_11 = A[1:].transpose(0,2,1)@(1/sigmasq*A[1:])

      #J_dyn_21 = (-1/sigmasq*A[1:]).transpose(0,2,1)
      J_dyn_21 = -1/sigmasq*A[1:]

      J_dyn_22 = np.diag(1/sigmasq[:,0])   

      h_ini = (1/sigmasq*(A[0]@state.x[0]+Bu[0]))[:,0]
      h_dyn_1 = (-Bu[1:].transpose((0,2,1)) @ (1/sigmasq*A[1:]))[:,0,:]
      
      h_dyn_2 = (1/sigmasq*Bu[1:])[:,:,0]
      return J_ini, J_dyn_11, J_dyn_21, J_dyn_22, h_ini, h_dyn_1, h_dyn_2

    def x_obs_params(self, state, sigmasq_y):
      J_obs = state.C.T@(1/sigmasq_y*state.C)
      J_obs = J_obs*np.ones((state.y.shape[0],J_obs.shape[0], J_obs.shape[1]))
      h_obs = ((1/sigmasq_y*state.C).T @ (state.y-state.by))[:,:,0]
      return J_obs, h_obs

    def Wbar_pg_params(self, state, r, rrT, ms, Vs):
        d =  state.W.shape[0]
        ud = state.B.shape[1]
        J_tildes = np.zeros((Vs.shape[0],d, d+ud,d+ud))
        h_tildes = np.zeros((Vs.shape[0],d,d+ud,1))
        for j in range(0,d):
          J_tildes[:,j] = 1/Vs[:,j,None]*rrT
          h_tildes[:,j] = 1/Vs[:,j,None]*ms[:,j,None]*r

        return J_tildes, h_tildes

    def Wbar_recurrence_params(self, state, r, rrT, sigmasq):
      d =  state.W.shape[0]
      ud = state.B.shape[1]
      J = np.zeros((state.x[1:].shape[0],d, d+ud,d+ud))
      h = np.zeros((state.x[1:].shape[0],d,d+ud,1))
            
      for j in range(0,d):
        J[:,j] = ((4*self.slopes[state.z][:,j,None])**2)*(1/sigmasq[j,0])*rrT 
        h[:,j] = (1/sigmasq[j,0]*4*self.slopes[state.z][:,j,None]*(state.x[1:,j,None]-(2*self.intercepts[state.z][:,j,None]-1)))*r

      return J, h

    def Wbar_prior_potentials(self, state, Wbar_sigma, Wbar_prior):
      d =  state.W.shape[0]
      ud = state.B.shape[1]
      J_prior = np.zeros((d, d+ud,d+ud))
      h_prior = np.zeros((d,d+ud,1))
      for j in range(0,d):
        J_prior[j] = np.diag(1/(Wbar_sigma[j]**2))
        h_prior[j] =  (1/(Wbar_sigma[j]**2)*Wbar_prior[j]).reshape(-1,1)

      return J_prior, h_prior

    def Wbary_params(self, state,r, rrT, sigmaysq,Wbary_prior, Wbary_sigma):
      yd = state.C.shape[0]
      d = state.C.shape[1]
      J = np.zeros((yd, d+1,d+1))
      h = np.zeros((yd,d+1,1))
      sum_rrT = np.sum(rrT, axis=0)

      for j in range(0,yd):
        J[j] = 1/sigmaysq[j,0]*sum_rrT + np.diag(1/(Wbary_sigma[j]**2))

        ydr = state.y[:,j,:].reshape(state.y.shape[0],1,1)*r
        h[j] = 1/sigmaysq[j,0]*np.sum(ydr, axis=0)
        h[j] += ( 1/(Wbary_sigma[j]**2)*Wbary_prior[j]).reshape(-1,1)

      return J, h

    def gibbs_step(self, state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma, train_weights):
      
      # 1. Sample discrete states and auxiliary variables given continuous states x and Wbar.
      #First compute the input to the tanh
      inx = state.get_inx()
      iny = state.x[1:]

      state.z, ms, Vs = spwl_tanh.resample(inx, iny, sigma**2)

      '''
      ###########################
      #z unit tests
      state.z, _, _ = spwl_tanh.resample(inx, iny, sigma**2)
      zr = state.z.reshape(-1, 1).ravel()
      condz = self.cond_z.reshape(-1,self.cond_z.shape[2])
      log_cond1 = np.sum(np.log(condz[np.arange(len(zr)),zr]))


      log_joint1 = self.log_joint_nopgs(state, sigma, sigma_y, Wbar_prior, Wbar_sigma)

      state.z, _, _ = spwl_tanh.resample(inx, iny, sigma**2)
      zr = state.z.reshape(-1, 1).ravel()
      condz = self.cond_z.reshape(-1,self.cond_z.shape[2])
      log_cond2 = np.sum(np.log(condz[np.arange(len(zr)),zr]))
      log_joint2 = self.log_joint_nopgs(state, sigma, sigma_y, Wbar_prior, Wbar_sigma)

      print('z')
      print('log_joint1-log_joint2')
      print(log_joint1-log_joint2)
      print('log_cond1-log_cond2')
      print(log_cond1-log_cond2)
      sys.exit()
      ############################
      ''' 


      
      # 2. Sample continuous states x given (linear Gaussian) observations y and 
      #    current discrete states and auxiliary variables 

      #first Convert mean and variance on Wx + Bu to natural parameters on x
      J_tildes = state.W.T @ ( state.W*1/Vs[1:] )
      h_tildes = ((state.W* 1/Vs[1:]).transpose((0,2,1)) @ (ms[1:]- state.B @ state.u[2:]))[:,:,0]
      
      #Get recurrence natural parameters for x 
      J_ini, J_dyn_11, J_dyn_21, J_dyn_22, h_ini, h_dyn_1, h_dyn_2 = self.x_recurrence_params(state, sigma**2)
      

      #Get observation natural parameters for x
      J_obs, h_obs = self.x_obs_params(state, sigma_y**2)
      
      #Combine parameters
      J_dyn_11 += J_tildes
      h_dyn_1 += h_tildes
      
      #Sample x using message passing
      log_Z_obs = np.zeros(h_obs.shape[0])

      state.x[1:,:,0] = messages.kalman_info_sample(J_ini, h_ini, 0, J_dyn_11, J_dyn_21,     
                                        J_dyn_22, h_dyn_1, h_dyn_2, 0,           
                                        J_obs, h_obs, log_Z_obs) 
      
      #3. Train weights:
      if train_weights == True:
        r, rrT = state.get_r()
        #discrete and auxiliary potential
        #convert mean and variance on Wbar @ rt to natural parameters on Wbar=[W,B] where r=[x_{t-1}^T, u_t^T]^T
        J_tildes, h_tildes = self.Wbar_pg_params(state, r, rrT, ms, Vs)

        #recurrence potential 
        J_rec, h_rec = self.Wbar_recurrence_params(state, r, rrT, sigma**2)

        #prior potential
        J_prior, h_prior = self.Wbar_prior_potentials(state, Wbar_sigma, Wbar_prior)

        #Combine discrete/auxiliary, recurrence and prior
        Jw = np.sum(J_tildes+J_rec, axis=0)+J_prior
        hw = np.sum(h_tildes+h_rec, axis=0)+h_prior

        #convert to mean parameters
        Vw = np.linalg.inv(Jw)
        mw = Vw @ hw

        #Sample weights
        L = np.linalg.cholesky(Vw)
        Wbar = (mw + L @ npr.randn(*mw.shape))[:,:,0]
        state.split_Wbar(Wbar)

        '''
        ################################
        #Wbar unit tests
        #NEED to change squeeze
        Wbar = (mw + L @ npr.randn(*mw.shape))[:,:,0]]
        log_cond1=0
        for j in range(0,Wbar.shape[0]):
          log_cond1 += MVN.logpdf(Wbar[j], mw[:,:,0][j], Vw[j])
        state.split_Wbar(Wbar)
        log_joint1,_ = self.log_joint_base(state, sigma, sigma_y, Wbar_prior, Wbar_sigma)       
        Wr = Wbar @ r
        log_joint1 += np.sum(norm.logpdf(Wr,ms, np.sqrt(Vs) ))

        Wbar = (mw + L @ npr.randn(*mw.shape))[:,:,0]
        log_cond2=0
        for j in range(0,Wbar.shape[0]):
          log_cond2 += MVN.logpdf(Wbar[j], mw[:,:,0][j], Vw[j])
        state.split_Wbar(Wbar)
        log_joint2,_ = self.log_joint_base(state, sigma, sigma_y, Wbar_prior, Wbar_sigma)
        Wr = Wbar @ r
        log_joint2 += np.sum(norm.logpdf(Wr,ms, np.sqrt(Vs) ))

        

        print('Wbar')
        print('log_joint1-log_joint2')
        print(log_joint1-log_joint2)
        print('log_cond1-log_cond2')
        print(log_cond1-log_cond2)
        sys.exit()
        ################################
        '''
        
        #Update Wybar weights
        ry, ryryT = state.get_ry()
        J, h = self.Wbary_params(state,ry, ryryT, sigma_y**2,Wbary_prior, Wbary_sigma)

        #convert to mean parameters
        V = np.linalg.inv(J)
        m = V @ h

        #Sample weights
        L = np.linalg.cholesky(V)
        Wbary = (m + L @ npr.randn(*m.shape))[:,:,0]
        state.split_Wybar(Wbary)
        
        

        '''
        ################################
        #Wbary unit tests
        #NEED to change squeeze
        Wbary = (m + L @ npr.randn(*m.shape))[:,:,0]
        log_cond1=0
        for j in range(0,Wbary.shape[0]):
          log_cond1 += MVN.logpdf(Wbary[j], m[:,:,0][j], V[j])
        state.split_Wybar(Wbary)
        log_joint1 = self.log_joint_nopgs(state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma)
      

        Wbary = (m + L @ npr.randn(*m.shape))[:,:,0]
        log_cond2=0
        for j in range(0,Wbary.shape[0]):
          log_cond2 += MVN.logpdf(Wbary[j], m[:,:,0][j], V[j])
        state.split_Wybar(Wbary)
        log_joint2 = self.log_joint_nopgs(state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma)
        

        print('Wbar')
        print('log_joint1-log_joint2')
        print(log_joint1-log_joint2)
        print('log_cond1-log_cond2')
        print(log_cond1-log_cond2)
        sys.exit()
        ################################
        '''



    def gewecke_step(self, state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma, train_weights):
      self.gibbs_step(state, sigma, sigma_y, Wbar_prior, Wbar_sigma, Wbary_prior, Wbary_sigma,train_weights)

      #Samples ys for Gewecke testing
      state.y = state.C @ state.x[1:]+state.by+sigma_y*npr.randn(*state.y.shape)

    def x_log_prior(self, inx, state, sigma):
      mu = 2*self.f_given_z(2*inx, state.z) - 1 
      scale = sigma * np.ones(mu.shape)
      return norm.logpdf(state.x[1:], mu, scale) 

    def y_log_prior(self, state, sigma_y):
      mu = state.C @ state.x[1:]+state.by
      scale = sigma_y * np.ones(mu.shape)
      return norm.logpdf(state.y, mu, scale)

    def Wbar_log_prior(self, state, Wbar_prior, Wbar_sigma):
      Wbar = state.get_Wbar()
      return norm.logpdf(Wbar, Wbar_prior, Wbar_sigma)

    def Wbary_log_prior(self, state, Wbar_prior, Wbar_sigma):
      Wbar = state.get_Wybar()
      return norm.logpdf(Wbar, Wbar_prior, Wbar_sigma)

    def log_joint_base(self, state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma):
      inx = state.get_inx()
      log_like = np.sum(self.x_log_prior(inx, state, sigma))
      log_like += np.sum(self.y_log_prior(state, sigma_y))
      log_like += np.sum(self.Wbar_log_prior(state, Wbar_prior, Wbar_sigma))
      log_like += np.sum(self.Wbary_log_prior(state, Wbary_prior, Wbary_sigma))
      return log_like, inx
    
    def log_joint_nopgs(self, state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma):
      log_like, inx = self.log_joint_base(state, sigma, sigma_y, Wbar_prior, Wbar_sigma,Wbary_prior, Wbary_sigma)

      zr = state.z.reshape(-1, 1).ravel()
      priorz = self.discrete_prior(2*inx)
      priorz = priorz.reshape(-1,priorz.shape[2])
      priorz = priorz[np.arange(len(zr)),zr]

      return log_like + np.sum(np.log(priorz))


In [0]:
class State:
  def __init__(self, y, x, W, B, u, C,by):
    self.x0 = x[0]
    self.x = x
    self.u = u
    self.y = y
    self.W = W
    self.B = B
    self.C = C
    self.by = by
    self.omega = 0
    self.z = 0
    
  def get_inx(self):
    return self.W @ self.x[:-1]  + self.B @ self.u[1:]
  def get_Wbar(self):
    Wbar = np.concatenate((self.W, self.B), axis=1)
    return Wbar
  def split_Wbar(self, Wbar):
    d = self.W.shape[0]
    ud = self.B.shape[1]
    self.W = Wbar[:,:d]
    self.B = Wbar[:,d:]
  def get_r(self):
    r = np.concatenate((self.x[:-1], self.u[1:]), axis=1)
    rrT = (r[...,None]*r[:,None,:]).reshape(r.shape[0], r.shape[1], r.shape[1])
    return r, rrT
  def get_Wybar(self):
    Wybar = np.concatenate((self.C, self.by), axis=1)
    return Wybar
  def split_Wybar(self, Wybar):
    d = self.C.shape[1]
    self.C = Wybar[:,:d]
    self.by = Wybar[:,d:]
  def get_ry(self):
    r = np.concatenate((self.x[1:], np.ones((self.x[1:].shape[0],1,1))), axis=1)
    rrT = (r[...,None]*r[:,None,:]).reshape(r.shape[0], r.shape[1], r.shape[1])
    return r, rrT

In [0]:
def plot_y_obs(y, T,T_train,  ypred_mean, ypred_std, ytrue_mean, ytrue_std, ypred2_mean, ypred2_std, Wbar_prior, Wbar_sigma, sigma, d, num_parts, temp):
  plt.figure(figsize=(12, 8))
  plt.plot(y[:,0,0], label='observed')
  
  plt.plot(ytrue_mean,color='green', label='true_mean')
  plt.fill_between(np.arange(T), 
                  ytrue_mean - 2 * ytrue_std, 
                  ytrue_mean + 2 * ytrue_std,
                  color='green',
                  alpha=0.3)
  

  plt.plot(ypred_mean, color='orange', label='gen mean using W samples')
  plt.fill_between(np.arange(T), 
                  ypred_mean - 2 * ypred_std, 
                  ypred_mean + 2 * ypred_std,
                  color='orange',
                  alpha=0.5)

  plt.plot(np.arange(T_train-1,T), ypred2_mean, color='purple', label='post_pred')

  plt.fill_between(np.arange(T_train-1,T), 
                  ypred2_mean - 2 * ypred2_std, 
                  ypred2_mean + 2 * ypred2_std,
                  color='purple',
                  alpha=0.3)
  
  plt.title('d={}, sigma={}, Wbar_prior_mean={}, Wbar_sigma={},\n num_parts={}, temp={}'.format(d, sigma,Wbar_prior,Wbar_sigma,num_parts, temp  ))
  plt.legend()

# Run Gibbs sampler 

In [54]:
#seed = np.random.randint(low=0, high=1000000)
seed = 499017#126828#665346 #97865
print('seed={}'.format(seed))
npr.seed(seed)

num_iters = 3000000
prior_iters = 100000 #for Gewecke testing
burn_rate = .8
burn_iters = int(burn_rate*num_iters)
train_weights=True
window=4 #4
numparts=4 #16
temp= .5#1#.25

d = 3
ud=2 #u' dimension
yd = 2 #y dimension
T = 10
T_train = 6

####inx = Wx_{t-1}+Bu_t, where B = [B', b] and u_t = [u_t'^T, 1]^T (includes 1 for bias term)####
x0 = npr.randn(d,1)

###sine wave data
'''
u_ = np.zeros((T,ud,1))
data = np.sin(.2*(np.arange(0,T,1))).reshape(T,ud,1)
u_[1:] = data[:-1]
'''
###

u_ =  npr.uniform(size=(T,ud,1))#1*np.ones((T, ud,1)) #npr.uniform(size=(T,ud,1))
#u_[0:100:12] = -1
#u_[100:-1:20] = -1
u = np.concatenate((u_, np.ones((T,1,1))), axis=1)


sigma = 0.1*np.ones((d,1))
sigma_true = .1*np.ones((d,1))
sigma_y = 0.1*np.ones((yd,1)) 


Wbar_prior = .1*np.ones((d,d+ud+1))
Wbar_sigma = .1*np.ones(Wbar_prior.shape)
Wbar = npr.normal(Wbar_prior, Wbar_sigma)
W = Wbar[:,:d]
B = Wbar[:,d:]


Wbary_prior = -.4*np.ones((yd,d+1))
Wbary_sigma = .1*np.ones(Wbary_prior.shape)
Wbary = npr.normal(Wbary_prior, Wbary_sigma)
C = Wbary[:,:d]
by = Wbary[:,d:]

#C = np.ones((yd,d))
#by = .5*np.ones((yd,1))



#obs_sample
'''
lo1 = -.5
hi1=.5

lo2 = -.5
hi2 = .5

W_obs = npr.uniform(low=lo1, high = hi1, size=W.shape)
B_obs = npr.uniform(low=lo2, high= hi2, size=B.shape)
'''

W_obs = W+.4
B_obs = B-.2
C=C+.2
# Sample the vanilla RNN
x = np.zeros((T, d, 1))
x[0] = x0

for t in range(1, T):
    x[t] = np.tanh(W_obs @ x[t-1] + B_obs @ u[t]) + sigma_true * npr.randn(d,1)


ymu = C @ x[1:]+by
y = ymu + sigma_y*npr.randn(*ymu.shape)

#sine wave
#y = data[1:]


seed=499017


Gewecke Testing

In [0]:
spwl_tanh = StochPiecewiseLinearTanh(num_parts=numparts, window=window,  temp=temp)
state = State(y[:T_train-1], x[:T_train], W, B, u[:T_train], C, by)



#Burn-in period
for i in range(burn_iters):
  spwl_tanh.gewecke_step(state,sigma, sigma_y, Wbar_prior, Wbar_sigma, Wbary_prior, Wbary_sigma,train_weights)
  print(i)

#Collect samples
x_samples = np.zeros((num_iters-burn_iters,T_train,d,1))
y_samples = np.zeros((num_iters-burn_iters,T_train-1,yd,1))
Wbar_samples = np.zeros((num_iters-burn_iters,Wbar_prior.shape[0], Wbar_prior.shape[1]))
Wbary_samples = np.zeros((num_iters-burn_iters,Wbary_prior.shape[0], Wbary_prior.shape[1]))
z_samples =  np.zeros((num_iters-burn_iters,T_train-1,d,1))
for i in range(num_iters-burn_iters):
  spwl_tanh.gewecke_step(state,sigma, sigma_y, Wbar_prior, Wbar_sigma, Wbary_prior, Wbary_sigma, train_weights)
  x_samples[i] = state.x
  Wbar_samples[i] = state.get_Wbar()
  Wbary_samples[i] = state.get_Wybar()
  y_samples[i] = state.y
  z_samples[i] = state.z
  print(i)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
2352764
2352765
2352766
2352767
2352768
2352769
2352770
2352771
2352772
2352773
2352774
2352775
2352776
2352777
2352778
2352779
2352780
2352781
2352782
2352783
2352784
2352785
2352786
2352787
2352788
2352789
2352790
2352791
2352792
2352793
2352794
2352795
2352796
2352797
2352798
2352799
2352800
2352801
2352802
2352803
2352804
2352805
2352806
2352807
2352808
2352809
2352810
2352811
2352812
2352813
2352814
2352815
2352816
2352817
2352818
2352819
2352820
2352821
2352822
2352823
2352824
2352825
2352826
2352827
2352828
2352829
2352830
2352831
2352832
2352833
2352834
2352835
2352836
2352837
2352838
2352839
2352840
2352841
2352842
2352843
2352844
2352845
2352846
2352847
2352848
2352849
2352850
2352851
2352852
2352853
2352854
2352855
2352856
2352857
2352858
2352859
2352860
2352861
2352862
2352863
2352864
2352865
2352866
2352867
2352868
2352869
2352870
2352871
2352872
2352873
2352874
2352875
2352876
2352877
2352878
2352879
2352880

In [0]:
#sample from prior
x_preds = np.zeros((prior_iters, T_train, d, 1))
Wbar_preds = np.zeros((prior_iters, Wbar_prior.shape[0], Wbar_prior.shape[1]))
Wbary_preds = np.zeros((prior_iters, Wbary_prior.shape[0], Wbary_prior.shape[1]))

z_preds =  np.zeros((prior_iters,T_train-1,d,1))
y_preds =  np.zeros((prior_iters,T_train-1,yd,1))

for n in range(prior_iters):
  x_preds[n,0] = x0

  if train_weights== True:
    Wbar = Wbar_prior + Wbar_sigma*npr.randn(*Wbar_prior.shape)
    Wbar_preds[n] = Wbar
    W = Wbar[:,:d]
    B = Wbar[:,d:]

    Wbary = Wbary_prior + Wbary_sigma*npr.randn(*Wbary_prior.shape)
    Wbary_preds[n] = Wbary
    C = Wbary[:,:d]
    by = Wbary[:,d:]

  for t in range(1,T_train):
    #Using Sampled Weights
    mu_pred,z = spwl_tanh.f(W @ x_preds[n, t-1] + B @ u[t],return_z=True)
    x_preds[n, t] = mu_pred + sigma * npr.randn(*mu_pred.shape)
    z_preds[n, t-1] = z
  ymu = C @ x_preds[n,1:]+by
  y_preds[n] = ymu + sigma_y*npr.randn(*ymu.shape)



print('Prior z avg')
prior_z_avg = np.sum(z_preds,axis=0)/prior_iters
print(prior_z_avg)

print('gibbs Wbar avg')
Ez = np.sum(z_samples, axis=0)/(num_iters-burn_iters)
print(Ez) 


print('Prior Wbar avg')
prior_Wbar_avg = np.sum(Wbar_preds,axis=0)/prior_iters
print(prior_Wbar_avg)

print('gibbs Wbar avg')
EWbar = np.sum(Wbar_samples, axis=0)/(num_iters-burn_iters)
print(EWbar) 

print('Prior Wbary avg')
prior_Wbary_avg = np.sum(Wbary_preds,axis=0)/prior_iters
print(prior_Wbary_avg)

print('gibbs Wbaryy avg')
EWbary = np.sum(Wbary_samples, axis=0)/(num_iters-burn_iters)
print(EWbary) 

In [0]:
#Pairwise Plots of Gewecke x Results
xsamples = x_samples[:,1:,:,0]
print(xsamples.shape)
xpreds = x_preds[:,1:,:,0]
print(xpreds.shape)
xsamples = xsamples.reshape(-1, xsamples.shape[1]*xsamples.shape[2])
xpreds = xpreds.reshape(-1, xpreds.shape[1]*x_preds.shape[2])


df = pd.DataFrame(xpreds[:,1:])
col = ['prior']*prior_iters
df[T-1]= col

df2 = pd.DataFrame(xsamples[:,1:])
col = ['gibbs']*(num_iters-burn_iters)
df2[T-1] = col 

df3 = df.append(df2)

sns.pairplot(df3, hue=T-1, diag_kind='kde')

In [0]:
#Pairwise Plots of Gewecke y Results
ysamples = y_samples[:,:,:,0]
print(ysamples.shape)
ypreds = y_preds[:,:,:,0]
print(ypreds.shape)
ysamples = ysamples.reshape(-1, ysamples.shape[1]*ysamples.shape[2])
ypreds = ypreds.reshape(-1, ypreds.shape[1]*y_preds.shape[2])


df = pd.DataFrame(ypreds)
col = ['prior']*prior_iters
df[T-1]= col

df2 = pd.DataFrame(ysamples)
col = ['gibbs']*(num_iters-burn_iters)
df2[T-1] = col 

df3 = df.append(df2)

sns.pairplot(df3, hue=T-1, diag_kind='kde')

In [0]:
#Pairwise Plots of Gewecke x Results
xsamples = x_samples[:,1:,:,0]
print(xsamples.shape)
xpreds = x_preds[:,1:,:,0]
print(xpreds.shape)
xsamples = xsamples.reshape(-1, xsamples.shape[1]*xsamples.shape[2])
xpreds = xpreds.reshape(-1, xpreds.shape[1]*x_preds.shape[2])


df = pd.DataFrame(xpreds[:,1:])
col = ['prior']*prior_iters
df[T-1]= col

df2 = pd.DataFrame(xsamples[:,1:])
col = ['gibbs']*(num_iters-burn_iters)
df2[T-1] = col 

df3 = df.append(df2)

sns.pairplot(df3, hue=T-1, diag_kind='kde')

In [0]:
#Pairwise Plots of Gewecke y Results
ysamples = y_samples[:,:,:,0]
print(ysamples.shape)
ypreds = y_preds[:,:,:,0]
print(ypreds.shape)
ysamples = ysamples.reshape(-1, ysamples.shape[1]*ysamples.shape[2])
ypreds = ypreds.reshape(-1, ypreds.shape[1]*y_preds.shape[2])


df = pd.DataFrame(ypreds)
col = ['prior']*prior_iters
df[T-1]= col

df2 = pd.DataFrame(ysamples)
col = ['gibbs']*(num_iters-burn_iters)
df2[T-1] = col 

df3 = df.append(df2)

sns.pairplot(df3, hue=T-1, diag_kind='kde')

In [28]:
#Pairwise Plots of Gewecke Results
xsamples = x_samples[:,1:,:,0]
print(xsamples.shape)
xpreds = x_preds[:,1:,:,0]
print(xpreds.shape)
xsamples = xsamples.reshape(-1, xsamples.shape[1]*xsamples.shape[2])
xpreds = xpreds.reshape(-1, xpreds.shape[1]*x_preds.shape[2])


df = pd.DataFrame(xpreds[:,1:])
col = ['prior']*prior_iters
df[T-1]= col

df2 = pd.DataFrame(xsamples[:,1:])
col = ['gibbs']*(num_iters-burn_iters)
df2[T-1] = col 

df3 = df.append(df2)

sns.pairplot(df3, hue=T-1, diag_kind='kde')

Output hidden; open in https://colab.research.google.com to view.