# Merge modules from WGCNA
#### Jianfeng
#### 05/14/2024

In [1]:
# import libraries
# import torch
# import torch.nn.functional as F
# from torch_geometric.nn import GCNConv
import os
import csv
import copy
import math
import time
import random
import pickle
import argparse
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
import matplotlib.pyplot as plt
from statsmodels.tools.eval_measures import mse
from scipy import stats
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
# Load gtex data
data_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/12052023/'
gt = pd.read_csv(data_dir+"new_normed_gtex_gtex_allen_gene.txt", low_memory=False, index_col=0, sep="\t")
region_pick = ['Amygdala', 'Anterior_cingulate_cortex_BA24', 'Caudate_basal_ganglia', 
               'Cerebellar_Hemisphere', 'Frontal_Cortex_BA9', 'Hippocampus', 'Hypothalamus', 
               'Nucleus_accumbens_basal_ganglia', 'Putamen_basal_ganglia', 'Substantia_nigra']

# build a dictionary to count the freq of each subject 
sample_subject_list = gt.loc['subject'].tolist()
subject_region_mat = pd.DataFrame(np.zeros((len(set(sample_subject_list)), len(region_pick)), dtype=int))
subject_region_mat.index = sorted(set(sample_subject_list))
subject_region_mat.columns = region_pick
for i in range(gt.shape[1]):
    region = gt.loc['region'][i]
    subject = gt.loc['subject'][i]
    region_idx = region_pick.index(region)
    subject_idx = subject_region_mat.index.tolist().index(subject)
    subject_region_mat.iloc[subject_idx, region_idx] = 1

In [5]:
# find ME for each module, {'M1': ['A1BG', 'A2M', 'A2ML1'], 'M2':...}
# output is the module-MEs dictionary
def find_ME(go_mat, input_dict):
    ME_dict = {}
    for key, value in input_dict.items():
        exp_mat = go_mat.loc[value]
        pca = PCA(n_components=1)
        pca.fit(exp_mat)
        ME = pca.components_[0]
        ME_dict[key] = ME
        
    return ME_dict


def find_max_element(mat):
    max_value = -2
    max_value_pos = [0, 0]
    for i in range(len(mat)):
        for j in range(i+1, len(mat)):
            if mat.iloc[i,j] > max_value:
                max_value = mat.iloc[i,j]
                max_value_pos = [i, j]  
                
    return max_value, max_value_pos
                

def merge_module(input_dict):
    ME_dict = find_ME(go_mat, input_dict)
    ME_mat = pd.DataFrame(ME_dict)
    module_list = list(input_dict.keys())
    corr_mat = ME_mat.corr()
    max_corr, max_corr_pair = find_max_element(corr_mat)
    # print(f'Max correlation between two modules is {max_corr}')
    mod1_idx, mod2_idx = max_corr_pair
    new_dict = input_dict.copy()
    # print(f'The module to extend: {module_list[mod1_idx]}, there are {len(input_dict[module_list[mod1_idx]])} in this module')
    # print(f'The module to remove: {module_list[mod2_idx]}, there are {len(input_dict[module_list[mod2_idx]])} in this module')
    if max_corr>0.8:
        new_dict[module_list[mod1_idx]].extend(new_dict.pop(module_list[mod2_idx]))
        # print(f'Now there are {len(new_dict[module_list[mod1_idx]])} in module {module_list[mod1_idx]} in the new dict')
    # print(f'After merging modules, the number of modules left: {len(new_dict.keys())}')
        
    return new_dict


def reassign_gene_to_module(go_mat, input_dict):
    ME_dict = find_ME(go_mat, input_dict)
    ME_mat = pd.DataFrame(ME_dict)
    module_list = ME_mat.columns.tolist()
    gene_list = go_mat.index.tolist()
    kME_mat = np.corrcoef(ME_mat.T, go_mat)[len(ME_mat.T):(len(go_mat)+len(ME_mat.T)), :len(ME_mat.T)]
    # create a new dictionary with same modules
    new_dict = {}
    for mod in module_list:
        new_dict[mod] = []
    for i in range(len(gene_list)):
        gene_name = gene_list[i]
        gene_kME = kME_mat[i]
        mod = module_list[np.argmax(gene_kME)]
        new_dict[mod].append(gene_name)
    # remove key with no genes
    # new_dict = {k: v for k, v in new_dict.items() if v}
    
    return new_dict


