In [104]:
import re
import os
import json
import pandas as pd
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, RandomSampler
from tqdm import tqdm, trange
import torch.nn.functional as F
import torch.optim as optim

import pickle
import random

In [50]:
def split_dataset(df, p=0.9):
    train_size = int(len(df)*p)
    train = df[:train_size]
    test = df[train_size:]
    train_labels = train.pop('labels')
    test_labels = test.pop('labels')
    return train, test, train_labels, test_labels

## Define simple model to do n-class classification
class NodeClassifier(nn.Module):
    def __init__(self, input_dim, num_labels):
        super(NodeClassifier, self).__init__()
        self.num_labels = num_labels
        self.input_dim = input_dim
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.input_dim, self.num_labels)
    
    def forward(self, inputs, labels=None):
        inputs = self.dropout(inputs)
        logits = self.classifier(inputs)
        
        return logits
    
class NodeClassifier2(nn.Module):
    def __init__(self, input_dim, num_labels):
        super(NodeClassifier2, self).__init__()
        self.num_labels = num_labels
        self.input_dim = input_dim
        self.dropout = nn.Dropout(0.1)
        self.layer1 = nn.Linear(self.input_dim, 64)
        self.classifier = nn.Linear(64, self.num_labels)
        self.activation = nn.ReLU()
    
    def forward(self, inputs, labels=None):
        logits = self.classifier(self.activation(self.layer1(self.dropout(inputs))))
                
        
        return logits

def evaluate(model, dataloader):
    results = {}
    
    preds = None
    out_label_ids = None
    for batch in dataloader:
        model.eval()
        inputs, labels = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            logits = model(inputs)
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = labels.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)
    preds = np.argmax(preds, axis=1)
    result = (preds==out_label_ids).mean()
    results['acc'] = result
    return results
    

In [51]:
## Set up paths
embeddings_dir = '/home/dc925/project/data/embeddings'

snomed2vec_emb_file = os.path.join(embeddings_dir, 'snomed2vec/Node2Vec/snomed2vec.txt')
cui2vec_emb_file = os.path.join(embeddings_dir, 'cui2vec/cui2vec_pretrained.csv')
kge_models = ['TransE', 'DistMult', 'SimplE', 'ComplEx', 'RotatE']
kge_models_paths = {}
for m in kge_models:
    kge_models_paths[m] = os.path.join(embeddings_dir, 'kge/{}.pkl'.format(m))

In [52]:
kge_models_paths;

{'TransE': '/home/dc925/project/data/embeddings/kge/TransE.pkl',
 'DistMult': '/home/dc925/project/data/embeddings/kge/DistMult.pkl',
 'SimplE': '/home/dc925/project/data/embeddings/kge/SimplE.pkl',
 'ComplEx': '/home/dc925/project/data/embeddings/kge/ComplEx.pkl',
 'RotatE': '/home/dc925/project/data/embeddings/kge/RotatE.pkl'}

In [53]:
## Load in mappings

# scui to cui map
scui2cui = pd.read_csv(os.path.join(embeddings_dir, 'snomed2vec/concept_maps/cui_scui.tsv'), sep='\t', header=None)
scui2cui = scui2cui[:-1]
scui2cui.columns = ['CUI', 'SCUI']
scui2cui = scui2cui.set_index('SCUI')['CUI'].to_dict()
# cui to semtype and semgroup maps
# cui2semtype = pd.read_csv('/home/dc925/project/data/embeddings/snomed2vec/concept_maps/cui_node_type.tsv', sep='\t', header= None)
# cui2semtype.columns = ['CUI', 'TYPE']
# cui2semtype = cui2semtype[-cui2semtype.duplicated()]
# cui2semtype = cui2semtype.set_index('CUI')['TYPE'].to_dict()
semantic_info = pd.read_csv('/home/dc925/project/clinical_kge/semantic_info.csv', sep='\t', index_col=0)
semantic_info = semantic_info.drop_duplicates(subset='CUI')
cui2sty = semantic_info.set_index('CUI')['STY'].to_dict()
cui2sg = semantic_info.set_index('CUI')['SemGroup'].to_dict()

  interactivity=interactivity, compiler=compiler, result=result)


In [54]:
## Load in embeddings

# load in snomed2vec
snomed2vec = {}
with open(snomed2vec_emb_file, 'r') as fin:
    for i, line in enumerate(fin):
        if i == 0:
            continue
        line = line.strip().split()
        scui = int(line[0])
        embedding = np.array(line[1:], dtype=float)
        if scui in scui2cui:
            snomed2vec[scui2cui[scui]] = embedding
