In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import os.path as osp
import networkx as nx

import sys

In [None]:
sys.path.append('../')

In [None]:
from common_functions import eval_g_hat_with_DnX

In [None]:
p2root = '../results_0905_logging/'
K = 25
version='Rr'
iteration = 10000
iid_sample = 20000
chromo = 'chr3L'
expr_name = f'{chromo}_drosophila_ChIA_Drop_0.1_PASS'
prefix = f'{expr_name}_{iid_sample}_MCMC_pivot_K_{K}_iter_{iteration}'
# prefix = f'{expr_name}_{iid_sample}_K_{K}_iter_{iteration}'
# p2nmf_dict = osp.join(p2root, f'{prefix}_nmf_centroid_df')
# p2cmf_dict = osp.join(p2root, f'{prefix}_cmf_centroid_df')
# p2omf_dict = osp.join(p2root, f'{prefix}_omf_centroid_df')

p2ocmf_dict = osp.join(p2root, f'{version}_{prefix}_ocmf_centroid_df')
p2representative_regions = osp.join(p2root, f'{version}_{prefix}_x_hat_df')
p2A_t = osp.join(p2root, f'A_t_{version}_{prefix}_ocmf.csv')
p2W_hat = osp.join(p2root, f'W_hat_{version}_{prefix}_ocmf.csv')
p2X = osp.join(p2root, f'{prefix}_X_df')

In [None]:
X_df = pd.read_pickle(p2X)
# nmf_dict = pd.read_pickle(p2nmf_dict)
# cmf_dict = pd.read_pickle(p2cmf_dict)
# omf_dict = pd.read_pickle(p2omf_dict)
ocmf_dict = pd.read_pickle(p2ocmf_dict)
rep_region = pd.read_pickle(p2representative_regions)

In [None]:
W_hat = np.loadtxt(p2W_hat, delimiter=',')
A_t = np.loadtxt(p2A_t, delimiter=',')

In [None]:
(W_hat > 0).astype(int)

In [None]:
rep_region.label.value_counts()

In [None]:
def plot_one_row(ax, row, k = 21, threshold = 1, title = ''):
    row_np = np.array(row)
    if threshold is not None:
        bin_row = (row_np >= threshold).astype(int)
        bin_row_mat = bin_row.reshape(k, k)
    else:
        bin_row_mat = row_np.reshape(k, k)
    
    ax.imshow(bin_row_mat, cmap = 'Greys')
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    return bin_row_mat, ax

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

plot_one_row(ax, X_df.loc[5], title = 'row_5')
plt.show()
plt.close()

In [None]:
def plot_df_in_grid(df_in, grid_size = (2, 5), figsize = (10, 4), 
                    k = 21, threshold = 1, title_prefix = 'row', title = '', 
                   sub_titles = None, p2savefig = None):
    fig = plt.figure(figsize = figsize)
    axes = fig.subplots((*grid_size))
    
    for row_idx in range(grid_size[0]):
        for col_idx in range(grid_size[1]):
            ax = axes[row_idx][col_idx]
            
            idx = row_idx * grid_size[1] + col_idx
            if idx < df_in.shape[0]:
                cur_row = df_in.iloc[idx]
            else:
                cur_row = np.zeros_like(df_in.iloc[0].values)
            
            if sub_titles is None:
                subtitle =  f'{title_prefix}_{idx}'
            else:
                if idx < len(sub_titles):
                    subtitle = sub_titles[idx]
                else:
                    subtitle = ''
                
            plot_one_row(ax, cur_row, k = k, threshold= threshold, title = subtitle)
    plt.suptitle(title, y = 1.05)
    plt.tight_layout()
    
    if p2savefig is None:
        plt.show()    
        plt.close()
    else:
        plt.savefig(p2savefig)
        plt.close()

In [None]:
plot_df_in_grid(X_df.loc[np.random.choice(X_df.index, 10)], title_prefix= 'random_sample')

# online cvxMF results

In [None]:
feature = [x for x in ocmf_dict.columns if 'label' not in x]
ocmf_dict_val_df = ocmf_dict[feature]
ocmf_dict_val_df.stack().hist()

In [None]:
def get_importance_from_At(A_t):
    diag = A_t.diagonal()
    sum_diag = sum(diag)
    score = diag / sum_diag
    return score

In [None]:
importance_score = get_importance_from_At(A_t)
subtitles = [f'online cvxNDL\n dictionaries {x}\n score {importance_score[x]:.2f}' for x in range(ocmf_dict_val_df.shape[0])]

