# Maximum Mean Discrepancy(MMD)

## GRAM matrix
By given a set of vectors $a_k$, Gram Matrix is simply calculated as:

\begin{equation}
G=A^TA=
\begin{bmatrix}
a_1\\
a_2\\
\vdots\\
a_n
\end{bmatrix}^T
\begin{bmatrix}
a_1 & a_2 & \cdots & a_n
\end{bmatrix}=
\begin{bmatrix}
a_1^Ta_1 & a_1^Ta_2 & \cdots & a_1^Ta_n \\
a_2^Ta_1 & a_2^Ta_2 & \cdots & a_2^Ta_n \\
\vdots  & \vdots  & \ddots & \vdots  \\
a_n^Ta_1 & a_n^Ta_2 & \cdots & a_n^Ta_n 
\end{bmatrix},
\end{equation}

where $A$ is a matrix whose columns are the vectors $a_k$.

## Hermitian matrix

$a_{ij}=\overline{a_{ji}}$

$A = \overline{A^T}$

Hermitian matrix is the complex extension of a real symmetric matrix (Conjugate transpose).

For complex matrix, $G=A^HA$


## Kernal function

$K(\pmb x_i, \pmb x_j) = \Phi(\pmb x_i)\Phi(\pmb x_j)$

## Code Example

In [22]:
import numpy as np
import torch

In [19]:
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    """计算Gram核矩阵
    source: sample_size_1 * feature_size 的数据
    target: sample_size_2 * feature_size 的数据
    kernel_mul: 这个概念不太清楚，感觉也是为了计算每个核的bandwith
    kernel_num: 表示的是多核的数量
    fix_sigma: 表示是否使用固定的标准差
        return: (sample_size_1 + sample_size_2) * (sample_size_1 + sample_size_2)的
                        矩阵，表达形式:
                        [   K_ss K_st
                            K_ts K_tt ]
    """
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0) # 合并在一起

    total0 = total.unsqueeze(0).expand(int(total.size(0)), \
                                       int(total.size(0)), \
                                       int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), \
                                       int(total.size(0)), \
                                       int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2) # 计算高斯核中的|x-y|

    # 计算多核中每个核的bandwidth
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]

    # 高斯核的公式，exp(-|x-y|/bandwith)
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for \
                  bandwidth_temp in bandwidth_list]

    return sum(kernel_val) # 将多个核合并在一起


In [20]:
def mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n = int(source.size()[0])
    m = int(target.size()[0])

    kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:n, :n] 
    YY = kernels[n:, n:]
    XY = kernels[:n, n:]
    YX = kernels[n:, :n]

    XX = torch.div(XX, n * n).sum(dim=1).view(1,-1)  
    XY = torch.div(XY, -n * m).sum(dim=1).view(1,-1)

    YX = torch.div(YX, -m * n).sum(dim=1).view(1,-1)
    YY = torch.div(YY, m * m).sum(dim=1).view(1,-1)
    	
    loss = (XX + XY).sum() + (YX + YY).sum()
    return loss
    

In [14]:
data_1 = torch.tensor(np.random.normal(loc=0,scale=10,size=(100,50)))

In [13]:
data_2 = torch.tensor(np.random.normal(loc=10,scale=10,size=(90,50)))

In [21]:
mmd(data_1,data_2)

tensor(1.0417, dtype=torch.float64)