In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scem import gen, kernel, ebm, net
from scem import util
import scem.loss as scem_loss

In [None]:
import torch
import torch.nn as nn
import torch.distributions as dists

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.rc('lines', linewidth=2, markersize=10)
from matplotlib import rcParams
rcParams.update({'figure.autolayout': False})

In [None]:
torch.set_default_dtype(torch.double)
torch.manual_seed(13)

In [None]:
class CategoricalMixture(nn.Module):
    def __init__(self, dh1, dh2, dout, dnoise,
                 n_classes, n_logits, temperature=1.):
        super(CategoricalMixture, self).__init__()
        self.dout = dout
        self.dnoise = dnoise
        self.n_logits = n_logits
        self.n_classes = n_classes
        self.feat = net.TwoLayerFC(dnoise, dh1, dh2, dout)
        self.mlinear = net.MultipleLinear(dout, n_classes, n_logits,
                                          bias=True)
        self.temperature = temperature

    def forward(self, noise):
        return (self.feat(noise))

    def sample_noise(self, n_sample, seed=14):
        noise = torch.randn(n_sample, self.dnoise)
        return noise
    
    def in_out_shapes(self):
        return ((self.dnoise,), self.dout) 

    def sample(self, n_sample, seed=13):
        noise = self.sample_noise(n_sample, seed)
        out = self.forward(noise).relu()
        logits = self.mlinear(out)
        if self.training:
            m = dists.RelaxedOneHotCategorical(
                self.temperature,
                logits=logits,
            )
            sample = m.rsample()
            # print(sample)
            return sample
        m = dists.OneHotCategorical(logits=logits)
        sample = m.sample()
        return sample

In [None]:
class Categorical(nn.Module):
    def __init__(self, n_classes, n_logits, temperature=1.):
        super(Categorical, self).__init__()
        self.n_logits = n_logits
        self.n_classes = n_classes
        self.logits = nn.Parameter(
            torch.Tensor(n_logits, n_classes))
        self.logits = nn.init.normal_(self.logits) 
        self.temperature = temperature

    def sample(self, n_sample, seed=13):
        logits = self.logits
        with util.TorchSeedContext(seed):
            if self.training:
                m = dists.RelaxedOneHotCategorical(
                    self.temperature,
                    logits=logits,
                )
                sample = m.rsample([n_sample])
                # print(sample)
                return sample
            m = dists.OneHotCategorical(logits=logits)
            sample = m.sample([n_sample])
            return sample

In [None]:
class NeuralMachine(ebm.LatentEBM):
    
    var_type_latent = None
    var_type_obs = None

    def __init__(self, din, emb_d, n_classes, d1=10, d2=10):
        super(NeuralMachine, self).__init__()
        self.din = din
        self.emb_d = emb_d
        self.n_classs = n_classes
        
        self.W = nn.Parameter(
            torch.Tensor(emb_d, n_classes))
        self.W = nn.init.normal_(self.W) 
        self.lin1 = nn.Linear(emb_d*din, d1)
        self.lin2 = nn.Linear(d1, d2)
        self.lin3 = nn.Linear(d2, 1)

    def forward(self, X, Z):
        W = self.W
        Y = torch.einsum('ijk, dk->ijd', X, W, )
        Y = Y.reshape(Y.shape[0], -1)
        Y = self.lin1(Y).relu()
        Y = self.lin2(Y).relu()
        Y = self.lin3(Y).tanh().squeeze()
        return Y
    
    def score_marginal_obs(self, X):
        D = util.forward_diff_onehot(self.forward, 0,
                                     [X, None])
        return torch.exp(D) - 1.

In [None]:
dx = 50

In [None]:
blzm = NeuralMachine(din=dx, emb_d=3, n_classes=2)

In [None]:
cm = CategoricalMixture(100, 10, 10, 30, n_classes=2, n_logits=dx, temperature=1.)

In [None]:
c = Categorical(n_classes=2, n_logits=dx)

In [None]:
k = kernel.OHKGauss(2, torch.tensor(dx))

In [None]:
ksd = scem_loss.KSD(k, blzm.score_marginal_obs)
iksd = scem_loss.IncompleteKSD(k, blzm.score_marginal_obs)

In [None]:
opt_cm = torch.optim.Adam(cm.parameters(), lr=1e-3,
                          weight_decay=0.)
opt_c = torch.optim.Adam(c.parameters(), lr=1e-3,
                          weight_decay=0.)
niter = 300
n = 300
test_n = 300
batch_size = 100

In [None]:
def get_minibatch(batch_size, X, detach=True):
    perm = torch.randperm(X.shape[0]).detach()
    idx = perm[:batch_size]
    X_ = X[idx]
    if detach:
        X_ = X_.detach()
    return X_

In [None]:
losses = []
for i in range(niter):
    X = cm.sample(n, seed=0)
    # loss = ksd.loss(X)
    i1, i2 = util.sample_incomplete_ustat_batch(n, batch_size)
    loss = iksd.loss(X[i1], X[i2])
    opt_cm.zero_grad()
    loss.backward(retain_graph=False)
    opt_cm.step()   
    if i % 1 == 0:
        cm.eval()
        X_ = cm.sample(test_n, seed=i)
        test_loss = ksd.loss(X_)
        cm.train()
        print(test_loss.item())
        losses += [test_loss.item()]

In [None]:
plt.plot(losses)

In [None]:
c_losses = []
for i in range(niter):
    X = c.sample(n, seed=0)
    loss = ksd.loss(X)
    opt_c.zero_grad()
    loss.backward(retain_graph=False)
    opt_c.step()   
    if i % 1 == 0:
        c.eval()
        X_ = c.sample(test_n, seed=i)
        test_loss = ksd.loss(X_)
        c.train()
        print(test_loss.item())
        c_losses += [test_loss.item()]

In [None]:
import numpy as np
losses_ = [np.abs(l) for l in losses]
c_losses_ = [np.abs(l) for l in c_losses]

In [None]:
plt.plot(c_losses_, label='independent')
plt.plot(losses_, label='mixture')
plt.xlabel('iter')
plt.ylabel('test KSD')
plt.yscale('log')

plt.legend()