In [27]:
import sys
import os
import torch
import argparse
import pyhocon
import random
import pandas as pd
import pickle
from src.dataCenter import *
from src.utils import *
from src.models import *

In [106]:
outpath = './embs/icd10_node_embs_mean.csv'

args_seed = 824
args_config = './src/experiments.conf'
args_dataSet = 'icd10'
args_gcn = False
args_agg_func = 'MEAN'
args_b_sz = 20
args_unsup_loss = 'normal'
args_epochs = 100
args_patience = 10

In [2]:
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    else:
        device_id = torch.cuda.current_device()
        print('using device', device_id, torch.cuda.get_device_name(device_id))


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('DEVICE:', device)

DEVICE: cpu


In [4]:
random.seed(args_seed)
np.random.seed(args_seed)
torch.manual_seed(args_seed)
torch.cuda.manual_seed_all(args_seed)
config = pyhocon.ConfigFactory.parse_file(args_config)

In [5]:
ds = args_dataSet
dataCenter = DataCenter(config)
dataCenter.load_dataSet(ds)
features = torch.FloatTensor(getattr(dataCenter, ds+'_feats')).to(device)

In [6]:
graphSage = GraphSage(config['setting.num_layers'], features.size(1), 
                      config['setting.hidden_emb_size'], features, 
                      getattr(dataCenter, ds+'_adj_lists'), 
                      device, gcn=args_gcn, agg_func=args_agg_func)
graphSage.to(device)


GraphSage(
  (sage_layer1): SageLayer()
  (sage_layer2): SageLayer()
)

In [7]:
model = GraphSage(config['setting.num_layers'], features.size(1), 
                      config['setting.hidden_emb_size'], features, 
                      getattr(dataCenter, ds+'_adj_lists'), 
                      device, gcn=args_gcn, agg_func=args_agg_func)

