# Maximum Mean Discrepancy for Ontology Maintenance
In this notebook, the MMD can be calculated for sentences containing two similar concepts to determine whether or not to merge them in the ontology. Note for users: The MMD score calculation requires 40-45 GB RAM. 

### Calculation can be completed with virtual machine providers such as Google cloud.



In [None]:
# ! gcloud init

Imports

In [1]:
import numpy as np
import os
import pandas as pd
from ast import literal_eval
import json
import torch
import torch.nn as nn

## Example Data Input

DataFrame for Term 1 including embeddings and term *name*:

In [None]:
tachycardia = Data_term1_2[Data_term1_2['allterms'] == 'tachycardia']
tachycardia

Unnamed: 0.1,Unnamed: 0,all_asjc,broad_asjc_exploded,docId,embedding,issn,label,offset,pos,term,allterms
0,0,['2705' '2737'],27,S1547527118309196,"[0.14747348427772522,0.1459769904613495,-0.511...",15475271,1,258,NN,tachycardia,tachycardia
1,1,['2705'],27,S2352906714000700,"[0.1062718853354454,-0.439564049243927,0.05402...",23529067,1,90,NN,tachycardia,tachycardia
2,2,['2711'],27,S0196064418301707,"[-0.2664773464202881,-0.1106829047203064,-0.19...",01960644,1,178,NN,tachycardia,tachycardia
3,3,['2706'],27,S0883944117310675,"[0.038361646234989166,-0.36324775218963623,-0....",08839441,1,362,NN,tachycardia,tachycardia
4,4,['2705' '2737'],27,S1547527114011205,"[0.09001897275447845,0.022551456466317177,-0.3...",15475271,1,225,NN,tachycardia,tachycardia
...,...,...,...,...,...,...,...,...,...,...,...
89995,89995,['2705' '2741'],27,S1936878X10001300,"[0.24832430481910706,-0.04433631896972656,-0.2...",1936878X,1,131,NN,tachycardia,tachycardia
89996,89996,['2705' '2737'],27,S1547527106000671,"[-0.047779105603694916,-0.3508870005607605,-0....",15475271,1,84,NN,tachycardia,tachycardia
89997,89997,['2705'],27,S073510971831934X,"[0.2683626711368561,0.24180784821510315,0.1825...",07351097,1,63,NN,tachycardia,tachycardia
89998,89998,['2705' '2737'],27,S1547527109000022,"[0.09048610180616379,-0.2487306296825409,0.066...",15475271,1,52,NN,tachycardia,tachycardia


DataFrame for Term 2 including embeddings and term *name*:

In [None]:
tachyarrhythmia = Data_term1_2[Data_term1_2['allterms'] == 'tachyarrhythmia']
tachyarrhythmia

Unnamed: 0.1,Unnamed: 0,all_asjc,broad_asjc_exploded,docId,embedding,issn,label,offset,pos,term,allterms
90000,0,['2705' '2737'],27,S0972629216307409,"[0.15877412259578705,0.20141197741031647,0.412...",09726292,1,87,NN,tachyarrhythmia,tachyarrhythmia
90001,1,['2705'],27,S016752731631213X,"[-0.05676533281803131,-0.1182749941945076,0.08...",01675273,1,223,NN,tachyarrhythmia,tachyarrhythmia
90002,2,['2705' '2740' '2746'],27,S0022522317309182,"[-0.25963646173477173,0.3415696322917938,-0.35...",00225223,1,77,NN,tachyarrhythmia,tachyarrhythmia
90003,3,['2705' '2706' '2740'],27,S0012369215513500,"[-0.20454366505146027,-0.14279726147651672,-0....",00123692,1,178,NN,tachyarrhythmia,tachyarrhythmia
90004,4,['2705'],27,S0167527315008529,"[-0.42706137895584106,-0.2416134476661682,-0.2...",01675273,1,180,NN,tachyarrhythmia,tachyarrhythmia
...,...,...,...,...,...,...,...,...,...,...,...
99909,9909,['3004'],30,S0014299907003032,"[-0.2808833718299866,-0.1459818333387375,0.140...",00142999,0,47,NN,tachyarrhythmia,tachyarrhythmia
99910,9910,['3004' '3005'],30,S0041008X05002371,"[0.2768738269805908,-0.05070939660072327,0.059...",0041008X,0,113,NN,tachyarrhythmia,tachyarrhythmia
99911,9911,['3004'],30,S0014299913006626,"[-0.021888669580221176,-0.2529616951942444,0.0...",00142999,0,276,NN,tachyarrhythmia,tachyarrhythmia
99912,9912,['1303' '1304' '1307' '1312'],13,S0006291X06018699,"[-0.2943525016307831,-0.16969314217567444,0.45...",0006291X,0,5,NN,tachyarrhythmia,tachyarrhythmia


## Parsing the data

In [None]:
def safe_parse(x):
    try:
        # return np.array(literal_eval(x))
        return json.loads(x)
    except (SyntaxError, ValueError):
        print(x)
        return np.nan # replace with any suitable placeholder value

Suggested Sample Size for MMD Calculation is 1000

In [None]:
SAMPLE_SIZE = 1000

Example output for parsed data, should include embeddings per concept

In [None]:
tachycardia_sample = tachycardia.sample(n=SAMPLE_SIZE)
tachycardia_sample["embedding"] = tachycardia["embedding"].apply(safe_parse)
tachyarrhythmia_sample = tachyarrhythmia.sample(n=SAMPLE_SIZE)
tachyarrhythmia_sample["embedding"] = tachyarrhythmia["embedding"].apply(safe_parse)
print(tachycardia_sample.head())
print()

       Unnamed: 0         all_asjc  ...         term     allterms
78203       78203         ['2705']  ...  tachycardia  tachycardia
60948       60948  ['2705' '2737']  ...  tachycardia  tachycardia
29040       29040         ['2700']  ...  tachycardia  tachycardia
51006       51006         ['3004']  ...  tachycardia  tachycardia
32650       32650         ['2700']  ...  tachycardia  tachycardia

[5 rows x 11 columns]



# Class for calculating maximum mean discrepancy loss

In [None]:
class MMD_loss(nn.Module):
    def __init__(self, kernel_mul = 2.0, kernel_num = 5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        return
    def gaussian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)

        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2) 
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def forward(self, source, target):
        batch_size = int(source.size()[0])
        kernels = self.gaussian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX + YY - XY -YX)
        return loss

In [None]:
mdd_loss = MMD_loss()

# Generating vectors with desired size

In [None]:
source_np = np.array([np.array(x) for x in tachycardia_sample["embedding"].to_numpy()])
source = torch.from_numpy(source_np)
print(source.size())

target_np = np.array([np.array(x) for x in tachyarrhythmia_sample["embedding"].to_numpy()])
target = torch.from_numpy(target_np)
print(target.size())

torch.Size([1000, 768])
torch.Size([1000, 768])


Calculating MMD Score

In [None]:
mdd_loss.forward(source=source, target=target)

## Logarithmic transformation

In [None]:
import torch.nn.functional as F
from torch.distributions import log_normal

source_normal = log_normal.LogNormal(source)
print(source_normal)
target_normal = LogNormal(source)
out = F.kl_div(source, target)

TypeError: __init__() missing 1 required positional argument: 'scale'

## Printing out MMD Score

FYI the below score does not reflect the actual score for the example shown above, but instead is a score for a different set of synonyms. 

In [None]:
print(out)

tensor(-0.6618, dtype=torch.float64)
