In [None]:
import pandas as pd
import numpy as np
import anndata as ad
import scanpy as sc
import scipy as sci
import squidpy as sq
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
plt.rcParams['axes.unicode_minus']=False
plt.rc('font', family='Helvetica')
plt.rcParams['pdf.fonttype'] = 42
sc.settings.verbosity = 3             # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_header()
sc.set_figure_params(dpi=120,facecolor='w',frameon=True,figsize=(4,4)) 
%config InlineBackend.figure_format='retina'
%matplotlib inline

# Integration

In [None]:
adata=sc.read_h5ad('/data2/liyuzhe/data/SPACEX/adata_ADmouse.h5ad')
adata=adata[adata.obs.label.isin(['8months-control-replicate_1', '8months-control-replicate_2', 
                                  '13months-control-replicate_1', '13months-control-replicate_2'])]
adata

In [None]:
from function import TISCOPE_integration

In [None]:
adata=TISCOPE_integration(adata,GPU=2,outdir='AD')

In [None]:
sc.pl.umap(adata,color=['batch','tissue'])

In [None]:
sc.tl.louvain(adata,resolution=0.02,neighbors_key='SPACE')

In [None]:
from sklearn.metrics.cluster import adjusted_rand_score as ARI
from sklearn.metrics.cluster import normalized_mutual_info_score as NMI
from sklearn.metrics import homogeneity_score

In [None]:
ARI(adata.obs['tissue'].values,adata.obs['louvain'].values)

In [None]:
NMI(adata.obs['tissue'].values,adata.obs['louvain'].values)

# Projection

In [None]:
from function import TISCOPE_projection

In [None]:
adata_ref=sc.read('AD/adata.h5ad')
adata=sc.read('/data2/liyuzhe/data/SPACEX/adata_ADmouse.h5ad')
adata=adata[adata.obs.label.isin(['8months-disease-replicate_1', '8months-disease-replicate_2',
                                  '13months-disease-replicate_1', '13months-disease-replicate_2'])]
adata

In [None]:
adata=TISCOPE_projection(adata,adata_ref,outdir='./AD',model_path='./AD')

In [None]:
sc.set_figure_params(dpi=120,facecolor='w',frameon=True,figsize=(4,4.5)) 
sc.pl.umap(adata,color=['tissue','projection'],ncols=1,legend_fontsize=10)

# Label transfer

In [None]:
adata=sc.read('./AD/adata_projection.h5ad')
adata

In [None]:
adata_query=adata[adata.obs.projection=='query']
adata_query

In [None]:
adata_ref=adata[adata.obs.projection=='reference']
adata_ref

In [None]:
from sklearn.neighbors import KNeighborsClassifier

# Fit kNN classifier on the reference
knn = KNeighborsClassifier(n_neighbors=20, weights='distance')
knn.fit(X=adata_ref.obsm['latent'], y=adata_ref.obs[['tissue']])

# Predict labels with the knn classifier
proba = knn.predict_proba(adata_query.obsm['latent'])
k_dist, k_indx = knn.kneighbors(adata_query.obsm['latent'], n_neighbors=20, return_distance=True)

predictions = proba
predictions = pd.DataFrame({'tissue_transfer': np.argmax(predictions, axis=1), 'probability': np.max(predictions, axis=1), 
                            'mean_dist': np.mean(k_dist, axis=1), 'k_dist': k_dist[:,19]})
predictions['tissue_transfer'] = predictions['tissue_transfer'].map({i: l for i, l in enumerate(knn.classes_)})
predictions.index = adata_query.obs.index

adata_query.obs = pd.concat([adata_query.obs, predictions], axis=1)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import itertools

In [None]:
report = classification_report(adata_query.obs['tissue'], adata_query.obs['tissue_transfer'])
print(report)

In [None]:
ARI(adata_query.obs['tissue'], adata_query.obs['tissue_transfer'])