# compare the original dict and the new dict, 
# find the module for which the largest number of genes were reassigned,
# if <70% of the genes in this module remain in the module after re-assignment, remove the module
def check_after_reassign(new_dict, input_dict):
    reassign_number_list = []
    module_list = list(input_dict.keys())
    for mod in module_list:
        count = 0
        old_genes = input_dict[mod]
        new_genes = new_dict[mod]
        for gene in old_genes:
            if gene in new_genes:
                count = count + 1
        reassign_number_list.append(count/len(old_genes))
        print(f'{count}/{len(old_genes)}')
    min_value = reassign_number_list[np.argmin(reassign_number_list)]
    pick_mod = module_list[np.argmin(reassign_number_list)]
    
    return pick_mod, min_value


# remove empty module
def clear_empty_module(input_dict):
    new_dict = {k: v for k, v in input_dict.items() if v}
    if len(new_dict.keys())<len(input_dict.keys()):
        diff = len(input_dict.keys())-len(new_dict.keys())
        # print(f'There are empty modules, removing {diff} empty modules')
    
    return new_dict


# remove the module after reassignment
# input a dictionary, and remove a module if needed, output a new dict afterward
def remove_module(go_mat, input_dict, pick_mod):
    ME_dict = find_ME(go_mat, input_dict)
    ME_mat = pd.DataFrame(ME_dict)
    module_list = ME_mat.columns.tolist()
    gene_list = go_mat.index.tolist()
    kME_mat = np.corrcoef(ME_mat.T, go_mat)[len(ME_mat.T):(len(go_mat)+len(ME_mat.T)), :len(ME_mat.T)]
    # create a new dictionary with same modules
    update_dict = {}
    pick_mod_idx = module_list.index(pick_mod)
    new_kME_mat = np.delete(kME_mat, pick_mod_idx, axis=1)
    # remove the module from keys
    module_list.pop(pick_mod_idx)
    for mod in module_list:
        update_dict[mod] = []
    for i in range(len(gene_list)):
        gene_name = gene_list[i]
        gene_kME = new_kME_mat[i]
        mod = module_list[np.argmax(gene_kME)]
        update_dict[mod].append(gene_name)
    
    return update_dict


# merge modules
def MERGE_WORKFLOW(go_mat, input_dict):
    ME_dict = find_ME(go_mat, input_dict)
    merged_input_dict = merge_module(input_dict)
    while len(merged_input_dict.keys()) < len(input_dict.keys()):
        input_dict = merged_input_dict.copy()
        merged_input_dict = merge_module(input_dict)
    # print(f'End merging module phase 1. The number of modules left: {len(merged_input_dict.keys())}\n')
    # print('Starting steps for coherency and incoherent:')
    # reassign_dict = reassign_gene_to_module(go_mat, merged_input_dict)
    # pick_mod, min_value = check_after_reassign(reassign_dict, merged_input_dict)
    # while min_value<0.7:
    #     print(f'Min value is {min_value}, and the module to remove is {pick_mod}')
    #     update_dict = remove_module(go_mat, reassign_dict, pick_mod)
    #     reassign_dict = reassign_gene_to_module(go_mat, update_dict)
    #     reassign_dict = clear_empty_module(reassign_dict)
    #     pick_mod, min_value = check_after_reassign(update_dict, reassign_dict)
        
    return merged_input_dict

