In [1]:
%matplotlib widget

In [2]:
from matplotlib import pyplot as plot

In [3]:
import numpy

In [4]:
import torch
from torch import nn
from torch import distributions
from torch import optim

In [326]:
import itertools

In [220]:
class GKF(nn.Module):
    
    def __init__(self, x_dim, h_dim): 
        super(GKF, self).__init__()
        
        # we assume an identity covariance in p(x_t|h_t) and p(h_t|h_{t-1})
        self.transition = nn.Linear(h_dim, h_dim)
        self.emission = nn.Linear(h_dim, x_dim)
        self.initial = nn.Parameter(torch.zeros(h_dim))
        
        self.h_dim = h_dim
        self.x_dim = x_dim
        
    def joint_likelihood(self, x, z):
        assert type(x) == list, 'x must be a list of vectors'
        assert type(z) == list, 'z must be a list of vectors'
        assert len(x)+1 == len(z), 'z must have one more element than x does'
        
        T = len(x)
        
        logp = self._compute_normal_log_p(z[0], self.initial, torch.ones(self.h_dim))
        
        for t in range(T):
            mu_z = self.transition(z[t])
            logp += self._compute_normal_log_p(z[t+1], mu_z, torch.ones(self.h_dim))
            mu_x = self.emission(z[t+1])
            logp += self._compute_normal_log_p(x[t], mu_x, torch.ones(self.x_dim))
            
        return logp
    
    def emit(self, z_list, sample=False):
        x_list = []
        
        for z in z_list:
            mu_x = self.emission(z)
            if sample:
                x = distributions.normal.Normal(mu_x, torch.ones(self.x_dim)).sample()
            else:
                x = mu_x
            x_list.append(x)
            
        return x_list
            
    def sample(self, T, z0=None):
        if z0 is None:
            z0 = distributions.normal.Normal(self.initial, torch.ones(self.h_dim)).sample()
            
        z_list = [z0]
        z = z0
        
        for t in range(T):
            mu_z = self.transition(z)
            z = distributions.normal.Normal(mu_z, torch.ones(self.h_dim)).sample()
            z_list.append(z)

        x_list = self.emit(z_list[1:], sample=True)
            
        return x_list, z_list
        
    def _compute_normal_log_p(self, x, mu, diag_cov):
        return (-(((x - mu) ** 2) / diag_cov) - torch.log(numpy.sqrt(2.) * torch.sqrt(diag_cov))).sum()

In [287]:
class GKFQ(nn.Module):
    def __init__(self, h_dim, T):
        super(GKFQ, self).__init__()
        
        self.diag_cov = nn.ParameterList()
        self.mean = nn.ParameterList()
        
        for t in range(T):
            self.diag_cov.append(nn.Parameter(torch.zeros(h_dim)))
            self.mean.append(nn.Parameter(torch.zeros(h_dim)))
            
        self.T = T
        self.h_dim = h_dim
        
    def sample(self):
        samples = []
        
        for t in range(self.T):
            ss = distributions.normal.Normal(self.mean[t], torch.sqrt(torch.exp(self.diag_cov[t]))).rsample()
            samples.append(ss)
            
        return samples
    
    def compute_log_p(self, z_list):
        log_ps = []
        for t in range(self.T):
            log_ps.append(distributions.normal.Normal(self.mean[t], torch.sqrt(torch.exp(self.diag_cov[t]))).log_prob(z_list[t]))
        return log_ps
    
    def compute_entropy(self):
        entropy = 0.
        for t in range(self.T):
            entropy += distributions.normal.Normal(self.mean[t], torch.sqrt(torch.exp(self.diag_cov[t]))).entropy().sum()
        return entropy

In [350]:
x_dim = 2
h_dim = 2
T = 15

In [351]:
gkf_target = GKF(x_dim, h_dim)

In [352]:
x_list, z_list = gkf_model.sample(T)

In [369]:
''' inference '''
gkf_model = GKF(x_dim, h_dim)
q_model = GKFQ(h_dim, len(x_list)+1)

In [370]:
optimizer = optim.SGD(itertools.chain(q_model.parameters(), gkf_model.parameters()), lr=0.01)

In [371]:
running_loss = None

n_iter = 10000
disp_int = 100
n_samples = 5
entropy_beta = 1.

for ni in range(n_iter):
    optimizer.zero_grad()

    loss = 0.
    for si in range(n_samples):
        z_inferred = q_model.sample()
        loss = loss - gkf_model.joint_likelihood(x_list, z_inferred)
    loss = loss / n_samples
    if running_loss is None:
        running_loss = loss
    else:
        running_loss = 0.9 * running_loss + 0.1 * loss
    if numpy.mod(ni+1, disp_int) == 0:
        print('loss ', running_loss.item())
    loss = loss - entropy_beta * q_model.compute_entropy()
    loss.backward()
    optimizer.step()

loss  96.9422607421875
loss  72.27911376953125
loss  63.01311111450195
loss  58.87183380126953
loss  56.976463317871094
loss  56.40394973754883
loss  54.90300369262695
loss  55.34176254272461
loss  55.50751495361328
loss  54.84132766723633
loss  55.342041015625
loss  55.46296310424805
loss  54.6760368347168
loss  54.62382888793945
loss  55.1097297668457
loss  54.8911247253418
loss  55.161338806152344
loss  54.46751022338867
loss  55.76375198364258
loss  56.20075988769531
loss  55.401947021484375
loss  55.84090042114258
loss  55.509742736816406
loss  55.34404373168945
loss  54.98444366455078
loss  54.99190139770508
loss  56.105613708496094


KeyboardInterrupt: 

In [372]:
list(q_model.parameters())[0].grad

tensor([0., 0.])

In [373]:
q_list = q_model.mean
r_list = gkf_model.emit(q_list[1:])

In [374]:
x_list_ = numpy.array([x.numpy() for x in x_list])
z_list_ = numpy.array([z.numpy() for z in z_list])

r_list_ = numpy.array([r.detach().numpy() for r in r_list])
q_list_ = numpy.array([m.detach().numpy() for m in q_list])

plot.figure()

plot.subplot(1,2,1)
plot.plot(x_list_[:,0], x_list_[:,1], 'b-')
for i, x in enumerate(x_list_):
    plot.text(x[0], x[1], '{}'.format(i), color='b')
    
plot.plot(r_list_[:,0], r_list_[:,1], 'r--')
for i, r in enumerate(r_list_):
    plot.text(r[0], r[1], '{}'.format(i), color='r')

plot.subplot(1,2,2)
plot.plot(z_list_[:,0], z_list_[:,1], 'b-')
for i, z in enumerate(z_list_):
    plot.text(z[0], z[1], '{}'.format(i), color='b')
plot.plot(q_list_[:,0], q_list_[:,1], 'r--')
for i, q in enumerate(q_list_):
    plot.text(q[0], q[1], '{}'.format(i), color='r')

plot.show()

FigureCanvasNbAgg()

In [375]:
plot.figure()

plot.subplot(2,1,1)
plot.imshow(x_list_.T)
plot.colorbar()

plot.subplot(2,1,2)
plot.imshow(r_list_.T)
plot.colorbar()

plot.show()

FigureCanvasNbAgg()

In [323]:
plot.close('all')

In [195]:
x_list_.shape

(11, 2)