# Comparaison des distances inter-distributions
Dans ce notebook on compare selon plusieurs critères les performances entre plusieurs distances (KLD, MMD-RBF, MMD_IRQ)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from math import *

def KLD(var, mu):
    return .5 * (var - 1 + mu**2 - np.log(var))

In [None]:
mu = np.linspace(0,10,6).reshape(1,6)
var = np.linspace(.1,15,1000).reshape(1000,1)

In [None]:
def KL_bench(var,mu):
    KL_bench = KLD(var,mu)
    plt.plot(var,KL_bench)
    plt.xlabel("var")
    plt.ylabel("KLD")

In [None]:
def compute_kernel(x, y):
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1)
    x = x.unsqueeze(1) # (x_size, 1, dim)
    y = y.unsqueeze(0) # (1, y_size, dim)
    tiled_x = x.expand(x_size, y_size, dim)
    tiled_y = y.expand(x_size, y_size, dim)
    kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim)
    return torch.exp(-kernel_input) # (x_size, y_size)

def compute_mmd(x, y):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean()
    return mmd

def RBF_bench(var,mu,n):
    var = torch.from_numpy(var).float().squeeze(1)
    mu  = torch.from_numpy(mu).float().squeeze(0)
    
    RBF = torch.zeros([len(mu),len(var)])
    
    for i in range(len(mu)):
        for j in range(len(var)):
            sample = torch.randn(n)*torch.sqrt(var[j]) + mu[i]
            prior  = torch.randn(n)
            RBF[i,j] = compute_mmd(sample.unsqueeze(0),prior.unsqueeze(0))
    plt.plot(var.numpy(),RBF.numpy().T)
    plt.xlabel("var")
    plt.ylabel("MMDRBF")
    plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(121)
KL_bench(var,mu)
plt.subplot(122)
RBF_bench(var,mu,10000)