## Experiment with Gaussian mixtures without kernel optimisation

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import math
import torch.distributions as dists
from scem import loss, util, kernel, net, stein, cpdkernel
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
from os.path import dirname, join
import ksdmom.sampler as samp
from collections import namedtuple

In [None]:
torch.manual_seed(101)

In [None]:
results_path = './results/mixture'
problem = 'Gaussian'

In [None]:
dir_path = (join(results_path, problem))
if not os.path.exists(dir_path):
    os.makedirs(dir_path)

In [None]:
font = {
    'family' : 'serif',
#     'weight' : 'bold',
    'size'   : 24
}
# matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})

matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.usetex'] = True
plt.rc('font', **font)
plt.rc('lines', linewidth=3, markersize=10)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


### Distributions

In [None]:
class Normal:
    def __init__(self, m, s):
        self.m = m 
        self.s = s
    
    def den(self, X):
        m = self.m
        s = self.s 
        
        den = torch.exp(-torch.sum((X-m)**2, axis=1)/(2*s**2))
        den /= (2*math.pi*s**2)**(d/2)
        return den
    
    def log_den(self, X):
        m = self.m
        s = self.s
        ld = -torch.sum((X-m)**2, axis=1)/(2*s**2)
        ld -= d/2 * torch.log(torch.tensor(2*math.pi*s**2))
        return ld
    
    def score(self, X):
        m = self.m 
        s = self.s
        return -(X-m) / s**2
    
    def sample(self, n):
        m = self.m
        d = len(m)
        return m + self.s * torch.randn(n, d)
    
class MixNormal:
    def __init__(self, m1, m2, s1, s2, mweights):
        self.m1 = m1
        self.m2 = m2
        self.s1 = s1
        self.s2 = s2
        self.n1 = Normal(m1, s1)
        self.n2 = Normal(m2, s2)
        self.mweights = mweights
    
    def score(self, X):
        m1 = self.m1
        m2 = self.m2
        s1 = self.s1
        s2 = self.s2
        d = len(self.m1)
        mweights = self.mweights

        logden1 = self.n1.log_den(X)
        logden2 = self.n2.log_den(X)

        score1 = -(X - m1)/(s1**2) 
        score2 = -(X - m2)/(s2**2) 
        
        post_prob1 = 1. / (1. + (mweights[1]/mweights[0])*torch.exp(logden2-logden1))
        post_prob1 = post_prob1.unsqueeze(1)
        post_prob2 = 1. - post_prob1
        sc = post_prob1*score1 + post_prob2*score2
        return sc
    
    def den(self, X):
        m1 = self.m1
        m2 = self.m2
        s1 = self.s1
        s2 = self.s2
        d = len(self.m1)
        w = self.mweights
        den1 = torch.exp(-torch.sum((X-m1)**2, axis=1)/(2*s1**2))
        den1 /= (2*math.pi*s1**2)**(d/2)
        den2 = torch.exp(-torch.sum((X-m2)**2, axis=1)/(2*s2**2))
        den2 /= (2*math.pi*s2**2)**(d/2)
        return w[0]*den1 + w[1]*den2
    
    def log_den(self, X):
        return torch.log(self.den(X))
    
    def sample(self, n):
        m = torch.distributions.Binomial(n, torch.tensor([self.mweights[0]]))        
        n1 = int(m.sample().item())
        n2 = n - n1 
        d = len(self.m1)
        X1 = self.s1*torch.randn(n1, d) + self.m1
        X2 = self.s2*torch.randn(n2, d) + self.m2
        X = torch.cat([X1, X2])
        for _ in range(100):
            idx = torch.randperm(n)
            X = X[idx]
        return X

In [None]:
d = 5
m1 = -30*torch.ones(d)
m2 = -10*torch.ones(d)
s1 = 1.
s2 = 1.
target = MixNormal(m1, m2, s1, s2, 0.5*(torch.ones(2)))
model = MixNormal(m1, m2, s1, s2, torch.tensor([0.2, 0.8]))

### Define kernels

In [None]:
kimq  = kernel.KIMQ(b=-0.5)
loc = None
klin = kernel.KLinear(scale=1, loc=loc, bias=1)
w = kernel.MultiquadraticWeight(p=-0.5, bias=1, loc=loc)
kw = kernel.KSTWeight(w_func=w)
ktilted_lin = kernel.KSTProduct(klin, kw)
kimq_sum = kernel.KSTSumKernel([ktilted_lin, kimq])

