In [None]:
from tempo_sc import TEMPO_Selector
import anndata as ad
import numpy as np

# Load dataset
adata = ad.read_h5ad("bulk_data_lungalveoli_TPS.h5ad")

# Create TEMPO selector
tempo = TEMPO_Selector(
    seed=1234,
    max_ref=15,
    beam_width=32,
    precompute_gcn=False,
)

# Run beam search
S_sel, pack_sel, hist = tempo.fit(adata, normalize_data=False, verbose=True)

print("Selected time points:", S_sel)

# Predict full matrix
P = tempo.predict_full_from_pack(S_sel, pack_sel)
print("Prediction matrix shape:", P.shape)     # (G, T)


[Step 1] S=[6] added_t=6 MAE_train=0.074042 MSE_train=0.083836 R2_train=0.754327 | MAE_val=0.070176 MSE_val=0.082016 R2_val=0.705493 time=42.418s
[Step 2] S=[6, 0] added_t=0 MAE_train=0.046651 MSE_train=0.033271 R2_train=0.902492 | MAE_val=0.046739 MSE_val=0.034628 R2_val=0.871500 time=487.578s


In [None]:
import pdb,sys,os
import warnings
warnings.filterwarnings('ignore')
import anndata
import scanpy as sc
sc.settings.verbosity = 0
import argparse
import copy
import numpy as np
import scipy
import timeit

from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
from typing import Tuple
import scSemiProfiler as semi
from scSemiProfiler.utils import *
name = 'single_cell_inference_project_lung_Alveolus_high'
bulk = 'bulk_data_lung_Alveolus.h5ad'
logged = False
normed = False
geneselection = False
batch = 3

t0 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t0.h5ad")
t1 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t1.h5ad")
t2 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t2.h5ad")
t3 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t3.h5ad")
t4 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t4.h5ad")
t5 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t5.h5ad")
t6 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t6.h5ad")

In [None]:
import pdb, sys, os
import anndata
import scanpy as sc
import argparse
import copy
import numpy as np
from sklearn.metrics import pairwise_distances
from typing import Union
import matplotlib.pyplot as plt

def initsetup(name: str, bulk: str, logged: bool = False, normed: bool = True,
              geneselection: Union[bool, int] = True, representatives: list = None) -> None:
    """
    Initial setup of the semi-profiling pipeline, processing the bulk data,
    and assigning each sample to the nearest fixed representative.

    Parameters
    ----------
    name : str
        Project name.
    bulk : str
        Path to bulk data as an h5ad file.
    logged : bool
        Whether the data has been logged or not.
    normed : bool
        Whether the library size has been normalized or not.
    geneselection : bool or int
        Perform gene selection (boolean) or specify number of highly variable genes.
    representatives : list
        Indices of fixed representative samples.

    Returns
    -------
    None

    Example
    -------
    >>> name = 'runexample'
    >>> bulk = 'example_data/bulkdata.h5ad'
    >>> logged = False
    >>> normed = True
    >>> geneselection = False
    >>> representatives = [0, 2, 5]  # Fixed representative indices
    >>> initsetup(name, bulk, logged, normed, geneselection, representatives)
    """

    print('Start initial setup')

    if not os.path.isdir(name):
        os.system('mkdir ' + name)
    else:
        print(name + ' exists. Please choose another name.')
        return

    if not os.path.isdir(name + '/figures'):
        os.system('mkdir ' + name + '/figures')

    bulkdata = anndata.read_h5ad(bulk)

    if not normed:
        if logged:
            print('Bad data preprocessing. Normalize library size before log-transformation.')
            return
        sc.pp.normalize_total(bulkdata, target_sum=1e4)

    if not logged:
        sc.pp.log1p(bulkdata)

    sids = list(bulkdata.obs['sample_ids'])
    with open(name + '/sids.txt', 'w') as f:
        for sid in sids:
            f.write(sid + '\n')

    if geneselection is False:
        hvgenes = np.array(bulkdata.var.index)
    elif geneselection is True:
        sc.pp.highly_variable_genes(bulkdata, n_top_genes=6000)
        bulkdata = bulkdata[:, bulkdata.var.highly_variable]
        hvgenes = np.array(bulkdata.var.index)[bulkdata.var.highly_variable]
    else:
        sc.pp.highly_variable_genes(bulkdata, n_top_genes=int(geneselection))
        bulkdata = bulkdata[:, bulkdata.var.highly_variable]
        hvgenes = np.array(bulkdata.var.index)[bulkdata.var.highly_variable]
    np.save(name + '/hvgenes.npy', hvgenes)

    n_comps = min(100, bulkdata.X.shape[0] - 1)
    sc.tl.pca(bulkdata, n_comps=n_comps)

    bulkdata.write(name + '/processed_bulkdata.h5ad')

    if representatives is None or len(representatives) == 0:
        print("Please provide fixed representative indices.")
        return

    representatives_pca = bulkdata.obsm['X_pca'][representatives]
    distances = pairwise_distances(bulkdata.obsm['X_pca'], representatives_pca)
    cluster_labels = np.argmin(distances, axis=1)

    # Store the cluster labels
    if not os.path.isdir(name + '/status'):
        os.system('mkdir ' + name + '/status')

    with open(name + '/status/init_cluster_labels.txt', 'w') as f:
        for label in cluster_labels:
            f.write(str(label) + '\n')

    with open(name + '/status/init_representatives.txt', 'w') as f:
        for rep in representatives:
            f.write(str(rep) + '\n')

    print('Initial setup finished. Among ' + str(len(sids)) +
          ' total samples, assigned to fixed representatives:')
    for i, rep in enumerate(representatives):
        print(f"Cluster {i} representative: {sids[rep]}")

    return
     

