In [239]:
%matplotlib widget

In [240]:
from matplotlib import pyplot as plot

In [241]:
import numpy

In [242]:
import torch
from torch import nn
from torch import distributions
from torch import optim

In [243]:
import itertools

In [244]:
def softplus(x):
    return torch.log(1+torch.exp(x))

In [259]:
class GKF(nn.Module):
    
    def __init__(self, x_dim, h_dim, log_noise_level=0.): 
        super(GKF, self).__init__()
        
        # we assume an identity covariance in p(x_t|h_t) and p(h_t|h_{t-1})
#         self.transition = nn.Linear(h_dim, h_dim)
#         self.emission = nn.Linear(h_dim, x_dim)
        
        self.transition = nn.Sequential(nn.Linear(h_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, h_dim))
        self.emission = nn.Sequential(nn.Linear(h_dim, h_dim), nn.ReLU(), nn.Linear(h_dim,x_dim))
        
        self.initial = nn.Parameter(torch.zeros(h_dim))
        
        self.h_dcov = nn.Parameter(torch.zeros(h_dim)+log_noise_level)
        self.x_dcov = nn.Parameter(torch.zeros(x_dim)+log_noise_level)
        
        self.h_dim = h_dim
        self.x_dim = x_dim
        
    def joint_likelihood(self, x, z):
        assert type(x) == list, 'x must be a list of vectors'
        assert type(z) == list, 'z must be a list of vectors'
        assert len(x)+1 == len(z), 'z must have one more element than x does'
        
        T = len(x)
        
        logp = self._compute_normal_log_p(z[0], self.initial, 1e-6+softplus(self.h_dcov))
        
        for t in range(T):
            mu_z = self.transition(z[t])
            logp += self._compute_normal_log_p(z[t+1], mu_z, 1e-6+softplus(self.h_dcov))
            mu_x = self.emission(z[t+1])
            logp += self._compute_normal_log_p(x[t], mu_x, 1e-6+softplus(self.x_dcov))
            
        return logp
    
    def emit(self, z_list, sample=False):
        x_list = []
        
        for z in z_list:
            mu_x = self.emission(z)
            if sample:
                x = distributions.normal.Normal(mu_x, 1e-6+softplus(self.x_dcov)).sample()
            else:
                x = mu_x
            x_list.append(x)
            
        return x_list
            
    def sample(self, T, z0=None):
        if z0 is None:
            z0 = distributions.normal.Normal(self.initial, 1e-6+softplus(self.h_dcov)).sample()
            
        z_list = [z0]
        z = z0
        
        for t in range(T):
            mu_z = self.transition(z)
            z = distributions.normal.Normal(mu_z, 1e-6+softplus(self.h_dcov)).sample()
            z_list.append(z)

        x_list = self.emit(z_list[1:], sample=True)
            
        return x_list, z_list
        
    def _compute_normal_log_p(self, x, mu, diag_cov):
        return distributions.normal.Normal(mu, torch.sqrt(1e-6+softplus(diag_cov))).log_prob(x).sum()
#         return (-(((x - mu) ** 2) / (diag_cov+1e-6)) - torch.log(numpy.sqrt(2.) * torch.sqrt(diag_cov+1e-6))).sum()

In [260]:
class GKFQ(nn.Module):
    def __init__(self, h_dim, T):
        super(GKFQ, self).__init__()
        
        self.diag_cov = nn.ParameterList()
        self.mean = nn.ParameterList()
        
        for t in range(T):
            self.diag_cov.append(nn.Parameter(torch.zeros(h_dim)))
            self.mean.append(nn.Parameter(torch.zeros(h_dim)))
            
        self.T = T
        self.h_dim = h_dim
        
    def sample(self):
        samples = []
        
        for t in range(self.T):
            ss = distributions.normal.Normal(self.mean[t], torch.sqrt(1e-6+softplus(self.diag_cov[t]))).rsample()
            samples.append(ss)
            
        return samples
    
    def compute_log_p(self, z_list):
        log_ps = []
        for t in range(self.T):
            log_ps.append(distributions.normal.Normal(self.mean[t], torch.sqrt(1e-6+softplus(self.diag_cov[t]))).log_prob(z_list[t]))
        return log_ps
    
    def compute_entropy(self):
        entropy = 0.
        for t in range(self.T):
            entropy += distributions.normal.Normal(self.mean[t], torch.sqrt(1e-6+softplus(self.diag_cov[t]))).entropy().sum()
        return entropy