kmat = kernel.KMatern(scale=1)
kmat_sum = kernel.KSTSumKernel([ktilted_lin, kmat])

kernels = {
#     'IMQ': kimq,
    'IMQ-sum': kimq_sum, 
    'Mat-sum': kmat_sum,
}

### Compute KSD

In [None]:
n = 500
rep = 100
n_ps = 30
ps = torch.linspace(0.0, 0.5, n_ps)
rerun = False
separate = False

In [None]:
data = np.empty([len(kernels), rep, n_ps])
losses = {
    k: loss.KSD(v, target.score) 
    for k, v in kernels.items()
}
if not separate: 
    weight = None

for l_i, (key, l) in enumerate(losses.items()):
    filename = (
        '{}_separate.npy'.format(key) if separate else
        '{}.npy'.format(key)
    )
    print(filename)
    path = os.path.join(dir_path, filename)
    if os.path.exists(path) and not rerun:
        tmp = np.load(path)
        data[l_i] = tmp
    else:
        for j, p in enumerate(ps):
            model = MixNormal(m1, m2, s1, s2, torch.tensor([p, 1-p]))
            for i in range(rep):
                if separate:
                    X1 = s1*torch.randn(n//2, d) + m1
                    X2 = s2*torch.randn(n//2, d) + m2
                    X = torch.vstack([X1, X2])
                    weight = torch.tensor([p]*(n//2) + [1-p]*(n//2))/(n//2)
                else:
                    X = model.sample(n)
                data[l_i, i, j] = l.loss(X, vstat=True, weight=weight, ignore_diag=False).detach().numpy()
        result = data[l_i]
        np.save(path, result)               

In [None]:
Format = namedtuple('Format', ['color', 'linestyle',])
label_format_tuples = {
    'IMQ': ("IMQ ", Format('C1', '-',)),
    'IMQ-sum': ('IMQ sum (lin.) $\\theta=0$', Format('C2', '--',)),
    'Mat-sum': ('Matérn sum (lin.) $\\theta=0.1$', Format('C7', '-.',)),
}

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=len(losses), sharey=True,
                         figsize=(8*len(losses), 6))

for ki, key in enumerate(losses.keys()):
    ax = axes[ki]
    ax.set_yscale('linear')

    label = ax.set_xlabel('Mixture ratio $\pi$', fontsize = 24)
    ax.xaxis.set_label_coords(0.5, -0.15)
    if ki == 0:
        label = ax.set_ylabel('$\\mathrm{KSD}(P, Q_{\pi,N})$', fontsize = 24)
        label.set_rotation(0)
        ax.yaxis.set_label_coords(-0.1, 1.05)    

#     ax.set_yticks([1e-1, 1e+0, 1e+1, 1e+2])
    ax.set_xticks([0.1*i for i in range(0, 5+1)])


    label = label_format_tuples[key][0]
    fmt = label_format_tuples[key][1]
#     ax.plot(ps, data[ki].mean(axis=0), label=label, 
#             color=fmt.color, linestyle=fmt.linestyle)
    y = data[ki].mean(axis=0)
#     y = np.percentile(data[ki], q=50, axis=0)
#     ax.plot(ps, y, color=fmt.color, label=label, linestyle=fmt.linestyle)
    error_low = (y - np.percentile(data[ki], q=5, axis=0))
    error_high = np.percentile(data[ki], q=95, axis=0)-y
    errors = np.vstack([error_low, error_high])
#     ax.errorbar(ps, y, yerr=data[ki].std(axis=0),
#     ax.errorbar(ps, y, yerr=errors,
#                 color=fmt.color, label=label, linestyle=fmt.linestyle, alpha=0.3)
    violin_parts = ax.violinplot(data[ki], ps.detach().numpy(), showmeans=True, showextrema=False, 
                                 widths=0.01)
    plt.setp(violin_parts['bodies'], facecolor=fmt.color, edgecolor='black', alpha=0.3)
    plt.setp(violin_parts['cmeans'], edgecolor=fmt.color)

plt.savefig('ksd_gauss_mixture_imq_vs_matern.pdf', bbox_inches='tight')