In [None]:
NMI(adata_query.obs['tissue'], adata_query.obs['tissue_transfer'])

In [None]:
classes=adata.obs.tissue.unique()
confusion = confusion_matrix(adata_query.obs['tissue'],adata_query.obs['tissue_transfer'],labels=classes)
row_sums = confusion.sum(axis=1)
new_matrix = confusion / row_sums[:, np.newaxis]

In [None]:
plt.clf()
plt.figure(figsize = (6,6))
plt.rcParams["axes.grid"] = False

# place labels at the top
plt.gca().xaxis.tick_top()
plt.gca().xaxis.set_label_position('top')

# plot the matrix per se
plt.imshow(new_matrix, interpolation='nearest', cmap=plt.cm.Blues,vmax=1)


# plot colorbar to the right
plt.colorbar()

fmt = '.2f'

# write the number of predictions in each bucket
thresh = 0.2
for i, j in itertools.product(range(confusion.shape[0]), range(confusion.shape[1])):

    # # if background is dark, use a white number, and vice-versa
    # plt.text(j, i,format(new_matrix[i, j], fmt),
    #          horizontalalignment="center",
    #          fontsize=10,
    #          color="white" if new_matrix[i, j] > thresh else "black")
    if new_matrix[i, j] > thresh:
        plt.text(j, i,format(new_matrix[i, j], fmt),
                 horizontalalignment="center",
                 fontsize=10,
                 color="white")

tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90,fontsize=12)
plt.yticks(tick_marks, classes, fontsize=12)
plt.tight_layout()
plt.ylabel('True label',size=10)
plt.xlabel('Predicted label',size=10)
# plt.savefig('figures/AD_projection_confusion_matrix.pdf',bbox_inches='tight')
# plt.savefig('figures/AD_projection_confusion_matrix.png',bbox_inches='tight')
plt.show()

# Condition associated TM

In [None]:
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

def calculate_tm_enrichment(adata, 
                           sample_type_col='Sample_type', 
                           batch_col='batch', 
                           cc_col='TM'):
    """
    计算每个样本类型中各类 TM（如细胞类型）的富集得分。
    
    参数:
        adata: AnnData 对象，包含 obs 表格。
        sample_type_col: str, 样本类型的列名，默认 'Sample_type'
        batch_col: str, 批次信息的列名，默认 'batch'
        cc_col: str, 细胞类别（TM）的列名，默认 'cc'

    返回:
        dict: 每个 Sample_type 对应一个 DataFrame，包含各 TM 的 enrichment scores、AUC、p 值等。
    """
    # 从 adata.obs 中提取所需列
    obs_df = adata.obs[[sample_type_col, batch_col, cc_col]].copy()
    obs_df.columns = ['Sample_type', 'batch', 'cc']  # 内部统一命名以简化后续处理

    # 计算每个样本中每个 TM 的细胞计数
    count_df = obs_df.groupby(['batch', 'cc']).size().unstack(fill_value=0)

    # 计算每个样本的总细胞数
    sample_totals = count_df.sum(axis=1)

    # 计算比例矩阵
    prop_df = count_df.div(sample_totals, axis=0)

    # 将 condition（Sample_type）信息添加到比例矩阵中
    sample_conditions = obs_df[['batch', 'Sample_type']].drop_duplicates().set_index('batch')['Sample_type']
    prop_df['Sample_type'] = sample_conditions

    # 初始化结果字典
    results_dict = {}

    # 获取所有 conditions 和所有 TMs
    conditions = prop_df['Sample_type'].unique()
    tms = count_df.columns

    # 遍历每个 condition
    for c in conditions:
        scores = []
        auc_values = []
        p_values = []
        tm_names = []

        # 遍历每个 TM
        for t in tms:
            # 提取当前 condition 下该 TM 的比例
            prop_c = prop_df[prop_df['Sample_type'] == c][t]
            # 提取其他 condition 下该 TM 的比例
            prop_other = prop_df[prop_df['Sample_type'] != c][t]

            # 数据检查
            if len(prop_c) == 0 or len(prop_other) == 0:
                auc = 0.5
                p_value = 1.0
            else:
                try:
                    # Mann-Whitney U 检验
                    u_stat, p_value = mannwhitneyu(prop_c, prop_other, alternative='two-sided')
                    n1 = len(prop_c)
                    n2 = len(prop_other)
                    auc = u_stat / (n1 * n2)  # AUC 近似计算
                except Exception:
                    auc = 0.5
                    p_value = 1.0

            # 计算 enrichment score
            score = auc - 0.5

            # 存储结果
            tm_names.append(t)
            scores.append(score)
            auc_values.append(auc)
            p_values.append(p_value)

        # 多重检验校正（FDR）
        rejected, p_adjusted, _, _ = multipletests(p_values, alpha=0.05, method='fdr_bh')

        # 构建结果 DataFrame
        result_df = pd.DataFrame({
            'TM': tm_names,
            'score': scores,
            'AUC': auc_values,
            'p_value': p_values,
            'p_adjusted': p_adjusted
        })

        # 按得分降序排序并重置索引
        result_df = result_df.sort_values('score', ascending=False).reset_index(drop=True)

        # 添加到结果字典
        results_dict[c] = result_df

    return results_dict

