# Librerías

In [1]:
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import pandas as pd
import numpy as np
import pickle
import h5py

In [2]:
import sys
sys.path.append('../src/')

from utils import *

# Lectura de datos

In [3]:
path_distancias = '../data/correlaciones_symsim/correlaciones.pickle'
path_datos = '../data/Macosko_mouse_retina.h5'
path_out = '../data/correlaciones_macosko/'

In [4]:
with open(path_distancias, 'rb') as f:
    correlaciones_hl = pickle.load(f)

with h5py.File(path_datos) as f:
    #X = np.array(f['X'])
    y = np.array(f['Y'])

In [5]:
assert len(y) == len(correlaciones_hl)

In [5]:
#sns.heatmap(correlaciones_hl)

# Pipeline completo

In [6]:
def create_kMST(distance_matrix, k = None, threshold = 1e-5):
    if k is None:
        N = np.log(len(distance_matrix))
        k = int(np.floor(N))
    
    print(f'k = {k}')
    grafo = nx.Graph()
    nodos = range(len(distance_matrix))

    # Crear nodo inicial
    grafo.add_nodes_from(nodos)

    for i in range(len(distance_matrix)):
        for j in range(i + 1, len(distance_matrix[i])):
            peso = distance_matrix[i][j]
            if peso > threshold:
                grafo.add_edge(i, j, weight=peso)

    print(f'---> Number of edges: {grafo.number_of_edges()}')

    mst_antes = None
    # Creamos los MSTs
    for iter in range(k):
        mst_new = nx.minimum_spanning_tree(grafo)

        edges_to_remove = list(mst_new.edges)
        grafo.remove_edges_from(edges_to_remove)
        print(f'---> {iter}. Number of edges: {grafo.number_of_edges()}')

        if mst_antes is None:
            mst_antes = mst_new.copy()
        else:
            mst_new.add_edges_from(list(mst_antes.edges()))
            mst_antes = mst_new.copy()

    return mst_antes 

In [8]:
union_graph_msts = create_kMST(correlaciones_hl)

In [None]:
union_graph_msts.number_of_edges()

14693

In [None]:
with open(path_out + 'grafo_kMST_correlaciones.pickle', 'wb') as f:
    pickle.dump(union_graph_msts, f)

# Louvain sobre el kMST

In [11]:
with open('../data/correlaciones_10x/grafo_kMST_correlaciones.pickle', 'rb') as f:
    mst = pickle.load(f) 
    
with h5py.File('../data/10X_PBMC_select_2100.h5') as f:
    X = np.array(f['X'])
    y = np.array(f['Y'])

In [14]:
assert mst.number_of_nodes() == len(y)

In [17]:
particiones = nx.community.louvain_communities(mst, seed=123)

diccionario = {}
for i, conjunto in enumerate(particiones):
    for elemento in conjunto:
        diccionario[elemento] = i 

max_elemento = max(max(particiones, key = max), default=-1)
clusters = np.array([diccionario.get(i, -1) for i in range(max_elemento + 1)])
clusters

array([ 9, 12, 10, ...,  1, 16,  3])

In [19]:
len(set(clusters)), len(set(y))

(20, 8)

In [20]:
from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score

acc = round(cluster_acc(clusters,y), 3)
nmi = round(normalized_mutual_info_score(clusters,y), 3)
ari = round(adjusted_mutual_info_score(clusters,y), 3)

print(f'ACC: {acc}. NMI: {nmi}. ARI: {ari}')

ACC: 0.098. NMI: 0.018. ARI: 0.004


In [22]:
with open('../data/correlaciones_human_liver/grafo_kMST_correlaciones.pickle', 'rb') as f:
    mst = pickle.load(f) 
    
with h5py.File('../data/HumanLiver_counts_top5000.h5') as f:
    X = np.array(f['X'])
    y = np.array(f['Y'])

In [23]:
assert mst.number_of_nodes() == len(y)

In [24]:
particiones = nx.community.louvain_communities(mst, seed=123)

diccionario = {}
for i, conjunto in enumerate(particiones):
    for elemento in conjunto:
        diccionario[elemento] = i 

max_elemento = max(max(particiones, key = max), default=-1)
clusters = np.array([diccionario.get(i, -1) for i in range(max_elemento + 1)])
clusters

array([4, 3, 2, ..., 3, 4, 8])

In [25]:
len(set(clusters)), len(set(y))

(9, 11)

In [26]:
from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score

acc = round(cluster_acc(clusters,y), 3)
nmi = round(normalized_mutual_info_score(clusters,y), 3)
ari = round(adjusted_mutual_info_score(clusters,y), 3)

print(f'ACC: {acc}. NMI: {nmi}. ARI: {ari}')

ACC: 0.336. NMI: 0.105. ARI: 0.103