In [261]:
x_dim = 10
h_dim = 2
T = 10

In [262]:
gkf_target = GKF(x_dim, h_dim, log_noise_level=-2.)

In [263]:
x_list, z_list = gkf_target.sample(T)

In [264]:
''' inference '''
inference_only = True

if inference_only:
    gkf_model = gkf_target
else:
    gkf_model = GKF(x_dim, h_dim)
q_model = GKFQ(h_dim, len(x_list)+1)

In [265]:
if inference_only:
    optimizer = optim.SGD(q_model.parameters(), lr=0.01)
#     optimizer = optim.Adam(q_model.parameters(), lr=.01)
else:
    optimizer = optim.SGD(itertools.chain(q_model.parameters(), gkf_model.parameters()), lr=0.01)

In [266]:
running_loss = None

n_iter = 10000
disp_int = 100
n_samples = 1
entropy_beta = 1.

for ni in range(n_iter):
    optimizer.zero_grad()
    gkf_target.zero_grad()

    loss = 0.
    for si in range(n_samples):
        z_inferred = q_model.sample()
        loss = loss - gkf_model.joint_likelihood(x_list, z_inferred)
    loss = loss / n_samples
    if running_loss is None:
        running_loss = loss
    else:
        running_loss = 0.9 * running_loss + 0.1 * loss
    if numpy.mod(ni+1, disp_int) == 0:
        print('loss ', running_loss.item())
    loss = loss - entropy_beta * q_model.compute_entropy()
    
    loss.backward()
    optimizer.step()

loss  112.23902893066406
loss  109.11211395263672
loss  112.1075210571289
loss  108.31175231933594
loss  109.36283874511719
loss  108.25534057617188
loss  111.44081115722656
loss  108.15746307373047
loss  108.10391998291016
loss  108.84854125976562
loss  108.41986846923828
loss  109.97013854980469
loss  109.0621337890625
loss  109.462646484375
loss  108.6527099609375
loss  108.2092056274414
loss  110.35690307617188
loss  107.99728393554688
loss  108.22374725341797
loss  108.29542541503906
loss  109.77043151855469
loss  108.06400299072266


KeyboardInterrupt: 

In [267]:
list(q_model.parameters())[0].grad

tensor([0., 0.])

In [268]:
q_list = q_model.mean
r_list = gkf_model.emit(q_list[1:])

In [269]:
x_list_ = numpy.array([x.numpy() for x in x_list])
z_list_ = numpy.array([z.numpy() for z in z_list])

r_list_ = numpy.array([r.detach().numpy() for r in r_list])
q_list_ = numpy.array([m.detach().numpy() for m in q_list])

plot.figure()

plot.subplot(1,2,1)
plot.plot(x_list_[:,0], x_list_[:,1], 'b-')
for i, x in enumerate(x_list_):
    plot.text(x[0], x[1], '{}'.format(i), color='b')
    
plot.plot(r_list_[:,0], r_list_[:,1], 'r--')
for i, r in enumerate(r_list_):
    plot.text(r[0], r[1], '{}'.format(i), color='r')

plot.subplot(1,2,2)
plot.plot(z_list_[:,0], z_list_[:,1], 'b-')
for i, z in enumerate(z_list_):
    plot.text(z[0], z[1], '{}'.format(i), color='b')
plot.plot(q_list_[:,0], q_list_[:,1], 'r--')
for i, q in enumerate(q_list_):
    plot.text(q[0], q[1], '{}'.format(i), color='r')

plot.show()

FigureCanvasNbAgg()

In [238]:
plot.figure()

plot.subplot(2,1,1)
plot.imshow(x_list_.T)
plot.colorbar()

plot.subplot(2,1,2)
plot.imshow(r_list_.T)
plot.colorbar()

plot.show()



FigureCanvasNbAgg()

In [323]:
plot.close('all')

In [195]:
x_list_.shape

(11, 2)