In [1]:
from pprint import pprint
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
import numpy as np
import random
import smart_open
from collections import defaultdict
from tqdm import tqdm
from joblib import Parallel, delayed
import networkx as nx

Glove evaluation refactored

In [2]:
def generate_base(model_path, vocab_path):
    with open(vocab_path, 'r+') as f:
        words = [x.rstrip().split(' ')[0] for x in f.readlines()]
    with open(model_path, 'r') as f:
        vectors = {}
        for line in f:
            vals = line.rstrip().split(' ')
            vectors[vals[0]] = [float(x) for x in vals[1:]]

    vocab_size = len(words)
    vocab = {w: idx for idx, w in enumerate(words)}
    ivocab = {idx: w for idx, w in enumerate(words)}

    vector_dim = len(vectors[ivocab[0]])
    W = np.zeros((vocab_size, vector_dim))
    for word, v in vectors.items():
        if word == '<unk>':
            continue
        W[vocab[word], :] = v

    # normalize each word vector to unit variance
    W_norm = np.zeros(W.shape)
    d = (np.sum(W ** 2, 1) ** (0.5))
    W_norm = (W.T / d).T
    return (W_norm, vocab, ivocab)


def distance(base, input_term, topn):
    W, vocab, ivocab = base
    for idx, term in enumerate(input_term.split(' ')):
        if term in vocab:
            if idx == 0:
                vec_result = np.copy(W[vocab[term], :])
            else:
                vec_result += W[vocab[term], :]
        else:
            return # Word: Out of dictionary!

    vec_norm = np.zeros(vec_result.shape)
    d = (np.sum(vec_result ** 2,) ** (0.5))
    vec_norm = (vec_result.T / d).T

    dist = np.dot(W, vec_norm.T)

    for term in input_term.split(' '):
        index = vocab[term]
        dist[index] = -np.Inf

    a = np.argsort(-dist)[:topn]
    return [[ivocab[x], dist[x]] for x in a]

Get most similar words for code

In [3]:
def get_icd10_codes():
    icd10 = defaultdict()
    with open('../results/icd10cm_codes_2018.txt', 'r+') as f:
        for line in f.readlines():
            line = line.split()
            icd10[line[0].lower()] = ' '.join(line[1:]).lower()
        return icd10
    
    
def validate_relations(similarity_results, icd10):
    """Filter found words by cheking, whether they're valid icd10 codes."""
    found_relations = defaultdict(list)
    for reference, codes in tqdm(similarity_results.items()):
        for code, score in codes:
            if code.lower() in icd10.keys():
                found_relations[reference].append(code.lower())
    return found_relations


def describe_relations(found_relations, icd10):
    """Annotate every relation with titles of contained codes"""
    described_relations = defaultdict(list)
    for reference, relations in found_relations.items():
        for relation in relations:
            described_relations[(reference, relation)].append(
                (icd10[reference], icd10[relation]))
    return described_relations


def perform_analysis(model_path, vocab_path, topn=50):
    """Find relations between icd10 codes."""
    icd10 = get_icd10_codes()
    print('Loading vocabulary and vectors...')
    base = generate_base(model_path, vocab_path)
    print('Retrieving similar words...')
    similarity = Parallel(n_jobs=-1, backend='threading', verbose=10)(
        delayed(distance)(base, code, topn) for code in icd10.keys())
    similarity_results = defaultdict(list)
    for code, relations in zip(icd10.keys(), similarity):
        if relations:
            similarity_results[code] = relations
    print('Filtering similarities...')
    validated_relations = validate_relations(similarity_results, icd10=icd10)
    print('Describing relations...')
    return describe_relations(validated_relations, icd10=icd10)

In [4]:
results = perform_analysis('../results/glove_data/symmetric_vectors/glove_vectors50.txt', '../results/filtered_vocab_10.txt', topn=1000)

Loading vocabulary and vectors...
Retrieving similar words...