In [None]:
results = calculate_tm_enrichment(adata,sample_type_col='group', 
                           batch_col='batch', 
                           cc_col='region')

In [None]:
def visualize_tm_enrichment_bubble(results_dict, top_n=15, condition_order=None, tm_order=None, 
                                   sig_threshold=0.05):
    """
    使用气泡图可视化TM富集分析结果，并为显著点添加红圈
    
    参数:
    results_dict: 从calculate_tm_enrichment函数返回的结果字典，必须包含 'p_adjusted' 和 'score'
    top_n: 每个条件显示前N个和后N个TM
    condition_order: list, 条件的显示顺序，如 ['Control', 'Disease']
    tm_order: list, TM的显示顺序（可选）
    sig_threshold: float, 显著性阈值，默认0.05
    """

    # 合并所有结果
    all_results = []
    for condition, df in results_dict.items():
        df = df.copy()
        df['condition'] = condition
        all_results.append(df)
    
    combined_df = pd.concat(all_results)

    # 确保包含必要的列
    required_columns = {'score', 'p_adjusted', 'TM'}
    for df in results_dict.values():
        if not required_columns.issubset(df.columns):
            raise ValueError(f"Missing required columns: {required_columns - set(df.columns)}")

    # 为每个条件选择前N个和后N个TM
    top_bottom_dfs = []
    for condition in combined_df['condition'].unique():
        cond_df = combined_df[combined_df['condition'] == condition].copy()
        cond_df = cond_df.sort_values('score', ascending=False)
        top_df = cond_df.head(top_n)
        bottom_df = cond_df.tail(top_n)
        top_bottom_dfs.extend([top_df, bottom_df])
    
    selected_df = pd.concat(top_bottom_dfs)

    # 获取所有需要显示的conditions和TMs
    if condition_order is not None:
        display_conditions = [c for c in condition_order if c in selected_df['condition'].unique()]
    else:
        display_conditions = sorted(selected_df['condition'].unique())
    
    if tm_order is not None:
        display_tms = [tm for tm in tm_order if tm in selected_df['TM'].unique()]
    else:
        display_tms = list(selected_df['TM'].unique())

    # 创建完整的condition-TM组合网格
    full_combinations = []
    for condition in display_conditions:
        for tm in display_tms:
            full_combinations.append({'condition': condition, 'TM': tm})
    
    full_grid_df = pd.DataFrame(full_combinations)

    # 首先从selected_df中查找这些组合的数据
    plot_df = full_grid_df.merge(selected_df[['condition', 'TM', 'score', 'p_adjusted']], 
                                on=['condition', 'TM'], 
                                how='left')

    # 如果仍然缺失，则从原始数据中查找
    missing_mask = plot_df['score'].isna()
    if missing_mask.any():
        missing_combinations = plot_df[missing_mask][['condition', 'TM']]
        found_data = missing_combinations.merge(
            combined_df[['condition', 'TM', 'score', 'p_adjusted']], 
            on=['condition', 'TM'], 
            how='left'
        )
        plot_df.loc[missing_mask, 'score'] = found_data['score'].values
        plot_df.loc[missing_mask, 'p_adjusted'] = found_data['p_adjusted'].values

    # 添加是否显著的列
    plot_df['significant'] = plot_df['p_adjusted'] < sig_threshold

    # 设置坐标映射
    cond_to_num = {cond: i for i, cond in enumerate(display_conditions)}
    tm_to_num = {tm: i for i, tm in enumerate(display_tms)}
    plot_df['x_pos'] = plot_df['condition'].map(cond_to_num)
    plot_df['y_pos'] = plot_df['TM'].map(tm_to_num)

    # 开始绘图
    plt.figure(figsize=(10, 8))

    # 主要气泡图
    scatter = plt.scatter(
        x=plot_df['x_pos'],
        y=plot_df['y_pos'],
        s=np.abs(plot_df['score']) * 1000 + 50,
        c=plot_df['score'],
        cmap='RdBu_r',
        alpha=0.7,
        edgecolors='black',
        linewidth=0.5,
        vmin=-0.5,
        vmax=0.5
    )

    # 绘制显著点的红圈
    sig_points = plot_df[plot_df['significant']]
    if not sig_points.empty:
        plt.scatter(
            x=sig_points['x_pos'],
            y=sig_points['y_pos'],
            s=(np.abs(sig_points['score']) * 1000 + 50),  # 与主图一致大小
            facecolors='none',
            edgecolors='red',
            linewidth=2,
            label=' Significant (p < {})'.format(sig_threshold)
        )

    # 添加颜色条
    cbar = plt.colorbar(scatter)
    cbar.set_label('Enrichment Score', fontsize=12)

    # 设置坐标轴标签
    plt.xticks(ticks=list(cond_to_num.values()), labels=list(cond_to_num.keys()), rotation=45, ha='right')
    plt.yticks(ticks=list(tm_to_num.values()), labels=list(tm_to_num.keys()))
    plt.xlabel('Condition', fontsize=12, fontweight='bold')
    plt.ylabel('TM', fontsize=12, fontweight='bold')
    plt.xlim(-0.5, len(display_conditions) - 0.5)
    

    # 图例
    # plt.legend(loc='upper right')
    plt.legend(loc=(1.2,0.8), frameon=False)

    # 布局调整
    plt.tight_layout()
    plt.grid(False)
    # plt.savefig('figures/dotplot_colitis_TM_enrich_score.pdf', bbox_inches='tight')
    plt.show()