In [8]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# model settings
t_epoch = 300
f_epoch = 500
status = 'trainable'
missing_N = 5
# merge modules for every subject
include_subject = [s for s in subject_region_mat.index if subject_region_mat.loc[s].sum()>=(10-missing_N)]
for subject in include_subject:
    save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_8_or_less/Prediction/300_500_by_subject/'
    csv_file = f'{subject}_{t_epoch}_{f_epoch}_{status}_LMfromGTEx.csv'
    go_mat = pd.read_csv(save_dir+csv_file, index_col=0)
    # read the raw module output from WGCNA
    filename = f'{subject}_raw_modules.csv'
    save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/WGCNA/raw_modules/'
    module_df = pd.read_csv(save_dir+filename)
    # remove color genes with color 'grey'
    color_list = module_df['Cluster']
    remove_row = [i for i in range(module_df.shape[0]) if color_list[i]=='grey']
    update_df = module_df.drop(remove_row).reset_index(drop=True)
    module_dict = update_df.groupby('Module')['Gene'].apply(list).to_dict()
    update_mat = go_mat.drop(go_mat.index[remove_row])
    print(f'{15044-update_mat.shape[0]} genes with color grey are excluded from analysis for {subject}')
    # check if the file is already exist or not
    save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/WGCNA/output_modules/'
    file_name = f'{subject}_merged_module.csv'
    if os.path.isfile(save_dir+file_name):
        print(f'There is already a file for {subject}!')
        continue
    else:
        # run the merging module function
        output_dict = MERGE_WORKFLOW(update_mat, module_dict)
        # clean the modules
        final_dict = {}
        module_list = list(output_dict.keys())
        for key, value in output_dict.items():
            new_module_name = f'M{module_list.index(key)+1}'
            final_dict[new_module_name] = value
        # save the result 
        data = [(gene, key) for key, genes in final_dict.items() for gene in genes]
        final_mat = pd.DataFrame(data, columns=['gene', 'module'])
        save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/WGCNA/output_modules/'
        file_name = f'{subject}_merged_module.csv'
        final_mat.to_csv(save_dir+file_name, index=False)
        print(f"End merging modules for {subject}!")

4 genes with color grey are excluded from analysis for GTEX-1192X
There is already a file for GTEX-1192X!
1 genes with color grey are excluded from analysis for GTEX-11DXW
End merging modules for GTEX-11DXW!
2 genes with color grey are excluded from analysis for GTEX-11DXY
End merging modules for GTEX-11DXY!
0 genes with color grey are excluded from analysis for GTEX-11DYG
End merging modules for GTEX-11DYG!
0 genes with color grey are excluded from analysis for GTEX-11DZ1
End merging modules for GTEX-11DZ1!
0 genes with color grey are excluded from analysis for GTEX-11GSP
End merging modules for GTEX-11GSP!
0 genes with color grey are excluded from analysis for GTEX-11OF3
End merging modules for GTEX-11OF3!
0 genes with color grey are excluded from analysis for GTEX-11ONC
End merging modules for GTEX-11ONC!
1 genes with color grey are excluded from analysis for GTEX-11PRG
End merging modules for GTEX-11PRG!
0 genes with color grey are excluded from analysis for GTEX-11TTK
End merging 

In [None]:
# subject = 'GTEX-1192X'
# # read the allen region prediction
# save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_8_or_less/Prediction/300_500_by_subject/'
# csv_file = f'{subject}_300_500_trainable_LMfromGTEx.csv'
# go_mat = pd.read_csv(save_dir+csv_file, index_col=0)
# # read the raw module output from WGCNA
# filename = f'{subject}_raw_modules.csv'
# save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/WGCNA/raw_modules/'
# module_df = pd.read_csv(save_dir+filename)
# module_dict = module_df.groupby('Module')['Gene'].apply(list).to_dict()

# # run the script
# output_dict = MERGE_WORKFLOW(go_mat, module_dict)

In [None]:
# final_dict = {}
# module_list = list(output_dict.keys())
# for key, value in output_dict.items():
#     new_module_name = f'M{module_list.index(key)+1}'
#     final_dict[new_module_name] = value

# data = [(gene, key) for key, genes in final_dict.items() for gene in genes]
# final_mat = pd.DataFrame(data, columns=['gene', 'module'])
# save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/WGCNA/output_modules/'
# final_mat.to_csv(save_dir+f'{subject}_merged_module.csv', index=False)
# # y = pd.read_csv(save_dir+f'{subject}_merged_module.csv')