In [None]:
descending_order_of_At = importance_score.argsort()[::-1]

In [None]:
importance_score[descending_order_of_At]

In [None]:
# plot_df_in_grid(ocmf_dict_val_df, threshold = 0.6, 
#                 grid_size = (5, 5), figsize = (10, 10),
#                 sub_titles= subtitles,
#                 title = 'online cvxNDL')

## sorted dictionary by importance score

In [None]:
sorted_ocmf_dict_val_df_by_At = ocmf_dict_val_df.loc[descending_order_of_At]
sorted_subtitles_by_At = [subtitles[i] for i in descending_order_of_At]

### Binarized dictionaries

In [None]:
# p2saved = f'/data/shared/jianhao/online_cvxNDL_results/{chromo}'
# if not osp.isdir(p2saved):
#     os.makedirs(p2saved)
    
# plot_df_in_grid(sorted_ocmf_dict_val_df_by_At, threshold = 0.6, 
#                 grid_size = (5, 5), figsize = (10, 10),
#                 sub_titles= sorted_subtitles_by_At,
#                 title = 'online cvxNDL',
#                p2savefig=osp.join(p2saved, f'{chromo}_all_dictionaries_updated_0905'))

### raw dicitonaries 

In [None]:
plot_df_in_grid(sorted_ocmf_dict_val_df_by_At, threshold = None, 
                grid_size = (5, 5), figsize = (10, 10),
                sub_titles= sorted_subtitles_by_At,
                title = f'online cvxNDL: {chromo}')

## plot convex hull of representative regions

In [None]:
def plot_cvx_hull(ax, df_in, weight = None, k = 21, title = ''):
#     fig = plt.figure()
#     ax = fig.subplots()
    
    if weight is None:
        average_row = df_in.mean(axis = 0).values
    else:
        average_row = (df_in * weight.reshape(-1, 1)).sum(axis = 0).values
    bin_row_mat = average_row.reshape(k, k)

    ax.imshow(bin_row_mat, cmap = 'Greys')
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])

#     plt.show()
#     plt.close()
    return ax

In [None]:
df_W_hat = pd.DataFrame(data = W_hat, columns = [f'w_cluster_{x}' for x in range(W_hat.shape[1])])

In [None]:
df_W_hat

In [None]:
descending_order_of_At

## Plot dictionaries in importance score order. Also plot the representative regions by their weights

In [None]:
for i in descending_order_of_At:
    cur_label = f'group {i}'
    df_rep = rep_region[rep_region.label == cur_label]
    df_rep_val_df = df_rep[feature]

    dict_row = ocmf_dict[ocmf_dict.label == f'ocmf: type {i + 1}']
    dict_row = dict_row[feature]
    fig = plt.figure(figsize=(8, 40))
    ax = fig.subplots(1, 2)
    # plot the dictionary, which is also the convex centroid.
#     plot_one_row(ax[0], row = dict_row, 
#                  title = f'online cvxMF \ndictionary {i}\nscore = {importance_score[i]:.2f}', 
#                  threshold= 0.8)
    # plot the average of representatives. 
    plot_cvx_hull(ax[0], df_rep_val_df, weight = None,
                 title = f'average of \nrep region {i}')
    
    cvx_weight = df_W_hat[rep_region.label == cur_label].values[:, i]
    descending_order_cvx_weight = cvx_weight.argsort()[::-1]
    
    plot_cvx_hull(ax[1], df_rep_val_df, weight = cvx_weight,
                  title = f'convex centroid \nof rep region {i}')
#     plot_cvx_hull(ax[2], df_rep_val_df, weight = None,
#                   title = f'cvx hull \nof rep region {i}')

    plt.show()

    plt.close()
    
    sub_titles = [f'rep_{x}, \nweight = {cvx_weight[x]:.2f}' \
                  for x in descending_order_cvx_weight]
    
    # reorder df_rep by their weights
    sorted_df_rep_val_df_by_cvx_weight = df_rep_val_df.iloc[descending_order_cvx_weight]
    
    p2saved = f'/data/shared/jianhao/online_cvxNDL_results/{chromo}'
    if not osp.isdir(p2saved):
        os.makedirs(p2saved)
    plot_df_in_grid(sorted_df_rep_val_df_by_cvx_weight.head(10), threshold = 1, 
                    sub_titles = sub_titles, title = f'dictionary {i}', 
                   p2savefig= osp.join(p2saved, f'representatives_of_dictionary_{i}_updated_0905_MCMC_pivot'))
