# Load data

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
import matplotlib.pyplot as plt
np.set_printoptions(precision=3, linewidth=120)
import sys
sys.path.append("..")
from scem import ebm, stein, kernel, util, gen
from scem.datasets import *
import matplotlib.pyplot as plt

In [2]:
dname = "banana"
p = load_data(dname, D=2, noise_std = 0.0, seed=0, itanh=False, whiten=False )

x = p.sample(1000)
x_eval = p.sample(100)

In [4]:
import torch 
import torch.nn as nn
from torch.nn.utils import spectral_norm
import numpy as np
import torch.distributions as td
torch.random.manual_seed(13)

class EBM(nn.Module):
    
    '''
    EBM 
    '''
    
    def __init__(self, Dx, Dz, Dh):
        
        super().__init__()
        
        self.layer_1 = nn.Linear(Dz+Dx, Dh)
        self.layer_2 = nn.Linear(Dh, 1)
        self.elu = nn.ELU()
    
    def forward(self, X, Z):

        XZ = torch.cat([X, Z], axis=-1)
        #h  = self.elu(self.layer_1(XZ))
        h = torch.relu(self.layer_1(XZ))
        E  = self.layer_2(h)
        return E[:,0]
    

# dimensionality of model 
Dx = 2
Dz = 2
Dh = 100

lebm = ebm.LatentEBMAdapter(EBM(Dx, Dz, Dh), var_type_obs='continuous', var_type_latent='continuous')

In [5]:
in_out_shapes = ((10,), (100,))
in_shape = in_out_shapes[0]
def noise_sampler(n_sample, n, seed=1):
    with util.TorchSeedContext(seed):
        N = torch.randn(n_sample, n, in_shape[0])
    return N

dh = 100

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.lin1 = nn.Linear(Dx+in_shape[0], dh)
        self.lin2 = nn.Linear(dh, Dz)
        self.module = nn.Sequential(
            self.lin1,
            nn.ReLU(),
            # nn.ELU(),
            self.lin2,
        )
        
    def forward(self, noise, X):
        n_sample = noise.shape[0]
        X_ = torch.stack([X]*n_sample)
        Y = torch.cat([X_, noise], axis=-1)
        return self.module.forward(Y)


In [6]:
def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

In [7]:
X = torch.as_tensor(x, dtype=torch.float32)

In [8]:
# define kernel

med2 = util.pt_meddistance(X)**2
kx = kernel.KIMQ(b=-0.5, c=1, s2=1.)
#kx = kernel.KGauss(torch.tensor([med2]))
#kx = kernel.KGauss(torch.tensor(med2))

In [9]:
# q(z|x)
#cs = gen.CSFactorisedGaussian(Dx, Dz, Dh)
cs = gen.CSNoiseTransformerAdapter(Net(), noise_sampler, in_out_shapes)

# optimizer settings
learning_rate_q = 1e-2
weight_decay_q = 0# 1e-3
optimizer_q = torch.optim.Adam(cs.parameters(), lr=learning_rate_q,
                               weight_decay=weight_decay_q)

approx_score = stein.ApproximateScore(
        lebm.score_joint_obs, cs)
approx_score.n_sample = 500

# optimizer settings for p(x)
learning_rate_p = 1e-1 # !!!
weight_decay_p = 0.
optimizer_p = torch.optim.Adam(lebm.parameters(), lr=learning_rate_p,
                               weight_decay=weight_decay_p)

In [10]:
iter_p = 500
iter_q = 10
batch_size = 100

In [11]:
def inner_loop(niter, X, cs, opt):
    for i in range(niter):
        Z = cs.sample(1, X)
        Z = Z.squeeze(0)
        zmed2 = util.pt_meddistance(Z)**2
        kz = kernel.KIMQ(b=-0.5, c=1, s2=1.)
        loss = stein.kcsd_ustat(
            X, Z, lebm.score_joint_latent, kx, kz)
        opt.zero_grad()
        loss.backward(retain_graph=False)
        opt.step()   
    #print('kcsd', loss.item())
    

In [None]:
#inner_loop(400, X)
losses = []

for t in range(iter_p):
    # reset q(z|x)'s weight
#     cs.apply(weight_reset)
#     optimizer_q = torch.optim.Adam(cs.parameters(), lr=learning_rate_q,
#                                weight_decay=weight_decay_q)
    
    perm = torch.randperm(X.shape[0]).detach()
    idx = perm[:batch_size]
    X_ = X[idx].detach()

    inner_loop(iter_q, X_, cs, optimizer_q)
    loss = stein.ksd_ustat(X_, approx_score, kx)  
    losses += [loss.item()]

    if (t%100 == 0):
        loss_ = stein.ksd_ustat(X, approx_score, kx).item()
        print(loss.item(), loss_)
    
    optimizer_p.zero_grad()
    loss.backward(retain_graph=False)
    optimizer_p.step()

0.10812857747077942 0.10262227803468704
0.03973084315657616 0.037283483892679214


In [None]:
plt.plot(losses)

In [None]:
# form a grid for numerical normalisation
from itertools import product
ngrid = 50
grid = torch.linspace(-10, 10, ngrid)
xz_eval = torch.tensor(list(product(*[grid]*4)))
x_eval = xz_eval[:,:2]
z_eval = xz_eval[:,2:]

In [None]:
# true log density
E_true = p.logpdf_multiple(torch.tensor(list(product(*[grid]*2))))
E_true -= E_true.max()

In [None]:
# EBM log density
E_eval = lebm(x_eval, z_eval).reshape(ngrid,ngrid,ngrid,ngrid).exp().detach()
E_eval /= E_eval.sum()
E_eval = E_eval.sum(-1).sum(-1)
E_eval.log_()
E_eval -= E_eval.max()
# E_eval = E_eval.sum(-1).sum(-1)

In [None]:
def normalise(E):
    if isinstance(E, np.ndarray):
        E = np.exp(E)
    else:
        E = E.exp()
    E /= E.sum()
    return E

In [None]:
fig, axes = plt.subplots(2,2,figsize=(6,6), sharex=True, sharey=True)


ax = axes[0,0]
ax.pcolor(grid, grid,E_true.reshape(ngrid,ngrid), shading='auto', vmin=-10, vmax=0)
ax.scatter(x[:,1], x[:,0], c="r", s=1, alpha=0.05)


ax = axes[1,0]
ax.pcolor(grid, grid,normalise(E_true).reshape(ngrid,ngrid), shading='auto')

ax = axes[0,1]
ax.pcolor(grid, grid,E_eval,shading='auto', vmin=-10, vmax=0, )
ax.scatter(x[:,1], x[:,0], c="r", s=1, alpha=0.05)

ax = axes[1,1]
ax.pcolor(grid, grid,normalise(E_eval),shading='auto' )
ax.scatter(x[:,1], x[:,0], c="r", s=1, alpha=0.0)



axes[0,0].set_ylabel("logp")
axes[1,0].set_ylabel("logp")

axes[0,0].set_title("data")
axes[0,1].set_title("KSD")

axes[0,0].set_xlim(-10,10)