In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline

import numpy as np

In [3]:
plt.rc('lines', linewidth=2, markersize=10)
from matplotlib import rcParams
rcParams.update({'figure.autolayout': False})

In [4]:
from scem import gen, stein, kernel, ebm
from scem import util
import torch

In [5]:
seed = 13
torch.manual_seed(seed)
n = 300
neval = 1000
dx = 5
dz = 3
var = 2.
W = torch.randn([dx, dz]) / (dx * dz)**0.5

In [6]:
n_cat = 2
Z = torch.randn([n, dz])
Zeval = torch.randn([neval, dz])
X = Z@W.T + var**0.5 * torch.randn([n, dx])
Xeval = Zeval@W.T + var**0.5 * torch.randn([neval, dx])

In [7]:
# define kernels
med2 = util.pt_meddistance(X)**2
kx = kernel.KIMQ(b=-0.5, c=1, s2=med2)
#kx = kernel.PTKGauss(torch.tensor([med2]))

In [8]:
def init_weights(m):
    if type(m) == torch.nn.Linear:
        torch.nn.init.normal_(m.weight, std=0.01)
        m.bias.data.fill_(0.)

In [9]:
# q(z|x)
cs = gen.PTCSGaussLinearMean(dx, dz)
cs.apply(init_weights)

PTCSGaussLinearMean(
  (mean_fn): Linear(in_features=5, out_features=3, bias=True)
)

In [10]:
# optimizer settings
learning_rate_q = 1e-2
weight_decay_q = 1e-2
optimizer_q = torch.optim.Adam(cs.parameters(), lr=learning_rate_q,
                               weight_decay=weight_decay_q)

In [11]:
W_init = torch.ones([dx, dz])
p = ebm.PPCA(W_init, torch.tensor([1.0]))

In [12]:
# optimizer settings
learning_rate_p = 1e-2
weight_decay_p = 1e-2
optimizer_p = torch.optim.Adam(p.parameters(), lr=learning_rate_p,
                               weight_decay=weight_decay_p)

In [13]:
iter_q = 100
iter_p = 300

In [14]:
approx_score = stein.ApproximateScore(
        p.score_joint_obs, cs)
approx_score.n_sample = 500

In [15]:
true_train_loss = torch.empty([iter_p])
true_eval_loss = torch.empty([iter_p])
approx_train_loss = torch.empty([iter_p])
grad_loss_pq = torch.empty([iter_p])
params = {
    'W': torch.empty([iter_p, dx, dz]), 
    'var': torch.empty([iter_p, dx]),
}

In [16]:
def inner_loop(niter):
    for i in range(niter):
        Z = cs.sample(1, X, seed)
        med2 = util.pt_meddistance(X)**2
        kz = kernel.KIMQ(b=-0.5, c=1., s2=med2)
        Z = Z.squeeze(0)
        loss = stein.kcsd_ustat(
            X, Z, p.score_joint_latent, kx, kz)
        optimizer_q.zero_grad()
        loss.backward()
        optimizer_q.step()

In [None]:
for t in range(iter_p):
    inner_loop(iter_q)
    loss = stein.ksd_ustat(X, approx_score, kx)
    #ksd_true = stein.ksd_ustat(X, p.score_marginal_obs, kx)
    #ksd_eval_true = stein.ksd_ustat(Xeval, p.score_marginal_obs, kx)
    
#     true_train_loss[t] = ksd_true
#     approx_train_loss[t] = loss
#     true_eval_loss[t] = ksd_eval_true
    
    if (t%10 == 0):
        print(loss)# , ksd_true, ksd_eval_true)
    
    optimizer_p.zero_grad()
    loss.backward()
    optimizer_p.step()

    params['W'][t] = p.weight.clone()
    params['var'][t] = p.var.clone()

In [None]:
plt.plot(np.arange(iter_p), true_eval_loss[:iter_p].detach().numpy(), '-', label='true KSD test')
plt.plot(np.arange(iter_p), true_train_loss[:iter_p].detach().numpy(), '-.',label='true KSD train')
plt.plot(np.arange(iter_p), approx_train_loss[:iter_p].detach().numpy(), '--', label='approx KSD train')
plt.xlabel('iteration')
plt.ylabel('loss')
plt.legend(bbox_to_anchor=(1.75, 1), loc='upper right')

In [None]:
p_weight_history = np.array([params['W'][i].detach().numpy() for i in range(iter_p)])