#     break
    

------
# Get subgraph node embeddings

In [None]:
expr_name, iid_sample

In [None]:
# p2raw_data = '/data/shared/jianhao/online_cvxNDL_data/new_data_with_node_id/'
# p2raw_data = '/data/shared/jianhao/online_cvxNDL_data/updated_0905_data_with_node_id/'

# p2emb = osp.join(p2raw_data, f'df_{expr_name}_{iid_sample}_sample_node_matrix')
# p2raw_data = osp.join(p2raw_data, f'df_{expr_name}_{iid_sample}')

p2raw_data = '/data/shared/jianhao/online_cvxNDL_data/updated_0905_data_with_node_id'
p2emb = osp.join(p2raw_data, f'df_{expr_name}_{iid_sample}_MCMC_pivot_sample_node_matrix')
p2raw_data = osp.join(p2raw_data, f'df_{expr_name}_{iid_sample}_MCMC_pivot')

In [None]:
df_node_all = pd.read_pickle(p2emb)
df_subgraph_all = pd.read_pickle(p2raw_data).drop(columns = ['label'])
# df_subgraph_all_old = pd.read_pickle(p2raw_data_old).drop(columns = ['label'])

In [None]:
# df_node_all = df_node_all.applymap(lambda s: s.replace('"', ''))

In [None]:
df_node_all

### check which dataset used in training.
if use old subgraph, then subgraph may not have correponding node embeddings.


In [None]:
# if (X_df.values == df_subgraph_all_old.values).all():
#     print('>>> used old dataset without node embedding!!!! Change later on.')
#     df_raw_data = df_subgraph_all_old
#     warning_flag = True
#     df_raw_data_matched_emb = df_subgraph_all
    
if (X_df.values == df_subgraph_all.values).all():
    print('>>> used new dataset with node embeddings. GOOD to go.')
    df_raw_data = df_subgraph_all
    warning_flag = False
    df_raw_data_matched_emb = df_subgraph_all

### 1. df_raw_data will always be the input data. So df_rep.isin(df_raw_data) == True. 
only case is that when used old data:

=> **df_node_emb** and **df_raw_data** does not match. 

### 2. df_raw_data_matched_emb will always matched the node embedding

In [None]:
df_raw_data.shape, df_raw_data_matched_emb.shape, df_node_all.shape

--------
# Each represetative, get subgraph nodes

In [None]:
def plot_two_rows(row1, row2, titles = ['row1', 'row2']):
    
    fig = plt.figure()
    axes = fig.subplots(1, 2)

    plot_one_row(axes[0], row1, title = titles[0], threshold= None)
    plot_one_row(axes[1], row2, title = titles[1], threshold= None)    
    
    plt.show()
    plt.close()

In [None]:
def draw_arc_of_graph(H_directed, pos, ax = None, p2savefig = None):
    if ax is None:
        fig, ax = plt.subplots(figsize = (20, 3))

    nx.draw_networkx_nodes(H_directed, pos, 
                           node_size = 100, 
                           ax = ax)
    nx.draw_networkx_edges(H_directed, pos, 
                           connectionstyle='arc3,rad=0.3', 
                           arrowsize= 0.1,
                           ax = ax)

    # skip some value of nodes that are close to each other.
    old_x_value = -100
    for value, loc in pos.items():
        x, y = loc 
        if (x - old_x_value) < 3:
            continue
        old_x_value = x
        ax.text(x, y - 20, value,
                rotation = 315)

    plt.ylim(-20, 100)
    if p2savefig is None:    
        plt.show()
    else:
        plt.savefig(p2savefig)
    return ax

In [None]:
df_W_hat.shape, rep_region.shape

In [None]:
cvx_weight

## Plot dictionary + rep + arc plot of rep.