2
64
128
defaultdict(<class 'set'>, {3797: {1, 386, 265, 525, 781, 910, 553, 171, 944, 436, 327, 72, 584, 476, 734, 356, 871, 366, 625, 115, 12542, 895}, 12542: {4261, 4262, 4263, 4264, 11205, 11339, 11340, 11341, 11342, 11343, 11344, 11345, 11346, 11347, 11348, 3797, 11349, 11350, 11104, 11105, 11106, 10217}, 1: {0, 35, 68, 5, 11, 43, 17, 3797, 54, 24, 61}, 0: {1, 2, 3, 4}, 2: {0}, 3: {0}, 4: {0}, 5: {1, 6, 7, 8, 9, 10}, 6: {5}, 7: {5}, 8: {5}, 9: {5}, 10: {5}, 11: {1, 12, 13, 14, 15, 16}, 12: {11}, 13: {11}, 14: {11}, 15: {11}, 16: {11}, 17: {1, 18, 19, 20, 21, 22, 23}, 18: {17}, 19: {17}, 20: {17}, 21: {17}, 22: {17}, 23: {17}, 24: {32, 1, 33, 34, 25, 26, 27, 28, 29, 30, 31}, 25: {24}, 26: {24}, 27: {24}, 28: {24}, 29: {24}, 30: {24}, 31: {24}, 32: {24}, 33: {24}, 34: {24}, 35: {1, 36, 37, 38, 39, 40, 41, 42}, 36: {35}, 37: {35}, 38: {35}, 39: {35}, 40: {35}, 41: {35}, 42: {35}, 43: {1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53}, 44: {43}, 45: {43}, 46: {43}, 47: {43}, 48: {43}, 49: {4

In [8]:
unsupervised_loss = UnsupervisedLoss(getattr(dataCenter, ds+'_adj_lists'), getattr(dataCenter, ds+'_train'), device)

In [9]:
for epoch in range(args_epochs):
    print('----------------------EPOCH %d-----------------------' % epoch)
    graphSage, loss = apply_unsup_training(dataCenter, ds, graphSage, unsupervised_loss, args_b_sz, args_unsup_loss, device)

    if loss < min_loss:
        min_loss = loss
        early_stop = 0
        torch.save(graphSage.state_dict(), f'./models/unsup_graphsage_epoch_{epoch}')
        print(f"Saved model state at epoch {epoch}: f'./models/unsup_graphsage_epoch_{epoch}")
    else:
        early_stop += 1

    if early_stop >= args_patience:
        break

In [98]:
model = GraphSage(config['setting.num_layers'], features.size(1), 
                  config['setting.hidden_emb_size'], features, 
                  getattr(dataCenter, ds+'_adj_lists'), 
                  device, gcn=args_gcn, agg_func=args_agg_func)

In [99]:
model.load_state_dict(torch.load(f'./models/unsup_graphsage_epoch_{0}'))

<All keys matched successfully>

In [100]:
train_nodes = getattr(dataCenter, ds+'_train')
print(f"Min: {train_nodes.min()}")
print(f"Max: {train_nodes.max()}")
print(f"Num nodes: {train_nodes.shape}")

Min: 0
Max: 12542
Num nodes: (12543,)


### Generate Embeddings for all Nodes in Graph

In [101]:
with open(r"icd10-data/encdmapper.pickle", "rb") as input_file:
    encd_mapper = pickle.load(input_file)
node_lst = []
encd_lst = []
for k, v in encd_mapper.items():
    node_lst.append(k)
    encd_lst.append(v)

df_nodes = pd.DataFrame({'icd': node_lst, 'encd': encd_lst})

# Load leaf nodes
leaf_nodes = pd.read_csv('icd10-data/lbls.csv')['icd'].tolist()

df_nodes['leaf'] = 0
df_nodes.loc[(df_nodes['icd'].isin(leaf_nodes)), 'leaf'] = 1

print(f"Shape: {df_nodes.shape}")
df_nodes.head()

Shape: (12543, 3)


Unnamed: 0,icd,encd,leaf
0,A00,0,0
1,A00-A09,1,0
2,A00.0,2,1
3,A00.1,3,1
4,A00.9,4,1


In [102]:
# Generate embedding for all nodes in graph
emb_arr = model(df_nodes['encd'].tolist()).detach().numpy()
print(emb_arr.shape)

df_emb = pd.DataFrame(emb_arr)
df_emb.columns = [f'emb_{i}' for i in df_emb.columns]
print(df_emb.shape)
df_emb.head()

(12543, 128)
(12543, 128)


Unnamed: 0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_118,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127
0,0.0,0.646565,0.0,0.143524,0.315669,0.357009,0.0,0.0,0.0,0.548923,...,0.095369,0.356,0.843799,0.0,0.0,0.821933,0.0,0.568395,0.0,0.0
1,0.0,0.551231,0.0,0.120007,0.0,0.200657,0.0,0.0,0.0,0.293617,...,0.0,0.0,0.352685,0.0,0.0,0.373554,0.0,0.405693,0.0,0.0
2,0.0,0.859074,0.0,0.247055,0.0,0.35249,0.0,0.0,0.0,0.476938,...,0.0,0.172888,0.170622,0.0,0.0,0.45812,0.0,0.67983,0.0,0.0
3,0.0,0.859296,0.0,0.24681,0.0,0.352816,0.0,0.0,0.0,0.476838,...,0.0,0.172619,0.170292,0.0,0.0,0.458356,0.0,0.679666,0.0,0.0
4,0.0,1.070201,0.0,0.159215,0.0,0.358874,0.0,0.0,0.077638,0.396076,...,0.0,0.0,0.385337,0.0,0.0,0.598882,0.0,0.858127,0.0,0.0


In [103]:
print(df_nodes.shape)
print(df_emb.shape)
df_nodes = pd.concat([df_nodes, df_emb], axis=1)
print(df_nodes.shape)
if len(list(df_nodes.columns)) != len(set(df_nodes.columns)):
    print("Error: Duplicated columns in DataFrame.")
df_nodes.head()

(12543, 3)
(12543, 128)
(12543, 131)


Unnamed: 0,icd,encd,leaf,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_118,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127
0,A00,0,0,0.0,0.646565,0.0,0.143524,0.315669,0.357009,0.0,...,0.095369,0.356,0.843799,0.0,0.0,0.821933,0.0,0.568395,0.0,0.0
1,A00-A09,1,0,0.0,0.551231,0.0,0.120007,0.0,0.200657,0.0,...,0.0,0.0,0.352685,0.0,0.0,0.373554,0.0,0.405693,0.0,0.0
2,A00.0,2,1,0.0,0.859074,0.0,0.247055,0.0,0.35249,0.0,...,0.0,0.172888,0.170622,0.0,0.0,0.45812,0.0,0.67983,0.0,0.0
3,A00.1,3,1,0.0,0.859296,0.0,0.24681,0.0,0.352816,0.0,...,0.0,0.172619,0.170292,0.0,0.0,0.458356,0.0,0.679666,0.0,0.0
4,A00.9,4,1,0.0,1.070201,0.0,0.159215,0.0,0.358874,0.0,...,0.0,0.0,0.385337,0.0,0.0,0.598882,0.0,0.858127,0.0,0.0


### Save Results

In [108]:
df_nodes.to_csv(outpath, index=False)
print(f"Saved: {outpath}")

Saved: ./embs/icd10_node_embs.csv