In [None]:
initsetup(name,bulk,logged=logged,normed=normed,geneselection=True,representatives=[0,3,6])

In [None]:
import anndata as ad
import hdf5plugin

reps_processed = ad.concat([t0, t3,t6], axis=0, join='inner')

print(f"Number of observations (cells): {reps_processed.n_obs}")
print(f"Number of variables (genes): {reps_processed.n_vars}")

if 'cell_id' not in reps_processed.obs.columns:
    reps_processed.obs['cell_id'] = reps_processed.obs_names

if 'n_genes' not in reps_processed.obs.columns:
    reps_processed.obs['n_genes'] = (reps_processed.X > 0).sum(axis=1)


if 'gene_ids' not in reps_processed.var.columns:
    reps_processed.var['gene_ids'] = reps_processed.var_names


reps_processed.obs.columns = reps_processed.obs.columns.astype(str)
reps_processed.var.columns = reps_processed.var.columns.astype(str)

# Convert object dtype columns in obs and var to strings
for col in reps_processed.obs.columns:
    if reps_processed.obs[col].dtype == 'object':
        reps_processed.obs[col] = reps_processed.obs[col].astype(str)

for col in reps_processed.var.columns:
    if reps_processed.var[col].dtype == 'object':
        reps_processed.var[col] = reps_processed.var[col].astype(str)

print("Data types in obs:")
print(reps_processed.obs.dtypes)
print("Data types in var:")
print(reps_processed.var.dtypes)
import numpy as np

hvgenes = np.load(name + '/hvgenes.npy', allow_pickle=True)

print("First few genes in hvgenes:", hvgenes[:5])

reps_genes = reps_processed.var_names

common_genes = np.intersect1d(hvgenes, reps_genes)

print(f"Number of genes in hvgenes: {len(hvgenes)}")
print(f"Number of genes in reps_processed: {len(reps_genes)}")
print(f"Number of common genes: {len(common_genes)}")

missing_in_reps = np.setdiff1d(hvgenes, reps_genes)
print(f"Number of genes in hvgenes not in reps_processed: {len(missing_in_reps)}")

hvgenes_in_reps_ordered = [gene for gene in hvgenes if gene in reps_genes]