In [None]:
for cluster in descending_order_of_At:
    cur_importance_score = importance_score[cluster]
    # get current dictionary row from for the current cluster with descending importance score.
    dict_row = ocmf_dict[ocmf_dict.label == f'ocmf: type {cluster + 1}']
    dict_row = dict_row[feature]
    
    # get representative regions for the current cluster.
    cur_label = f'group {cluster}'
    df_rep_val_df = rep_region[rep_region.label == cur_label].drop(columns = ['label'])
    
    # Find convex weight for each rep. in the corresponding column of cvx_weight, and the selected 
    # representative rows. 
    # df_W_hat is a column vector matrix. Each column is the weight for a cluster.
    cvx_weight = df_W_hat[rep_region.label == cur_label].values[:, cluster]
    descending_order_cvx_weight = cvx_weight.argsort()[::-1]
    
    # sort representative by their cvx weights. 
    sorted_df_rep_val_df_by_cvx_weight = df_rep_val_df.iloc[descending_order_cvx_weight]
    print('cluster :', cluster)
    
    for idx, (_, row) in enumerate(sorted_df_rep_val_df_by_cvx_weight.iterrows()):

        row_in = row.values
        cur_rep_idx = descending_order_cvx_weight[idx]
        cur_weight = cvx_weight[cur_rep_idx]
        
        print('-' * 7, f'cluster {cluster}, rep {cur_rep_idx}, pos in {idx}', '-' * 7)
        
        if warning_flag:
            # even though used old dataset, try to find if matching in new dataset exists.
            # always true: df_raw_data == X_df != df_raw_data_matched_emb
            sample_idx_in_subgraphs = (df_raw_data_matched_emb == row_in).all(axis = 1)
            samples_from_df_raw_data = df_raw_data_matched_emb.loc[sample_idx_in_subgraphs]
            
            if len(samples_from_df_raw_data) == 0:
#                 print('no match of row adj in new dataset with embedding. Use OLD dataset')
                sample_idx_in_subgraphs = (df_raw_data == row_in).all(axis = 1)
                samples_from_df_raw_data = df_raw_data.loc[sample_idx_in_subgraphs]
                
            
        else:
            # no warning sign. we are using new dataset so there is no error.
            # df_raw_data == X_df == df_raw_data_matched_emb.
            sample_idx_in_subgraphs = (df_raw_data == row_in).all(axis = 1)
            samples_from_df_raw_data = df_raw_data.loc[sample_idx_in_subgraphs]
        
        # In case there are multiple samples with same weight matrices, select the first one.
        samples_from_df_raw_data = samples_from_df_raw_data.iloc[0]
        
# ----------
#         tmp_samples_from_df_raw_data_matched_emb = df_raw_data_matched_emb.loc[sample_idx_in_subgraphs]                
#         sample_diff = np.sum(abs(samples_from_df_raw_data - tmp_samples_from_df_raw_data_matched_emb).values)
#         if sample_diff > 0:
#             pass
# #             print(f'sample difference in old dataset and new dataset for the same row_in: {sample_diff}')
# #             titles = ['row in new dataset', 'row in old dataset']
# #             plot_two_rows(tmp_samples_from_df_raw_data_matched_emb, 
# #                          samples_from_df_raw_data, 
# #                          titles)
# ----------

        # In each cluster, for each representative, draw the dictionary, and the representative.
        titles = [f'dictionary {cluster}\nscore: {cur_importance_score:.2f}', 
                  f'representatative {cur_rep_idx}\nweight: {cur_weight:.2f}']
        plot_two_rows(dict_row, samples_from_df_raw_data, titles)
        
        
        # Get node embeddings. 
        node_embedding_of_samples = df_node_all.loc[sample_idx_in_subgraphs].iloc[0].values
        
        # Create networkx graph from the samples adjacency matrix. 
        Adj = samples_from_df_raw_data.values.reshape(21, 21)
        G = nx.from_numpy_matrix(Adj, )
        
        # Each node has its node embedding id, from: node_embedding_of_samples
        node_name_map = {x:node_embedding_of_samples[x] for x in range(len(node_embedding_of_samples))}
        G = nx.relabel.relabel_nodes(G, node_name_map)
        
        # Get nodes and edges from G, reorder nodes by ascending order of node embedding. 
        nodes = list(G.nodes(data = True))
        edges = list(G.edges(data = True))
        
        node_pos = np.array([int(x[0][1:]) for x in nodes])
        new_node_order = node_pos.argsort()
        
        new_node_pos = node_pos[new_node_order]
        new_nodes = [nodes[i] for i in new_node_order]
        
        # create position for reordered nodes.
        min_node_pos = min(new_node_pos)
        max_node_pos = max(new_node_pos)
        node_pos_span = max_node_pos - min_node_pos

        pos = {new_nodes[i][0] : ((new_node_pos[i] - min_node_pos) / node_pos_span * 100, 0) \
               for i in range(len(nodes))}
        
        # change edges to same order between two vertices. i.e. n1 > n2.
        backward_edge = lambda edge: (int(edge[0][1:]) < int(edge[1][1:]))
        flip_nodes = lambda edge: (edge[1], edge[0], edge[2]) if backward_edge(edge) else edge
        new_edges = [flip_nodes(x) for x in edges]
        
        # Create new graph H, from the new nodes, and same edges as G. 
        H_directed = nx.DiGraph()
        H_directed.add_nodes_from(new_nodes)
        H_directed.add_edges_from(new_edges)
        
        p2saved = f'/data/shared/jianhao/online_cvxNDL_results/{chromo}/representatvies_edge_list_updated_0905_MCMC_pivot_train_iter_10k/dictionary_{cluster}'
        if not osp.isdir(p2saved):
            os.makedirs(p2saved)
        draw_arc_of_graph(H_directed, pos, 
                          p2savefig=osp.join(p2saved, 
                                             f'arc_plot_dictionary_{cluster}_rep_{cur_rep_idx}'))
        plt.close()
        
        p2edge_list = osp.join(p2saved, 
                              f'edge_list_dicitonary_{cluster}_rep_{cur_rep_idx}.txt')
        with open(p2edge_list, 'w') as f:
            for line in new_edges:
                f.write(','.join(line[:2]))
                f.write('\n')
        
        print('-' * 7, f'cluster {cluster}, rep {cur_rep_idx}, pos in {idx}', '-' * 7)
        