In [None]:
visualize_tm_enrichment_bubble(results, top_n=100,sig_threshold=0.05,
                               condition_order=['control', 'disease'])

In [None]:
from sklearn.neighbors import NearestNeighbors
from scipy import sparse
import warnings

def overcorrection_score(emb, celltype, n_neighbors=100, n_samples=None, 
                                 min_cells_per_type=5, random_state=42, weighted=False):
    """
    Improved overcorrection score that evaluates whether integration has over-mixed
    biological cell types while removing batch effects.
    
    Parameters:
    -----------
    emb : array-like, shape (n_cells, n_features)
        Embedding coordinates (e.g., UMAP, PCA)
    celltype : array-like, shape (n_cells,)
        Cell type labels for each cell
    n_neighbors : int, optional (default=30)
        Number of neighbors to consider for each cell
    n_samples : int, optional (default=1000)
        Number of cells to sample for estimation (use None for all cells)
    min_cells_per_type : int, optional (default=5)
        Minimum number of cells required for a cell type to be included in scoring
    random_state : int, optional (default=42)
        Random seed for reproducibility
    weighted : bool, optional (default=True)
        Whether to weight scores by cell type prevalence
    
    Returns:
    --------
    score : float
        Overcorrection score (higher values indicate more overcorrection)
    celltype_scores : dict
        Dictionary with scores for each cell type
    """
    
    # Input validation
    if len(emb) != len(celltype):
        raise ValueError("emb and celltype must have the same length")
    
    if n_neighbors >= len(emb):
        warnings.warn(f"n_neighbors ({n_neighbors}) is too large for dataset size ({len(emb)}). "
                      f"Reducing to {len(emb) - 1}")
        n_neighbors = len(emb) - 1
    
    # Convert to numpy arrays
    emb = np.asarray(emb)
    celltype = np.asarray(celltype)
    
    # Get unique cell types and their counts
    unique_types, type_counts = np.unique(celltype, return_counts=True)
    
    # Filter out cell types with too few cells
    valid_types = unique_types[type_counts >= min_cells_per_type]
    if len(valid_types) < 2:
        raise ValueError(f"Need at least 2 cell types with ≥ {min_cells_per_type} cells each")
    
    # Create a mask for valid cells
    valid_mask = np.isin(celltype, valid_types)
    
    if not np.all(valid_mask):
        warnings.warn(f"Ignoring {np.sum(~valid_mask)} cells from rare cell types "
                      f"(< {min_cells_per_type} cells)")
    
    # Subset to valid cells
    emb_valid = emb[valid_mask]
    celltype_valid = celltype[valid_mask]
    
    # Build nearest neighbors graph
    nne = NearestNeighbors(n_neighbors=min(n_neighbors + 1, len(emb_valid)), 
                           n_jobs=-1)  # Use all available cores
    nne.fit(emb_valid)
    kmatrix = nne.kneighbors_graph(emb_valid, mode='connectivity')
    
    # Remove self-connections
    kmatrix = kmatrix - sparse.identity(kmatrix.shape[0])
    
    # Sample cells if requested
    if n_samples is not None and n_samples < len(emb_valid):
        rng = np.random.RandomState(random_state)
        sample_indices = rng.choice(len(emb_valid), size=n_samples, replace=False)
    else:
        sample_indices = np.arange(len(emb_valid))
    
    # Calculate scores per cell type
    celltype_scores = {}
    for ct in valid_types:
        # Get indices of cells of this type
        ct_indices = np.where(celltype_valid == ct)[0]
        
        # Sample from this cell type if needed
        if n_samples is not None:
            n_sample_ct = max(1, int(n_samples * len(ct_indices) / len(emb_valid)))
            if n_sample_ct < len(ct_indices):
                ct_sample_indices = rng.choice(ct_indices, size=n_sample_ct, replace=False)
            else:
                ct_sample_indices = ct_indices
        else:
            ct_sample_indices = ct_indices
        
        # Calculate average same-type proportion for this cell type
        same_type_props = []
        for i in ct_sample_indices:
            # Get neighbors (excluding self)
            neighbors = kmatrix[i].nonzero()[1]
            
            # Calculate proportion of same-type neighbors
            same_type_count = np.sum(celltype_valid[neighbors] == celltype_valid[i])
            same_type_prop = same_type_count / len(neighbors) if len(neighbors) > 0 else 0
            same_type_props.append(same_type_prop)
        
        celltype_scores[ct] = np.mean(same_type_props) if same_type_props else 0
    
    # Calculate overall score
    if weighted:
        # Weight by cell type prevalence
        weights = [type_counts[unique_types == ct][0] for ct in valid_types]
        overall_score = 1 - np.average(list(celltype_scores.values()), weights=weights)
    else:
        # Simple average across cell types
        overall_score = 1 - np.mean(list(celltype_scores.values()))
    
    return overall_score

In [None]:
overcorrection_score(adata.obsm['X_umap'], adata.obs['tissue'])