In [1]:
import os
import sys
import numpy as np
import pandas as pd
import networkx as nx
from collections import defaultdict
import random
from networkx.algorithms import community
from sklearn.cluster import DBSCAN
from scipy import stats
from scipy.stats import norm
import h5py
import pickle

import warnings
warnings.filterwarnings('ignore')


import itertools
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import plotly.graph_objects as go
#plot 
from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import matplotlib.ticker as ticker
from matplotlib.ticker import EngFormatter
import seaborn as sns
from statsmodels.stats.multitest import multipletests
from joblib import Parallel, delayed

In [2]:
# 读取距离矩阵
def read_distance_matrices(h5_file, chromosome):
    with h5py.File(h5_file, 'r') as f:
        distance_matrices = f[chromosome][:]
    return distance_matrices
def plot_distance_matrix(matrix):
    plt.imshow(matrix, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.title("Distance Matrix Heatmap")
    plt.show()
    
def visualize_graph(G):
    pos = nx.spring_layout(G)  # 使用 spring 布局绘制图
    weights = nx.get_edge_attributes(G, 'weight')
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_size=700, node_color='lightblue', edge_color='gray', linewidths=1, font_size=15)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=weights, font_color='red')
    plt.title("Clique Graph Visualization")
    plt.show()


def proximity_matrices(distance_matrices, thresholds):
    # 创建一个与输入数组相同形状的数组，用于存储结果
    result = np.zeros_like(distance_matrices, dtype=int)
    # 遍历每个矩阵和对应的阈值
    for i, (matrix, threshold) in enumerate(zip(distance_matrices, thresholds)):
        mask = (matrix <= threshold) & ~np.isnan(matrix)
        result[i][mask] = 1
    defaultdict
    return result


def build_clique_graph(real_frequencies):
    G = nx.Graph()
    for clique, frequency in real_frequencies.items():
        G.add_edge(clique[0], clique[1], weight=frequency)
        G.add_edge(clique[0], clique[2], weight=frequency)
        G.add_edge(clique[1], clique[2], weight=frequency)
    return G

def louvain_clustering(G):
    partition = community_louvain.best_partition(G, weight='weight')
    return partition

def calculate_cluster_frequencies(partition, real_frequencies):
    cluster_frequencies = defaultdict(float)
    cluster_nodes = defaultdict(list)
    
    for node, cluster_id in partition.items():
        cluster_nodes[cluster_id].append(node)
    
    for clique, frequency in real_frequencies.items():
        for cluster_id, nodes in cluster_nodes.items():
            if set(clique).intersection(nodes):
                cluster_frequencies[cluster_id] += frequency
                break
    
    return cluster_frequencies, cluster_nodes

def merge_frequencies(partition, real_frequencies):
    merged_frequencies = {}
    for triplet, frequency in real_frequencies.items():
        cluster_ids = tuple(sorted(set(partition[node] for node in triplet)))
        if cluster_ids in merged_frequencies:
            merged_frequencies[cluster_ids] += frequency
        else:
            merged_frequencies[cluster_ids] = frequency
    return merged_frequencies


def build_contact_graph(pairwise_contacts, resolution, min_distance,max_distance):
    graph = nx.Graph()
    num_nodes = pairwise_contacts.shape[0]
    min_bin_distance = min_distance // resolution  # 将最小距离转换为bin数
    max_bin_distance =  max_distance // resolution 
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if pairwise_contacts[i, j] and abs(i - j) >= min_bin_distance and abs(i - j) <= max_bin_distance:
                #有接触并且满足线性距离的限制
                graph.add_edge(i, j)
    return graph


def find_cliques(contact_graph, min_size=3,max_body=6):
    cliques = [clique for clique in nx.find_cliques(contact_graph) if (len(clique) >= min_size and len(clique) <= max_body)]
    return cliques

def find_3body_cliques(contact_graph):
    all_cliques =nx.find_cliques(contact_graph)
    
    three_body_cliques = []
    for clique in all_cliques:
        
        if len(clique) > 3 and len(clique) < 6:
            three_body_combinations = list(itertools.combinations(clique, 3))
            three_body_cliques.extend(three_body_combinations)
        elif len(clique)==3:
            three_body_cliques.append(clique)
        else:
            continue
    return three_body_cliques


def calculate_multibody_frequencies(pairwise_contacts_list,min_distance,max_distance,resolution):
    #所有的
    multibody_counts = defaultdict(int)
    num_matrices = len(pairwise_contacts_list)#细胞的数量
    #分别处理每个细胞
    for pairwise_contacts in pairwise_contacts_list:
        contact_graph = build_contact_graph(pairwise_contacts, resolution,min_distance,max_distance)
        #所有的clique
        cliques = find_3body_cliques(contact_graph)
        for clique in cliques:
            multibody_counts[tuple(sorted(clique))] += 1 

    multibody_frequencies = {key: value / num_matrices for key, value in multibody_counts.items()}
    return multibody_frequencies


def proximity_frequency_two(x, y, proximity_matrix_list):
    return np.mean(proximity_matrix_list[:, x, y] == 1)


def proximity_frequency_three(x, y, z, proximity_matrix_list):
    # 利用布尔条件的乘积计算
    close_triples = (proximity_matrix_list[:, x, y] == 1) & \
                    (proximity_matrix_list[:, y, z] == 1) & \
                    (proximity_matrix_list[:, z, x] == 1)
    return np.mean(close_triples)