snomed2vec = pd.DataFrame.from_dict(snomed2vec, orient='index')

# load in cui2vec
cui2vec = pd.read_csv(cui2vec_emb_file, index_col=0)

# load in kge
kge_embeddings = {}
for m, p in kge_models_paths.items():
    print('loading {}'.format(m))
    with open(p, 'rb') as fin:
        model = pickle.load(fin)
        embeddings = model.solver.entity_embeddings
        embeddings = pd.DataFrame(embeddings)
        embeddings['CUI'] = [model.graph.id2entity[i] for i in range(len(embeddings))]
#         model_dict = {model.graph.id2entity[i]: embeddings[i] for i in range(len(embeddings))}
        embeddings = embeddings.set_index('CUI')
    kge_embeddings[m] = embeddings

loading TransE
loading DistMult
loading SimplE
loading ComplEx
loading RotatE


In [7]:
# snomed2vec.to_csv(os.path.join(embeddings_dir, 'snomed2vec/snomed2vec.csv'), sep='\t', header=None) #this is for ease of repeated use of snomed2vec, i.e. for viz

In [55]:
## Get intersecting CUIs and subset

cuis_intersection = list(set(kge_embeddings['TransE'].index) & set(cui2vec.index) & set(snomed2vec.index) - set(['C0015919']))
random.seed(42)
random.shuffle(cuis_intersection)

snomed2vec = snomed2vec.loc[cuis_intersection]
cui2vec = cui2vec.loc[cuis_intersection]
for m in kge_models:
    kge_embeddings[m] = kge_embeddings[m].loc[cuis_intersection]
    


In [61]:
# saving each embedding file w identical concepts list
snomed2vec.to_csv('notebooks/embeddings_for_bootstrapping/snomed2vec.csv', sep=',')

In [62]:
cui2vec.to_csv('notebooks/embeddings_for_bootstrapping/cui2vec_pretrained.csv', sep=',')

In [64]:
for model in kge_models:
    kge_embeddings[model].to_csv('notebooks/embeddings_for_bootstrapping/{}.csv'.format(model), sep=',')

In [None]:
## Get corresponding semantic type labels
# labels = [cui2sg[cui] for cui in cuis_intersection]
labels = [cui2sty[cui] for cui in cuis_intersection]
label_map = {label: i for i, label in enumerate(np.unique(labels))}

models = {m:kge_embeddings[m] for m in kge_models}
models['snomed2vec'] = snomed2vec
models['cui2vec'] = cui2vec
for name, model in models.items():
    model['labels'] = labels

In [88]:
models['TransE']['labels'].value_counts()

Disease or Syndrome                        6857
Organic Chemical                           3558
Finding                                    3257
Amino Acid, Peptide, or Protein            1970
Body Part, Organ, or Organ Component       1874
Injury or Poisoning                        1862
Therapeutic or Preventive Procedure        1814
Neoplastic Process                         1672
Pathologic Function                        1246
Congenital Abnormality                      797
Sign or Symptom                             751
Laboratory Procedure                        554
Mental or Behavioral Dysfunction            553
Diagnostic Procedure                        513
Pharmacologic Substance                     415
Body Location or Region                     311
Acquired Abnormality                        302
Body Space or Junction                      256
Anatomical Abnormality                      256
Health Care Activity                        229
Body Substance                          

In [65]:
cui2vec