reps_filtered = reps_processed[:, hvgenes_in_reps_ordered].copy()


assert all(reps_filtered.var_names == hvgenes_in_reps_ordered), "Gene order does not match!"

In [None]:
reps_filtered.write_h5ad(
      name+'/representative_sc.h5ad',
      compression=hdf5plugin.FILTERS["zstd"]
    )

In [None]:
semi.scprocess(name=name,singlecell=name+'/representative_sc.h5ad',normed=True,logged=False,cellfilter=False,threshold=1e-3,geneset=True,weight=0.5,k=15)

In [None]:
# read the representatives and clusterings
sids = []
f = open(name + '/sids.txt','r')
lines = f.readlines()
for l in lines:
    sids.append(l.strip())
f.close()

repres = []
f=open(name + '/status/init_representatives.txt','r')
lines = f.readlines()
f.close()
for l in lines:
    repres.append(int(l.strip()))

cl = []
f=open(name + '/status/init_cluster_labels.txt','r')
lines = f.readlines()
f.close()
for l in lines:
    cl.append(int(l.strip()))

print('representatives:',repres)
print('cluster labels:',cl)

In [None]:
import torch

torch.cuda.empty_cache()


representatives = name + '/status/init_representatives.txt'
cluster = name + '/status/init_cluster_labels.txt'

bulktype = 'pseudobulk'
semi.scinfer(name, representatives,cluster,bulktype, device='cuda:0')

In [None]:
cluster_labels = cl
semisdata = assemble_cohort(name,
                repres,
                cl,
                celltype_key = 'celltype',
                sample_info_keys = ['sample_ids'],
                bulkpath= 'bulk_data_lung_Alveolus.h5ad')
     

In [None]:

# read the combined adata of gound true single cell data for subsequent comparison
combined_adata = anndata.read_h5ad(name+"/combined_data.h5ad")

In [None]:
#filter out NA celltypes
import pandas as pd

invalid_values = [None, pd.NA, float('nan'), 'nan', 'NA']

def filter_invalid_celltypes(adata):
    return adata[~adata.obs['celltype'].astype(str).str.strip().isin(invalid_values)].copy()

combined_adata = filter_invalid_celltypes(combined_adata)
semisdata = filter_invalid_celltypes(semisdata)

print(f"Filtered combined_adata cells: {combined_adata.n_obs}")
print(f"Filtered semisdata cells: {semisdata.n_obs}")

combined_adata.write_h5ad('combined_adata_filtered.h5ad')
semisdata.write_h5ad('semisdata_filtered.h5ad')

In [None]:

# visualize distribution of assembled ground truth data and semi-profiled data
combined_data,gtdata,semidata = compare_umaps(
            semidata = semisdata,
            gtdata = combined_adata,
            name = name,
            representatives = name + '/status/init_representatives.txt',
            cluster_labels = name + '/status/init_cluster_labels.txt',
            celltype_key = 'celltype',
            save = name+"/figures"
            )
     

