In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from collections import defaultdict
from gensim.models import doc2vec 
from tqdm import tqdm
import networkx as nx
import pandas as pd
from pylab import rcParams

In [2]:
def get_icd10_codes():
    icd10 = defaultdict()
    with open('../icd10cm_codes_2018.txt', 'r+') as f:
        for line in f.readlines():
            line = line.split()
            icd10[line[0].lower()] = ' '.join(line[1:]).lower()
        return icd10

    
def create_relations_plot(path_to_model, path_to_graph, glove_model, output_file, min_nodes=1):
    rcParams['figure.figsize'] = 20, 15
    icd10_codes = get_icd10_codes()
    if glove_model:
        # load glove model
        vectors = pd.read_csv(path_to_model, header=None, names=['word', 'x', 'y'], sep=' ')
        plot_data = vectors[vectors.word.isin(icd10_codes)]
        graph = nx.read_edgelist(path_to_graph)
    else:
        # load doc2vec model
        model = doc2vec.Doc2Vec.load(path_to_model)
        graph = nx.read_edgelist(path_to_graph)
        vectors = []
        for code in graph.nodes:
            x, y = model.wv.get_vector(code)
            vectors.append([code, x, y])
        plot_data = pd.DataFrame(vectors, columns=['word', 'x', 'y'])

    components = list(nx.connected_component_subgraphs(graph))
    # pick only relations, which contain at least 3 nodes
    big_components = list(filter(lambda x: len(x.nodes) > min_nodes, components))
    for component in big_components:
        choosen_vectors = plot_data[plot_data.word.isin(component.nodes)]
        for _, row in choosen_vectors.iterrows():
            plt.scatter(choosen_vectors.x, choosen_vectors.y, cmap=plt.get_cmap('Spectral'))
            x, y, label = row.x, row.y, row.word
            plt.annotate(
                label,
                xy=(x,y),
                xytext=(-14, 14),
                textcoords='offset points',
                bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0')
            )
    plt.savefig(output_file, format='svg', bbox_inches='tight', dpi=1200)
    

def create_histogram_of_relations_size(path_to_graph, output_file):
    rcParams['figure.figsize'] = 12, 8
    rcParams['axes.labelsize'] = 'xx-large'
    graph = nx.read_edgelist(path_to_graph)
    _, _, _ = plt.hist([len(i) for i in list(nx.connected_component_subgraphs(graph))], bins=6, range=(0, 6))
    plt.xlabel('Liczba kodów składających się na relację')
    plt.ylabel('Liczba relacji')
    plt.savefig(output_file, format='svg', bbox_inches='tight', dpi=1200)


Histogram

In [None]:
plt.style.use('seaborn-deep')
rcParams['axes.labelsize'] = 'x-large'
rcParams['xtick.labelsize'] = 'large'
rcParams['ytick.labelsize'] = 'large'
rcParams['figure.figsize'] = 12, 8
glove50_2 = nx.read_edgelist('glove_analysis/dimension_2/50/graph_glove_100')
glove50_10 = nx.read_edgelist('glove_analysis/dimension_10/50/graph_glove_100')
glove150_2 = nx.read_edgelist('glove_analysis/dimension_2/150/graph_glove_100')
glove150_10 = nx.read_edgelist('glove_analysis/dimension_10/150/graph_glove_100')
graph_doc2vec_50_2 = nx.read_edgelist('doc2vec_analysis/cbow/50/graph_100')
graph_doc2vec_50_10 = nx.read_edgelist('doc2vec_analysis/dimension10/50/graph_100')
graph_doc2vec_150_2 = nx.read_edgelist('doc2vec_analysis/cbow/150/graph_100')
graph_doc2vec_150_10 = nx.read_edgelist('doc2vec_analysis/dimension10/150/graph_100')
_, _, _ = plt.hist([[len(i) for i in list(nx.connected_component_subgraphs(glove50_2))],
                    [len(i) for i in list(nx.connected_component_subgraphs(glove50_10))],
                    [len(i) for i in list(nx.connected_component_subgraphs(glove150_2))],
                    [len(i) for i in list(nx.connected_component_subgraphs(glove150_10))],
                    [len(i) for i in list(nx.connected_component_subgraphs(graph_doc2vec_50_2))],
                    [len(i) for i in list(nx.connected_component_subgraphs(graph_doc2vec_50_10))],
                    [len(i) for i in list(nx.connected_component_subgraphs(graph_doc2vec_150_2))],
                    [len(i) for i in list(nx.connected_component_subgraphs(graph_doc2vec_150_10))]],
                    range=(1, 10),
                   label=['GloVe: 50x2', 'GloVe: 50x10', 
                          'GloVe: 150 x 2', 'GloVe: 150x10',
                          'Doc2vec: 50 x 2', 'Doc2vec: 50x10', 
                          'Doc2vec: 150 x 2', 'Doc2vec: 150x10'], align='left')
_ = plt.legend(loc='upper right')
plt.xlabel('Liczba kodów składających się na relację')
plt.ylabel('Liczba relacji')
plt.savefig('hist_all.svg', format='svg', bbox_inches='tight', dpi=1200)