Unnamed: 0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V491,V492,V493,V494,V495,V496,V497,V498,V499,V500
C0443935,-0.017254,0.008221,-8.413409e-17,-0.011074,-0.002464,0.016044,0.003069,0.011504,0.003296,1.561251e-17,...,0.001396,0.062629,0.055867,-0.009603,-0.084995,-0.025543,-0.024957,0.044795,0.057199,-0.036408
C0228201,-0.001670,0.001422,1.116728e-17,-0.001626,0.002968,0.001928,0.003999,-0.003508,0.009028,1.375310e-16,...,0.004589,0.042257,0.057367,0.032100,-0.068330,-0.052230,0.027640,0.053433,0.014709,-0.002867
C0269672,-0.180031,-0.192788,-4.732326e-15,-1.444332,0.250273,-0.728043,-1.311989,0.328245,0.045558,2.144118e-15,...,0.014824,-0.004525,0.074014,0.017153,-0.087824,0.038558,-0.001647,0.224950,-0.013966,0.105538
C0010042,-0.041295,0.034836,-1.966201e-16,-0.031706,0.025178,0.082290,0.040825,-0.008587,0.056868,1.278708e-15,...,0.070307,0.131027,0.136996,0.025898,0.055677,-0.006197,0.030585,-0.003776,-0.041328,0.079009
C0496763,-1.571538,-0.800497,1.998401e-15,-0.071885,-0.783863,0.706583,-0.033649,0.108674,0.357236,-6.661338e-16,...,-0.207904,-0.474824,0.452711,0.058087,-0.683948,0.129327,-0.169678,0.776891,-0.183082,-0.159835
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
C0023755,-0.002035,0.001601,-1.338990e-17,-0.003740,0.000244,0.002646,0.004495,0.000083,0.006649,2.406929e-17,...,-0.031750,-0.023817,-0.059648,0.062411,0.025560,-0.023620,0.038617,0.001838,0.045233,-0.013449
C0936139,-0.011319,0.003691,-9.540979e-18,-0.000172,0.006309,0.007840,0.000143,-0.000045,-0.000561,0.000000e+00,...,-0.052012,-0.033416,-0.025539,0.022128,0.047825,-0.001195,0.025920,-0.040574,-0.029975,-0.014550
C0030569,-4.640131,2.663372,-1.137979e-15,1.207900,0.207284,0.580321,-0.928158,-0.411286,0.519693,1.443290e-14,...,-0.573119,0.190212,-0.192988,0.162159,-0.205319,-0.202873,0.262925,0.152719,0.227312,0.403281
C0018546,-0.043551,0.016760,-1.734723e-17,0.014421,-0.002491,-0.028119,0.003232,-0.020914,0.051351,4.059253e-16,...,0.117871,0.098924,-0.005892,-0.068078,-0.125509,-0.059967,-0.050742,0.057204,0.166450,-0.100283


In [66]:
snomed2vec

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,190,191,192,193,194,195,196,197,198,199
C0443935,0.006983,0.227081,-0.228218,0.033706,0.182619,-0.018154,-0.356819,0.116498,0.122230,-0.437398,...,0.031006,0.758379,-0.012824,0.439657,0.062407,0.193092,-0.235947,-0.029845,0.121454,-0.076669
C0228201,0.524431,0.312281,-0.417633,-0.240308,-0.109762,0.459223,-0.356220,-0.023537,0.587343,-0.249630,...,0.235096,0.138988,-0.147969,-0.386817,0.018341,0.067686,-0.435164,0.064783,0.368568,-0.041101
C0269672,0.619270,-0.019828,0.011866,0.106377,-0.170606,0.310345,-0.216165,0.085807,0.237098,-0.195784,...,-0.086883,0.296947,0.006707,-0.276423,0.236906,0.256283,-0.367586,-0.270690,-0.408692,0.106499
C0010042,0.174152,-0.033354,-0.019611,-0.138394,-0.133457,0.364476,-0.167052,-0.157165,0.177113,0.200517,...,0.069982,0.170885,-0.418300,0.050114,0.165173,0.310699,-0.089806,-0.144133,0.173289,-0.200281
C0496763,0.124313,0.238585,0.115642,-0.199571,-0.351034,0.089719,-0.251198,0.150077,-0.096325,-0.606303,...,0.244622,0.189515,0.036486,-0.307186,0.370290,0.391348,-0.220769,-0.190920,0.181331,0.074503
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
C0023755,0.028931,-0.016492,-0.722396,0.114977,0.018558,-0.162352,-0.545657,-0.028831,0.217093,-0.082749,...,-0.099555,0.518899,-0.090963,-0.196097,-0.345124,0.315299,-0.202121,-0.024957,0.004487,0.033182
C0936139,-0.126591,0.267098,0.120419,0.237915,-0.361387,0.328127,-0.376621,-0.009967,0.501816,-0.346451,...,0.130774,0.503798,0.372493,-0.569864,-0.187750,0.297392,-0.034855,-0.110204,0.131759,-0.030913
C0030569,0.405239,0.218418,-0.194067,0.148784,0.157940,0.221929,-0.570809,-0.115324,0.476091,-0.386193,...,0.130107,0.151192,0.203324,-0.259831,0.017045,-0.105415,-0.147096,0.185738,0.001776,-0.080883
C0018546,0.282159,0.084547,-0.060834,0.260427,0.088809,0.543984,-0.333859,-0.239688,0.313252,-0.355661,...,0.005109,0.160503,0.530713,-0.475399,-0.599489,0.259753,-0.049854,-0.037934,0.441254,-0.153263