#         break
#     break
    

In [None]:
new_edges

# randomly select original samples as representatives for each dictionary.

In [None]:
df_raw_data

In [None]:
df_node_all

In [None]:
for cluster in descending_order_of_At:
    
    for cur_rep_idx in range(10):
        random_idx = np.random.choice(len(df_raw_data))
        samples_from_df_raw_data = df_raw_data.loc[random_idx]
        
        # Get node embeddings. 
        node_embedding_of_samples = df_node_all.loc[random_idx]
        
        # Create networkx graph from the samples adjacency matrix. 
        Adj = samples_from_df_raw_data.values.reshape(21, 21)
        G = nx.from_numpy_matrix(Adj, )
        
        # Each node has its node embedding id, from: node_embedding_of_samples
        node_name_map = {x:node_embedding_of_samples[x] for x in range(len(node_embedding_of_samples))}
        G = nx.relabel.relabel_nodes(G, node_name_map)
        
        # Get nodes and edges from G, reorder nodes by ascending order of node embedding. 
        nodes = list(G.nodes(data = True))
        edges = list(G.edges(data = True))
        
        node_pos = np.array([int(x[0][1:]) for x in nodes])
        new_node_order = node_pos.argsort()
        
        new_node_pos = node_pos[new_node_order]
        new_nodes = [nodes[i] for i in new_node_order]
        
        # create position for reordered nodes.
        min_node_pos = min(new_node_pos)
        max_node_pos = max(new_node_pos)
        node_pos_span = max_node_pos - min_node_pos

        pos = {new_nodes[i][0] : ((new_node_pos[i] - min_node_pos) / node_pos_span * 100, 0) \
               for i in range(len(nodes))}
        
        # change edges to same order between two vertices. i.e. n1 > n2.
        backward_edge = lambda edge: (int(edge[0][1:]) < int(edge[1][1:]))
        flip_nodes = lambda edge: (edge[1], edge[0], edge[2]) if backward_edge(edge) else edge
        new_edges = [flip_nodes(x) for x in edges]
        
        # Create new graph H, from the new nodes, and same edges as G. 
        H_directed = nx.DiGraph()
        H_directed.add_nodes_from(new_nodes)
        H_directed.add_edges_from(new_edges)
        
        p2saved = f'/data/shared/jianhao/online_cvxNDL_results/{chromo}/random_rep_edge_list/dictionary_{cluster}'
        if not osp.isdir(p2saved):
            os.makedirs(p2saved)
        draw_arc_of_graph(H_directed, pos, 
                          p2savefig=osp.join(p2saved, 
                                             f'arc_plot_dictionary_{cluster}_rep_{cur_rep_idx}'))
        plt.close()
        
        p2edge_list = osp.join(p2saved, 
                              f'edge_list_dicitonary_{cluster}_rep_{cur_rep_idx}.txt')
        with open(p2edge_list, 'w') as f:
            for line in new_edges:
                f.write(','.join(line[:2]))
                f.write('\n')
        
        
#         break
#     break
    

In [None]:
random_idx