In [36]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from collections import defaultdict
from gensim.models import doc2vec 
from tqdm import tqdm
import networkx as nx
import pandas as pd
from pylab import rcParams
import numpy as np
from sklearn.decomposition import PCA

def get_icd10_codes():
    icd10 = defaultdict()
    with open('../icd10cm_codes_2018.txt', 'r+') as f:
        for line in f.readlines():
            line = line.split()
            icd10[line[0].lower()] = ' '.join(line[1:]).lower()
        return icd10

icd10_codes = get_icd10_codes()    
    
def create_relations_plot(path_to_model, path_to_graph, glove_model, output_file, min_nodes=1):
    rcParams['figure.figsize'] = 20, 15
    icd10_codes = get_icd10_codes()
    if glove_model:
        # load glove model
        vectors = pd.read_csv(path_to_model, header=None, sep=' ')
        graph = nx.read_edgelist(path_to_graph)
        
        # filter only icd10 codes
        plot_data = vectors[vectors[0].isin(graph.nodes)].rename({0: 'word'}, axis=1)
        coords = plot_data[plot_data.columns[1:]]
    else:
        # load doc2vec model
        model = doc2vec.Doc2Vec.load(path_to_model)
        graph = nx.read_edgelist(path_to_graph)
        
        # retrieve coords for every icd10 code
        coords = [model.wv.get_vector(code) for code in graph.nodes]
        plot_data = pd.DataFrame(list(graph.nodes), columns=['word'])

    pca = PCA(n_components=2)
    plot_data = plot_data.assign(x='', y='')
    plot_data[['x', 'y']] = pca.fit_transform(coords, 2)
    
    components = list(nx.connected_component_subgraphs(graph))
    # pick only subgraphs, which contain at least 3 nodes
    big_components = list(filter(lambda x: len(x.nodes) > min_nodes, components))
    for component in big_components:
        choosen_vectors = plot_data[plot_data.word.isin(component.nodes)]
        for _, row in choosen_vectors.iterrows():
            plt.scatter(choosen_vectors.x, choosen_vectors.y, cmap=plt.get_cmap('Spectral'))
            x, y, label = row.x, row.y, row.word
            plt.annotate(
                label,
                xy=(x,y),
                xytext=(-14, 14),
                textcoords='offset points',
                bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0')
            )
    plt.savefig(output_file, format='svg', bbox_inches='tight', dpi=1200)
    

def create_histogram_of_relations_size(path_to_graph, output_file):
    rcParams['figure.figsize'] = 12, 8
    rcParams['axes.labelsize'] = 'xx-large'
    graph = nx.read_edgelist(path_to_graph)
    _, _, _ = plt.hist([len(i) for i in list(nx.connected_component_subgraphs(graph))], bins=6, range=(0, 6))
    plt.xlabel('Liczba kodów składających się na relację')
    plt.ylabel('Liczba relacji')
    plt.savefig(output_file, format='svg', bbox_inches='tight', dpi=1200)