In [67]:
kge_embeddings['TransE']

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
CUI,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
C0443935,-0.005615,0.053907,0.064127,-0.017119,-0.012609,-0.003468,-0.030486,0.009772,-0.011773,0.035155,...,-0.027201,0.014443,0.041107,0.013423,-0.017699,-0.050566,0.031370,-0.039818,0.023471,0.046554
C0228201,0.056639,0.024973,0.043039,-0.010606,0.000297,0.028830,-0.005639,-0.012716,0.016615,0.013446,...,0.014388,-0.021831,-0.002546,0.021870,-0.042337,0.015798,-0.023749,-0.002659,0.094336,0.010797
C0269672,0.006275,0.035234,-0.005638,0.052735,0.027593,-0.034095,-0.000909,0.021540,-0.006118,-0.005323,...,-0.013453,0.033583,-0.053264,0.001736,0.028285,0.029924,-0.042194,-0.005361,-0.005018,-0.019387
C0010042,0.005165,-0.007816,0.005519,0.005191,-0.019536,-0.018363,0.034065,-0.018858,-0.015806,-0.023911,...,0.058378,-0.045572,0.005460,0.016206,-0.040299,0.009157,0.008565,0.010859,-0.001495,0.043285
C0496763,0.026411,-0.022139,-0.022034,0.004155,-0.004436,-0.003643,0.021189,0.009137,-0.017519,-0.109941,...,-0.010080,0.062235,-0.018077,0.057734,0.001609,-0.048640,-0.014018,-0.063896,-0.019638,0.014158
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
C0023755,-0.040562,-0.008575,-0.029939,-0.030448,-0.014047,-0.035917,-0.040634,-0.004045,-0.017335,0.035398,...,-0.020937,-0.002889,0.027856,-0.003465,-0.005584,0.009166,-0.024980,-0.030805,-0.022126,-0.010326
C0936139,-0.064008,0.010879,-0.021988,-0.007444,-0.018371,-0.035955,0.007486,0.021128,0.004214,0.035066,...,0.014685,0.008963,0.012840,0.004448,0.025604,-0.003902,0.024526,-0.015511,-0.048765,-0.012535
C0030569,-0.002874,-0.014771,-0.009262,0.051636,0.058948,0.006536,-0.007807,0.015736,0.027136,0.018170,...,-0.019772,-0.030501,0.002406,-0.005095,0.005002,-0.000051,-0.015331,0.002154,0.024111,0.015845
C0018546,-0.023810,-0.026229,0.087175,-0.021138,-0.013072,-0.026806,-0.016932,0.047351,-0.030469,0.048636,...,-0.000437,0.005676,0.006722,-0.008057,0.024764,0.006442,0.024149,-0.010890,0.018251,0.030575


In [9]:
# dataset = kge_embeddings['TransE']
# dataset = cui2vec
# dataset = snomed2vec



In [10]:
def run_experiment(dataset, num_epoch=20):    
    train, test, train_labels, test_labels = split_dataset(dataset)
    train_embeddings = torch.tensor(train.to_numpy()).float()
    test_embeddings = torch.tensor(test.to_numpy()).float()
    train_labels = torch.tensor([label_map[label] for label in train_labels], dtype=torch.long)
    test_labels = torch.tensor([label_map[label] for label in test_labels], dtype=torch.long)

    train_dataset = TensorDataset(train_embeddings, train_labels)
    test_dataset = TensorDataset(test_embeddings, test_labels)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=64)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=64)

    dim = train.shape[1]
    
    model = NodeClassifier(dim, 4)
    
    model.to(device)
    model.zero_grad()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    optimizer.zero_grad()
    loss_fct = nn.CrossEntropyLoss()

    for _ in range(num_epoch):
        for step, batch in enumerate(train_dataloader):
            model.train()
            
            print(batch)
            
            inputs, labels = tuple(t.to(device) for t in batch)
            logits = model(inputs)
            loss = loss_fct(logits, labels)
            print(loss)
            loss.backward()
            optimizer.step()
            model.zero_grad()

        results = evaluate(model, test_dataloader)
        print('epoch acc: {}'.format(results['acc']))
    return results



In [37]:
keep_labels = model['labels'].value_counts()[models['TransE']['labels'].value_counts()>25]

In [89]:
keep_labels