In [None]:
def composition_by_group(
    adata: anndata.AnnData,
    colormap: Union[str, list] = None,
    groupby: str = None,
    title: str = 'Cell type composition',
    save: str = None,
    name: str = None
) -> None:
    """
    Visualizing the cell type composition in each group.

    Parameters
    ----------
    adata:
        The dataset to investigate.
    colormap:
        The colormap for visualization.
    groupby:
        The key in .obs specifying groups.
    title:
        Plot title.
    save:
        Path to save the plot as a PDF file.
    name:
        Folder to save the file in, if provided.

    Returns
    -------
        None

    Example
    -------
    >>> groupby = 'states_collection_sum'
    >>> composition_by_group2(
    >>>     adata=gtdata,
    >>>     groupby=groupby,
    >>>     title='Ground truth'
    >>> )
    """
    totaltypes = np.array(adata.obs['celltype'].cat.categories)

    if colormap is None:
        colormap = adata.uns['celltypes_colors']

    conditions = np.unique(adata.obs[groupby])
    n = conditions.shape[0]
    percentages = []

    for i in range(conditions.shape[0]):
        condition_prop = celltype_proportion(adata[adata.obs[groupby] == conditions[i]], totaltypes)
        percentages.append(condition_prop)

    fig, axs = plt.subplots(n, 1, figsize=(n, 1))
    axs[0].set_title(title)

    for j in range(n):
        for i in range(len(totaltypes)):
            axs[j].barh(conditions[j], percentages[j][i], left=sum(percentages[j][:i]), color=colormap[i])
            axs[j].set_xlim([0, 1])
            axs[j].set_yticklabels([])
            axs[j].yaxis.set_tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

            if j != n:
                axs[j].set_xticklabels([])

        axs[j].text(-0.01, 0, conditions[j], ha='right', va='center')

    patches = [mpatches.Patch(color=colormap[i], label=totaltypes[i]) for i in range(len(totaltypes))]
    axs[-1].legend(handles=patches, loc='center left', bbox_to_anchor=(1.1, n))

    plt.xlabel('Proportion')

    if save is not None:
        save_path = f"{name}/{save}.pdf" if name else f"{save}.pdf"
        plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')

    plt.show()


In [None]:
# visualize cell types composition by timepoints
groupby = 'sample_ids'
composition_by_group(
    adata = combined_adata,
    groupby = groupby,
    title = 'Ground truth',
    colormap = semidata.uns['celltype_colors'],
    save = "/composition_gt",
    name = name
    )

In [None]:
enrichment_comparison(name, combined_adata, semisdata, celltype_key = 'celltype', selectedtype = "AT1", save = "figures")