[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  33 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  53 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  64 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  77 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done  90 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done 105 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done 120 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done 137 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done 154 tasks      | elapsed:    0.8s
[Parallel(n_jobs=-1)]: Done 173 tasks      | elapse

[Parallel(n_jobs=-1)]: Done 9105 tasks      | elapsed:   25.7s
[Parallel(n_jobs=-1)]: Done 9240 tasks      | elapsed:   26.0s
[Parallel(n_jobs=-1)]: Done 9377 tasks      | elapsed:   26.0s
[Parallel(n_jobs=-1)]: Done 9514 tasks      | elapsed:   26.4s
[Parallel(n_jobs=-1)]: Done 9653 tasks      | elapsed:   26.5s
[Parallel(n_jobs=-1)]: Done 9792 tasks      | elapsed:   27.2s
[Parallel(n_jobs=-1)]: Done 9933 tasks      | elapsed:   27.4s
[Parallel(n_jobs=-1)]: Done 10074 tasks      | elapsed:   27.4s
[Parallel(n_jobs=-1)]: Done 10217 tasks      | elapsed:   28.2s
[Parallel(n_jobs=-1)]: Done 10360 tasks      | elapsed:   28.8s
[Parallel(n_jobs=-1)]: Done 10505 tasks      | elapsed:   29.6s
[Parallel(n_jobs=-1)]: Done 10650 tasks      | elapsed:   29.7s
[Parallel(n_jobs=-1)]: Done 10797 tasks      | elapsed:   30.1s
[Parallel(n_jobs=-1)]: Done 10944 tasks      | elapsed:   30.2s
[Parallel(n_jobs=-1)]: Done 11093 tasks      | elapsed:   30.7s
[Parallel(n_jobs=-1)]: Done 11242 tasks      | 

[Parallel(n_jobs=-1)]: Done 34840 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 35105 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 35370 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 35637 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 35904 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 36173 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 36442 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 36713 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 36984 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 37257 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 37530 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 37805 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 38080 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 38357 tasks      | elapsed:  1.9min
[Parallel(n_jobs=-1)]: Done 38634 tasks      | elapsed:  1.9min
[Parallel(n_jobs=-1)]: Done 38913 tasks 

Filtering similarities...


100%|██████████| 1766/1766 [00:01<00:00, 1348.81it/s]

Describing relations...





Saving relations

In [5]:
results

defaultdict(list,
            {('p0081',
              'n329'): [('newborn affected by periodontal disease in mother',
               'bladder disorder, unspecified')],
             ('n361',
              'n151'): [('urethral diverticulum',
               'renal and perinephric abscess')],
             ('r52',
              'd259'): [('pain, unspecified',
               'leiomyoma of uterus, unspecified')],
             ('k225',
              'c002'): [('diverticulum of esophagus, acquired',
               'malignant neoplasm of external lip, unspecified')],
             ('q312',
              'd210'): [('laryngeal hypoplasia',
               'benign neoplasm of connective and other soft tissue of head, face and neck')],
             ('n882',
              'r0989'): [('stricture and stenosis of cervix uteri',
               'other specified symptoms and signs involving the circulatory and respiratory systems')],
             ('k599',
              'o1205'): [('functional intestinal dis

In [6]:
with open('glove_analysis/50/found_relations_glove50_1000', 'w+') as output_file:  
    for keys, titles in results.items():
        output_file.write('{}:\n'.format(' '.join(keys)))
        for title in titles:
            output_file.write('    {}\n'.format(' || '.join(title)))

Create network for found relations

In [7]:
graph = nx.Graph()

In [8]:
graph.add_edges_from(results.keys())

In [9]:
nx.write_edgelist(graph, 'glove_analysis/50/graph_glove50_1000')

Load graph

In [10]:
graph = nx.read_edgelist('glove_analysis/50/graph_glove50_1000')

Analysing found relations

In [11]:
icd10 = get_icd10_codes()

In [12]:
def retrieve_titles_for_subgraph(graph, icd10, min_nodes=2):
    subgraphs = list(nx.connected_component_subgraphs(graph))
    results = defaultdict(list)
    for subgraph in subgraphs:
        nodes = subgraph.nodes()
        if len(nodes) >= min_nodes:
            results[', '.join(nodes.keys())] = [icd10[node] for node in nodes]
    return results

In [13]:
results_titles = retrieve_titles_for_subgraph(graph, icd10)

In [14]:
with open('glove_analysis/50/described_relations_glove50_1000', 'w+') as output_file:  
    for keys, titles in results_titles.items():
        output_file.write('{}:\n'.format(keys))
        output_file.write('{}\n'.format('\n'.join(titles)))
        output_file.write('\n')