Disease or Syndrome                        6857
Organic Chemical                           3558
Finding                                    3257
Amino Acid, Peptide, or Protein            1970
Body Part, Organ, or Organ Component       1874
Injury or Poisoning                        1862
Therapeutic or Preventive Procedure        1814
Neoplastic Process                         1672
Pathologic Function                        1246
Congenital Abnormality                      797
Sign or Symptom                             751
Laboratory Procedure                        554
Mental or Behavioral Dysfunction            553
Diagnostic Procedure                        513
Pharmacologic Substance                     415
Body Location or Region                     311
Acquired Abnormality                        302
Body Space or Junction                      256
Anatomical Abnormality                      256
Health Care Activity                        229
Body Substance                          

In [48]:
set(model['labels'])

{'Acquired Abnormality',
 'Amino Acid, Peptide, or Protein',
 'Anatomical Abnormality',
 'Anatomical Structure',
 'Antibiotic',
 'Biologically Active Substance',
 'Biomedical or Dental Material',
 'Body Location or Region',
 'Body Part, Organ, or Organ Component',
 'Body Space or Junction',
 'Body Substance',
 'Body System',
 'Cell or Molecular Dysfunction',
 'Clinical Drug',
 'Congenital Abnormality',
 'Diagnostic Procedure',
 'Disease or Syndrome',
 'Educational Activity',
 'Element, Ion, or Isotope',
 'Finding',
 'Fully Formed Anatomical Structure',
 'Hazardous or Poisonous Substance',
 'Health Care Activity',
 'Hormone',
 'Immunologic Factor',
 'Indicator, Reagent, or Diagnostic Aid',
 'Injury or Poisoning',
 'Inorganic Chemical',
 'Laboratory Procedure',
 'Mental or Behavioral Dysfunction',
 'Molecular Biology Research Technique',
 'Neoplastic Process',
 'Nucleic Acid, Nucleoside, or Nucleotide',
 'Organic Chemical',
 'Pathologic Function',
 'Pharmacologic Substance',
 'Research A

In [39]:
models['TransE']

CUI
C0443935        Amino Acid, Peptide, or Protein
C0228201                 Body Space or Junction
C0269672                    Pathologic Function
C0010042    Therapeutic or Preventive Procedure
C0496763                     Neoplastic Process
                           ...                 
C0023755                       Organic Chemical
C0936139                       Organic Chemical
C0030569                    Disease or Syndrome
C0018546                       Organic Chemical
C0154685                    Disease or Syndrome
Name: labels, Length: 30095, dtype: object

In [17]:
models['snomed2vec']['labels'].value_counts()

DISO    17633
CHEM     6511
PROC     3147
ANAT     2804
OBJC        1
Name: labels, dtype: int64

In [18]:
models['cui2vec']['labels'].value_counts()

DISO    17633
CHEM     6511
PROC     3147
ANAT     2804
OBJC        1
Name: labels, dtype: int64

In [12]:
torch.cuda.set_device(0)
device = 0
for name, embeddings in models.items():
    print('running {}'.format(name))
    run_experiment(embeddings, num_epoch=20)

running TransE
[tensor([[-0.0313, -0.0153,  0.0024,  ..., -0.0529, -0.0065, -0.0062],
        [ 0.0181, -0.0050,  0.0216,  ..., -0.0054,  0.0076, -0.0274],
        [-0.0208, -0.0397,  0.0437,  ...,  0.0197,  0.0694,  0.0284],
        ...,
        [ 0.0020, -0.0048, -0.0056,  ..., -0.0421,  0.0070, -0.0259],
        [-0.0042,  0.0186,  0.0371,  ...,  0.0239,  0.0378,  0.0039],
        [ 0.0241, -0.0250,  0.0069,  ...,  0.0270, -0.0043,  0.0316]]), tensor([1, 2, 2, 2, 1, 1, 2, 1, 0, 4, 2, 1, 2, 0, 2, 1, 2, 2, 2, 2, 0, 2, 0, 2,
        1, 1, 2, 0, 2, 1, 1, 1, 0, 0, 0, 0, 4, 2, 2, 0, 1, 2, 0, 1, 0, 2, 2, 2,
        2, 2, 2, 2, 2, 1, 2, 1, 2, 4, 2, 2, 2, 2, 2, 2])]


RuntimeError: cuda runtime error (59) : device-side assert triggered at /tmp/pip-req-build-p5q91txh/aten/src/THC/THCTensorMathCompareT.cuh:69

In [1]:
"""
results: exactly the ordering of performance; cui2vec is worst, snomed2vec is better, but not as good as kge models.


"""

'\nresults: exactly the ordering of performance; cui2vec is worst, snomed2vec is better, but not as good as kge models.\n\n\n'