In [None]:
def enrichment_comparison_reactome(name:str,
                                   gtdata:anndata.AnnData,
                                   semisdata:anndata.AnnData,
                                   celltype_key:str,
                                   selectedtype:str,
                                   save = None
                                  ) -> Tuple[np.array, np.array, np.array, np.array]:
    """
    Compare the enrichment analysis results using the real-profiled and semi-profiled datasets, using Reactome pathway sets.

    Parameters
    ----------
    name:
        Project name
    gtdata:
        Real-profiled (ground truth) data (AnnData object)
    semisdata:
        Semi-profiled dataset (AnnData object)
    celltype_key:
        The key in anndata.AnnData.obs that stores cell type information
    selectedtype:
        The selected cell type to analyze
    save:
        Path within the 'figures' folder to save the plot

    Returns
    -------
    CommonDEGs : int
        The number of overlapping DEGs between real and semi-profiled data
    HypergeometricP : float
        P-value of hypergeometric test examining the overlap between two versions of DEGs
    PearsonR : float
        Pearson correlation between bar lengths in real-profiled and semi-profiled bar plots
    PearsonP : float
        P-value of the Pearson correlation test

    Example
    -------
    >>> _ = enrichment_comparison_reactome(name, gtdata, semisdata, celltype_key='celltypes', selectedtype='CD4')
    """

    totaltypes = np.unique(gtdata.obs[celltype_key])

    sc.tl.rank_genes_groups(gtdata, celltype_key, method='t-test')
    typededic = {}
    for j in range(totaltypes.shape[0]):
        celltype = totaltypes[j]
        typede = []
        for i in range(100):
            g = gtdata.uns['rank_genes_groups']['names'][i][j]
            typede.append(g)
        typededic[celltype] = typede

    sc.tl.rank_genes_groups(semisdata, celltype_key, method='t-test')
    semitypededic = {}
    for j in range(totaltypes.shape[0]):
        celltype = totaltypes[j]
        typede = []
        for i in range(100):
            g = semisdata.uns['rank_genes_groups']['names'][i][j]
            typede.append(g)
        semitypededic[celltype] = typede

    gtdeg = typededic[selectedtype]
    semideg = semitypededic[selectedtype]
    c = sum([1 for i in semideg if i in gtdeg])

    hyperpval = hypert(semisdata.X.shape[1], 100, 100, c)
    print('p-value of hypergeometric test for overlapping DEGs:', str(float(hyperpval)))

    if (os.path.isdir(name + '/gseapygt')) == False:
        os.system('mkdir ' + name + '/gseapygt')
    if (os.path.isdir(name + '/gseapysemi')) == False:
        os.system('mkdir ' + name + '/gseapysemi')

    results = gseapy.enrichr(gene_list=gtdeg, gene_sets='Reactome_2022', outdir=name + '/gseapygt')
    f = open(name + '/gseapygt/Reactome_2022.human.enrichr.reports.txt', 'r')
    lines = f.readlines()
    f.close()

    gtsets = []
    gtps = []
    gtdic = {}
    for l in lines[1:]:
        term = l.split('\t')[1]
        p = float(l.split('\t')[4])
        gtsets.append(term)
        gtps.append(p)
        gtdic[term] = p

    results = gseapy.enrichr(gene_list=semideg, gene_sets='Reactome_2022', outdir=name + '/gseapysemi')
    f = open(name + '/gseapysemi/Reactome_2022.human.enrichr.reports.txt','r')
    lines = f.readlines()
    f.close()

    semisets = []
    semips = []
    semidic = {}
    for l in lines[1:]:
        term = l.split('\t')[1]
        p = float(l.split('\t')[4])
        semisets.append(term)
        semips.append(p)
        semidic[term] = p

    terms = copy.deepcopy(gtsets[:10])
    real_data = copy.deepcopy(gtps[:10])
    sim_data = []
    for i in range(10):
        gtterm = semisets[i]
        if gtterm not in semidic.keys():
            sim_data.append(1)
        else:
            sim_data.append(semidic[gtterm])

    for i in range(10):
        if semisets[i] in terms:
            continue
        terms.append(semisets[i])
        sim_data.append(semips[i])
        if semisets[i] not in gtdic.keys():
            real_data.append(1)
        else:
            real_data.append(gtdic[semisets[i]])

    real_data = np.flip(real_data)
    sim_data = np.flip(sim_data)
    terms = np.flip(terms)
    sim_bar_lengths = [-np.log10(p) for p in sim_data]
    real_bar_lengths = [-np.log10(p) for p in real_data]

    res = scipy.stats.pearsonr(np.array(sim_bar_lengths), np.array(real_bar_lengths))
    print('Significance correlation:', res)

    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(8, 5))
    bar_width = 0.4
    y = np.arange(len(sim_data)) + 1
    ax1.barh(y, real_bar_lengths, height=bar_width, color='green', label='Real')
    ax1.set_xlabel('-log10(p)')
    ax1.set_ylabel('Term')
    ax1.set_title('Real Data (' + str(len(semideg)) + ' DEGs)')
    ax2.barh(y, sim_bar_lengths, height=bar_width, color='blue', label='Simulated')
    ax2.set_xlabel('-log10(p)')
    ax2.set_title('Semi-profiled Data(' + str(len(gtdeg)) + ' DEGs)')

    max_val = max(max(sim_bar_lengths), max(real_bar_lengths))
    ax1.set_xlim(0, max_val + 1)
    ax2.set_xlim(0, max_val + 1)
    ax1.invert_xaxis()
    ax1.set_yticks(y)
    ax2.set_yticklabels(terms)
    fig.suptitle(selectedtype + ' Reactome (' + str(c) + ' Overlap DEGs)')

    if save is not None:
        plt.savefig(name + '/figures/' + save + selectedtype + ' Reactome.pdf', bbox_inches='tight')
        plt.savefig(name + '/figures/' + save + selectedtype + ' Reactome.jpg', dpi=600, bbox_inches='tight')
    plt.show()

    return c, float(hyperpval), res[0], res[1]


In [None]:
enrichment_comparison_reactome(name, combined_adata, semisdata, celltype_key = 'celltype', selectedtype = "AT1", save = "figures")