def generate_random_indices(chr_idx_list, span_range, idx1, idx2,idx3):
    random_indices = []
    for _ in range(10):
        index1_random = random.choice(chr_idx_list[:-span_range])
        index2_random = index1_random + (idx2 - idx1)
        index3_random = index1_random + (idx3 - idx1)
        random_indices.append([(index1_random, index2_random, index3_random)])
    return random_indices


def count_max_distance(distance_matrices, idx_a, idx_b, idx_c):
    distances = np.stack([distance_matrices[:, idx_a, idx_b],
                          distance_matrices[:, idx_b, idx_c],
                          distance_matrices[:, idx_c, idx_a]], axis=-1)
    
    # 替换 NaN 为 -inf 确保不干扰最大值计算
    distances_no_nan = np.nan_to_num(distances, nan=-np.inf)
    max_distances = np.nanmax(distances_no_nan, axis=-1)
    
    # 将所有全为 NaN 的结果替换回 NaN
    max_distances[np.all(np.isnan(distances), axis=-1)] = np.nan
    return max_distances

#处理每一行
def process_row(i):
    idx_a = df.loc[i, 'A']
    idx_b = df.loc[i, 'B']
    idx_c = df.loc[i, 'C']
    frequency_ab = proximity_frequency_two(idx_a, idx_b, proximity_matrix_list)
    frequency_bc = proximity_frequency_two(idx_b, idx_c, proximity_matrix_list)
    frequency_ca = proximity_frequency_two(idx_c, idx_a, proximity_matrix_list)
    frequency_abc = proximity_frequency_three(idx_a, idx_b, idx_c, proximity_matrix_list)
    exp = frequency_ab * frequency_bc * frequency_ca
    obs = frequency_abc
    #abc3个位点在所有细胞中的邻近成都
    max_distance_list = count_max_distance(distance_matrices, idx_a, idx_b, idx_c)
    obs_distance = np.nanmean(max_distance_list)
    
    span_range = idx_c - idx_a
    #生成相同距离分布的
    random_index_list = generate_random_indices(np.arange(proximity_matrix_list.shape[1]), span_range, idx_a, idx_b, idx_c)
    exp_distance_lists = [
        count_max_distance(distance_matrices, idx[0][0], idx[0][1], idx[0][2]) for idx in random_index_list
    ]
    exp_distance_list = np.nanmean(exp_distance_lists, axis=0).tolist()
    exp_distance_array = np.array(exp_distance_list)
    max_distance_array = np.array(max_distance_list)
    exp_distance_array = exp_distance_array[~np.isnan(exp_distance_array) & ~np.isinf(exp_distance_array)]
    max_distance_array = max_distance_array[~np.isnan(max_distance_array) & ~np.isinf(max_distance_array)]
    t_stat, p_double_tailed = stats.ttest_ind(exp_distance_array, max_distance_array, equal_var=True)
    p_single_tailed = p_double_tailed / 2 if t_stat > 0 else 1 - p_double_tailed / 2
    return [exp, obs, obs / exp, obs_distance, np.nanmean(exp_distance_list), p_single_tailed]


## load data

In [3]:
chromosome = 'chr2'
h5_file = f'/shareb/mliu/HiMulti/data/mESC/distance_matrix/distance_mESC_40kb_{chromosome}.h5'
#太满了
distance_matrices = read_distance_matrices(h5_file, chromosome)#(229, 3738, 3738)

In [4]:
#阈值
path = "mesc_200nm_t_list.pkl"
with open(path,"rb") as f:
    t_list = pickle.load(f)

In [5]:
proximity_matrix_list = proximity_matrices(distance_matrices,t_list)

In [6]:
min_distance = 40000 
max_distance = 4000000 #最远距离的限制
resolution = 40000
real_frequencies = calculate_multibody_frequencies(proximity_matrix_list,min_distance,max_distance,resolution)

In [7]:
df = pd.DataFrame(list(real_frequencies.items()), columns=['triplet', 'frequency'])
df[['A', 'B', 'C']] = pd.DataFrame(df['triplet'].tolist(), index=df.index)
df['exp'] = None
df['obs'] = None
df['O/E'] = None
df['obs_distance'] = None
df['exp_distance'] = None
df['distance_pvalue'] = None
df.to_csv("mesc_tmp.csv",index=False)

In [None]:
warnings.filterwarnings('ignore', category=RuntimeWarning)
results = Parallel(n_jobs=40)(delayed(process_row)(i) for i in range(len(df)))



In [None]:
df.loc[:, ['exp', 'obs', 'O/E', 'obs_distance', 'exp_distance', 'distance_pvalue']] = results


In [None]:
### 获取所有 distance_pvalue 列中的非空值
p_values = df['distance_pvalue'].dropna()

# 使用 FDR 校正
# alpha 是显著性水平，method='fdr_bh' 指的是 Benjamini-Hochberg 方法
fdr_results = multipletests(p_values, alpha=0.05, method='fdr_bh')
adjusted_pvalues = fdr_results[1]  # 获取校正后的 p 值

# 将 FDR 校正后的 p 值写入新的列 'fdr_corrected_pvalue'
df.loc[df['distance_pvalue'].notna(), 'fdr_corrected_pvalue'] = adjusted_pvalues

In [None]:
df.to_csv("mesc_chr2_40kb_bin_200nm.csv",index=False)