In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import random
import os
import warnings
import sys
import argparse
import scipy.sparse as sp
import logging
from anndata import AnnData

from grn_inference import utils

logging.basicConfig(level=logging.INFO, format="%(message)s")

In [None]:

PROJECT_DIR = "/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER"
RAW_MESC_DATA_DIR = "/gpfs/Labs/Uzun/DATA/PROJECTS/2024.SC_MO_TRN_DB.MIRA/REPOSITORY/CURRENT/SINGLE_CELL_DATASETS/DS014_DOI496239_MOUSE_ESC_RAW_FILES"
MESC_PEAK_MATRIX_FILE = "/gpfs/Labs/Uzun/DATA/PROJECTS/2024.SC_MO_TRN_DB.MIRA/REPOSITORY/CURRENT/SINGLE_CELL_DATASETS/DS014_DOI496239_MOUSE_ESCDAYS7AND8/scATAC_PeakMatrix.txt"

MM10_GENOME_DIR = os.path.join(PROJECT_DIR, "data/reference_genome/mm10")
MM10_GENE_TSS_FILE = os.path.join(PROJECT_DIR, "data/genome_annotation/mm10/mm10_TSS.bed")
GROUND_TRUTH_DIR = os.path.join(PROJECT_DIR, "ground_truth_files")
SAMPLE_INPUT_DIR = os.path.join(PROJECT_DIR, "input/mESC/")
OUTPUT_DIR = os.path.join(PROJECT_DIR, "output/transformer_testing_output")


def get_adata_from_peakmatrix(peak_matrix_file: str, label: pd.DataFrame) -> AnnData:
    # Read header only
    all_cols = pd.read_csv(peak_matrix_file, sep="\t", nrows=10).columns
    
    print("First few ATAC barcodes:", all_cols[:10].tolist())
    print("Overlap count after normalization:", 
        len(set(label["barcode_use"]) & set(all_cols)))

    # Identify barcodes shared between RNA and ATAC
    matching_barcodes = set(label["barcode_use"]) & set(all_cols)

    # Map from original index -> normalized barcode
    col_map = {i: bc for i, bc in enumerate(all_cols)}

    # Always keep the first column (peak IDs)
    keep_indices = [0] + [i for i, bc in col_map.items() if bc in matching_barcodes]

    # Read only those columns
    peak_matrix = pd.read_csv(
        peak_matrix_file,
        sep="\t",
        usecols=keep_indices,
        index_col=0
    )

    # Replace column names with normalized barcodes
    new_cols = [col_map[i] for i in keep_indices[1:]]
    peak_matrix.columns = new_cols

    # Construct AnnData
    X = sp.csr_matrix(peak_matrix.values)
    adata_ATAC = AnnData(X=X.T)

    # Assign metadata
    adata_ATAC.obs_names = new_cols
    adata_ATAC.obs["barcode"] = new_cols
    adata_ATAC.obs["sample"] = sample_name
    adata_ATAC.obs["label"] = label.set_index("barcode_use").loc[new_cols, "label"].values

    adata_ATAC.var_names = peak_matrix.index
    adata_ATAC.var["gene_ids"] = peak_matrix.index

    return adata_ATAC


sample_name = "E7.5_rep1"

# for sample_name in os.listdir(RAW_MESC_DATA_DIR):
sample_raw_data_dir = os.path.join(RAW_MESC_DATA_DIR, sample_name)
sample_processed_data_dir = os.path.join(RAW_MESC_DATA_DIR, sample_name)

adata_RNA = sc.read_10x_mtx(
    sample_raw_data_dir,
    var_names="gene_symbols",   # or "gene_ids"
    make_unique=True,
    prefix="GSM6205416_E7.5_rep1_GEX_"
)

adata_RNA.obs_names = [(sample_name + "." + i).replace("-", ".") for i in adata_RNA.obs_names]
print(f"Found {len(adata_RNA.obs_names)} cell barcodes")

print("First few RNA barcodes:", adata_RNA.obs_names[:10])

label = pd.DataFrame({"barcode_use":adata_RNA.obs_names, "label":["mESC"] * len(adata_RNA.obs_names)})

adata_ATAC = get_adata_from_peakmatrix(MESC_PEAK_MATRIX_FILE, label)

# Add barcode column
adata_RNA.obs['barcode'] = adata_RNA.obs_names

print("RNA barcodes example:", adata_RNA.obs['barcode'][:5].tolist())
print("ATAC barcodes example:", adata_ATAC.obs['barcode'][:5].tolist())
print("Overlap count:", len(set(adata_RNA.obs['barcode']) & set(adata_ATAC.obs['barcode'])))


common_barcodes = adata_RNA.obs['barcode'].isin(adata_ATAC.obs['barcode'])
adata_RNA = adata_RNA[common_barcodes].copy()
adata_ATAC = adata_ATAC[adata_ATAC.obs['barcode'].isin(adata_RNA.obs['barcode'])].copy()


# Add sample column (strip suffix "-1" if you want sample numbers, else default 1)
adata_RNA.obs['sample'] = sample_name

# Add label column from your label DataFrame
label_lookup = label.set_index("barcode_use").loc[adata_RNA.obs['barcode']]
adata_RNA.obs['label'] = label_lookup['label'].values

# QC fields like in the reference get_adata
adata_RNA.var['mt'] = adata_RNA.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata_RNA, qc_vars=["mt"], inplace=True)
adata_RNA = adata_RNA[adata_RNA.obs.pct_counts_mt < 5].copy()

# Ensure gene IDs are unique
adata_RNA.var.index = adata_RNA.var_names
adata_RNA.var_names_make_unique()
adata_RNA.var['gene_ids'] = adata_RNA.var.index

Only considering the two last: ['.mtx', '.gz'].
Only considering the two last: ['.mtx', '.gz'].


First few ATAC barcodes: ['peak_coord', 'E8.5_rep1.TTACGTTTCTGGCATG.1', 'E8.5_rep1.GAATTTGTCGGTCAAT.1', 'E8.5_rep1.GGAGTCTGTGTTTCAC.1', 'E8.5_rep1.AGTTATGTCTCACTCA.1', 'E8.5_rep1.GAGCTAGCAACTCGCG.1', 'E8.5_rep1.GGCTGAGAGCTTAACA.1', 'E8.5_rep1.GCTTTATTCTTGGATA.1', 'E8.5_rep1.GGCTGAGAGCTCAATA.1', 'E8.5_rep1.TTGACTAAGGTCGAGG.1']
Overlap count after normalization: 7416
RNA barcodes example: ['E7.5_rep1.AAACAGCCAAACCCTA.1', 'E7.5_rep1.AAACAGCCAAACTCAT.1', 'E7.5_rep1.AAACAGCCACAACCTA.1', 'E7.5_rep1.AAACAGCCAGGAACTG.1', 'E7.5_rep1.AAACAGCCATCCTGAA.1']
ATAC barcodes example: ['E7.5_rep1.TCTAGCCTCTCACTAT.1', 'E7.5_rep1.GGAAGTATCCGGGACT.1', 'E7.5_rep1.CTTCGCGTCATTCATC.1', 'E7.5_rep1.TAGTGGCGTTCATCTA.1', 'E7.5_rep1.CTACTAAAGTTCCCAC.1']
Overlap count: 7416


In [29]:
# --- Save aligned & QC-filtered AnnData objects ---
sample_data_dir = os.path.join(SAMPLE_INPUT_DIR, sample_name)
os.makedirs(sample_data_dir, exist_ok=True)

adata_RNA.write_h5ad(os.path.join(sample_data_dir, f"{sample_name}_RNA_qc.h5ad"))
adata_ATAC.write_h5ad(os.path.join(sample_data_dir, f"{sample_name}_ATAC_qc.h5ad"))

print(f"Saved RNA and ATAC AnnData objects for {sample_name} to {sample_data_dir}")


Saved RNA and ATAC AnnData objects for E7.5_rep1 to /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/input/mESC/E7.5_rep1


In [34]:
adata_RNA = sc.read_h5ad(os.path.join(sample_data_dir, "E7.5_rep1_RNA_qc.h5ad"))
adata_ATAC = sc.read_h5ad(os.path.join(sample_data_dir, "E7.5_rep1_ATAC_qc.h5ad"))

In [None]:


def tfidf(atac_matrix: np.ndarray) -> np.ndarray:
    """
    Performs a TF-IDF-like transformation on the ATAC-seq matrix to highlight important regulatory elements.

    Parameters:
        atac_matrix (np.ndarray):
            A matrix of ATAC-seq data, where rows are regulatory elements (peaks) and columns are cells. 
            Values represent the accessibility of the peaks in each cell.

    Returns:
        transformed_matrix (np.ndarray):
            Transformed matrix where rows represent regulatory elements (peaks) and columns represent cells,
            with values weighted by TF-IDF-like scores.
    """

    # Create a binary matrix indicating presence/absence of peaks
    binary_matrix: np.ndarray = 1 * (atac_matrix > 0)
    
    # Calculate term frequency (TF) normalized by log of total accessibility per cell
    term_freq: np.ndarray = binary_matrix / (np.ones((binary_matrix.shape[0], 1)) * np.log(1 + np.sum(binary_matrix, axis=0))[np.newaxis, :])
    
    # Calculate inverse document frequency (IDF) based on peak occurrence across cells
    inverse_doc_freq: np.ndarray = np.log(1 + binary_matrix.shape[1] / (1 + np.sum(binary_matrix > 0, axis=1)))
    
    # Compute the TF-IDF-like matrix
    tfidf_matrix: np.ndarray = term_freq * (inverse_doc_freq[:, np.newaxis] * np.ones((1, binary_matrix.shape[1])))
    
    # Replace any NaN values with 0 (due to division by zero)
    tfidf_matrix[np.isnan(tfidf_matrix)] = 0
    
    # Return the transposed matrix (cells as rows, peaks as columns)
    transformed_matrix: np.ndarray = tfidf_matrix.T
    return transformed_matrix

def find_neighbors(rna_data: AnnData, atac_data: AnnData) -> tuple[AnnData, AnnData]:
    """
    Combines RNA and ATAC-seq data in a joint PCA space and identifies neighbors based on combined features.

    Parameters:
        rna_data (AnnData):
            AnnData object containing RNA expression data.
        atac_data (AnnData):
            AnnData object containing ATAC-seq data.

    Returns:
        tuple (AnnData, AnnData):
            Updated `rna_data` and `atac_data` objects with combined PCA representation.
    """
    neighbors_k: int = 20  # Number of neighbors to find
    
    ### RNA Data Preprocessing ###
    # Normalize RNA expression data and log-transform
    sc.pp.normalize_total(rna_data, target_sum=1e4)
    sc.pp.log1p(rna_data)
    
    # Identify highly variable genes
    sc.pp.highly_variable_genes(rna_data, min_mean=0.0125, max_mean=3, min_disp=0.5)
    
    # Save raw data and subset highly variable genes
    rna_data.raw = rna_data
    rna_data = rna_data[:, rna_data.var.highly_variable]
    
    # Scale the data and perform PCA for dimensionality reduction
    sc.pp.scale(rna_data, max_value=10)
    sc.tl.pca(rna_data, n_comps=15, svd_solver="arpack")
    
    # Store the PCA results for RNA
    pca_rna: np.ndarray = rna_data.obsm['X_pca']
    
    ### ATAC Data Preprocessing ###
    # Log-transform ATAC-seq data
    sc.pp.log1p(atac_data)
    
    # Identify highly variable peaks
    sc.pp.highly_variable_genes(atac_data, min_mean=0.0125, max_mean=3, min_disp=0.5)
    
    # Save raw ATAC data and subset highly variable peaks
    atac_data.raw = atac_data
    atac_data = atac_data[:, atac_data.var.highly_variable]
    
    # Scale the ATAC data and perform PCA
    sc.pp.scale(atac_data, max_value=10, zero_center=True)
    sc.tl.pca(atac_data, n_comps=15, svd_solver="arpack")
    
    # Store the PCA results for ATAC
    pca_atac: np.ndarray = atac_data.obsm['X_pca']
    
    ### Combine RNA and ATAC PCA Results ###
    combined_pca: np.ndarray = np.concatenate((pca_rna, pca_atac), axis=1)
    
    # Store the combined PCA representation in both AnnData objects
    rna_data.obsm['pca'] = combined_pca
    atac_data.obsm['pca'] = combined_pca

    return rna_data, atac_data

def pseudo_bulk(rna_data: AnnData, atac_data: AnnData, single_pseudo_bulk: int) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Generates pseudo-bulk RNA and ATAC profiles by aggregating cells with similar profiles based on neighbors.

    Parameters:
        rna_data (AnnData):
            AnnData object containing RNA expression data.
        atac_data (AnnData):
            AnnData object containing ATAC-seq data.
        single_pseudo_bulk (int):
            If set to a value greater than 0, limits each cluster to 1 sample for pseudo-bulk creation.

    Returns:
        tuple (pd.DataFrame, pd.DataFrame):
            pseudo_bulk_rna : pd.DataFrame
                Pseudo-bulk RNA expression matrix (genes x pseudo-bulk samples).
            pseudo_bulk_atac : pd.DataFrame
                Pseudo-bulk ATAC accessibility matrix (peaks x pseudo-bulk samples).
    """
    neighbors_k: int = 20  # Number of neighbors to use for aggregation

    ### RNA Data Preprocessing ###
    sc.pp.normalize_total(rna_data, target_sum=1e4)
    sc.pp.log1p(rna_data)
    sc.pp.filter_genes(rna_data, min_cells=3)
    sc.pp.highly_variable_genes(rna_data, min_mean=0.0125, max_mean=3, min_disp=0.5)
    rna_data.raw = rna_data
    rna_data = rna_data[:, rna_data.var.highly_variable]
    sc.pp.scale(rna_data, max_value=10)
    sc.tl.pca(rna_data, n_comps=15, svd_solver="arpack")
    
    ### ATAC Data Preprocessing ###
    sc.pp.log1p(atac_data)
    sc.pp.filter_genes(atac_data, min_cells=3)
    sc.pp.highly_variable_genes(atac_data, min_mean=0.0125, max_mean=3, min_disp=0.5)
    atac_data.raw = atac_data
    atac_data = atac_data[:, atac_data.var.highly_variable]
    sc.pp.scale(atac_data, max_value=10, zero_center=True)
    sc.tl.pca(atac_data, n_comps=15, svd_solver="arpack")
    
    ### Combine RNA and ATAC PCA Results ###
    pca_rna: np.ndarray = rna_data.obsm['X_pca']
    pca_atac: np.ndarray = atac_data.obsm['X_pca']
    combined_pca: np.ndarray = np.concatenate((pca_rna, pca_atac), axis=1)
    rna_data.obsm['pca'] = combined_pca
    atac_data.obsm['pca'] = combined_pca
    
    ### Neighbor Graph Construction ###
    sc.pp.neighbors(rna_data, n_neighbors=neighbors_k, n_pcs=30, use_rep='pca')
    connectivity_matrix: np.ndarray = (rna_data.obsp['distances'] > 0)
    
    ### Label Processing and Pseudo-bulk Generation ###
    cell_labels: pd.DataFrame = pd.DataFrame(rna_data.obs['label'])
    cell_labels.index = rna_data.obs_names.tolist()
    
    # Identify unique clusters of cells
    unique_clusters: list = list(set(cell_labels['label'].values))
    selected_indices: list = []
    
    np.random.seed(42)  # Set seed for reproducibility
    
    for cluster_label in unique_clusters:
        cluster_indices: pd.Index = cell_labels.index
        num_cells_in_cluster: int = len(cluster_indices)  # Total number of elements in the cluster
        
        if num_cells_in_cluster >= 10:
            sample_size: int = int(np.floor(np.sqrt(num_cells_in_cluster))) + 1  # Number of elements to sample
            
            if single_pseudo_bulk > 0:
                sample_size = 1  # If single_pseudo_bulk is greater than 0, limit to 1 sample
                
            sampled_elements = random.sample(range(num_cells_in_cluster), sample_size)
            cluster_indices = cluster_indices[sampled_elements]
            selected_indices += cluster_indices.tolist()
    
    ### Aggregating RNA and ATAC Profiles ###
    connectivity_df: pd.DataFrame = pd.DataFrame(connectivity_matrix.toarray(), index=rna_data.obs_names.tolist())
    selected_connectivity_matrix: np.ndarray = connectivity_df.loc[selected_indices].values
    
    # Aggregate RNA expression and ATAC accessibility
    aggregated_rna: np.ndarray = (selected_connectivity_matrix @ rna_data.raw.X.toarray())
    pseudo_bulk_rna: pd.DataFrame = pd.DataFrame(
        (aggregated_rna / (neighbors_k - 1)).T, 
        columns=selected_indices, 
        index=rna_data.raw.var['gene_ids'].tolist())
    
    aggregated_atac: np.ndarray = (selected_connectivity_matrix @ atac_data.raw.X.toarray())
    pseudo_bulk_atac: pd.DataFrame = pd.DataFrame(
        (aggregated_atac / (neighbors_k - 1)).T, 
        columns=selected_indices, 
        index=atac_data.raw.var['gene_ids'].tolist())
    
    return pseudo_bulk_rna, pseudo_bulk_atac

adata_RNA.obs['sample'] = sample_name
adata_ATAC.obs['sample'] = sample_name

print(f'\tscRNAseq Dataset: {adata_RNA.shape[1]} genes, {adata_RNA.shape[0]} cells')
print(f'\tscATACseq Dataset: {adata_ATAC.shape[1]} peaks, {adata_ATAC.shape[0]} cells')

# Remove low count cells and genes
print('\nFiltering Data')
print(f'\tFiltering out cells with less than 200 genes...')
sc.pp.filter_cells(adata_RNA, min_genes=200)
adata_RNA = adata_RNA.copy()
print(f'\t\tShape of the RNA dataset = {adata_RNA.shape[1]} genes, {adata_RNA.shape[0]} cells')

print(f'\tFiltering out genes expressed in fewer than 3 cells...')
sc.pp.filter_genes(adata_RNA, min_cells=3)
adata_RNA = adata_RNA.copy()
print(f'\t\tShape of the RNA dataset = {adata_RNA.shape[1]} genes, {adata_RNA.shape[0]} cells')

print(f'\tFiltering out cells with less than 200 ATAC-seq peaks...')
sc.pp.filter_cells(adata_ATAC, min_genes=200)
adata_ATAC = adata_ATAC.copy()
print(f'\t\tShape of the ATAC dataset = {adata_ATAC.shape[1]} peaks, {adata_ATAC.shape[0]} cells')

print(f'\tFiltering out peaks expressed in fewer than 3 cells...')
sc.pp.filter_genes(adata_ATAC, min_cells=3)
adata_ATAC = adata_ATAC.copy()
print(f'\t\tShape of the ATAC dataset = {adata_ATAC.shape[1]} peaks, {adata_ATAC.shape[0]} cells')

print('\nShape of the dataset after filtering')
print(f'\tscRNAseq Dataset: {adata_RNA.shape[1]} genes, {adata_RNA.shape[0]} cells')
print(f'\tscATACseq Dataset: {adata_ATAC.shape[1]} peaks, {adata_ATAC.shape[0]} cells')

common_barcodes = set(adata_RNA.obs['barcode']).intersection(set(adata_ATAC.obs['barcode']))
print(f"\nNumber of common barcodes: {len(common_barcodes)}")

adata_RNA = adata_RNA[adata_RNA.obs['barcode'].isin(common_barcodes)].copy()
adata_ATAC = adata_ATAC[adata_ATAC.obs['barcode'].isin(common_barcodes)].copy()

print('\nOnly keeping shared barcodes')
print(f'\tscRNAseq Dataset: {adata_RNA.shape[1]} genes, {adata_RNA.shape[0]} cells')
print(f'\tscATACseq Dataset: {adata_ATAC.shape[1]} peaks, {adata_ATAC.shape[0]} cells')

print(f'\nGenerating pseudo-bulk / metacells')
samplelist = list(set(adata_ATAC.obs['sample'].values))
tempsample = samplelist[0]

TG_pseudobulk = pd.DataFrame([])
RE_pseudobulk = pd.DataFrame([])

singlepseudobulk = (adata_RNA.obs['sample'].unique().shape[0] * adata_RNA.obs['sample'].unique().shape[0] > 100)

for tempsample in samplelist:
    adata_RNAtemp = adata_RNA[adata_RNA.obs['sample'] == tempsample].copy()
    adata_ATACtemp = adata_ATAC[adata_ATAC.obs['sample'] == tempsample].copy()

    TG_pseudobulk_temp, RE_pseudobulk_temp = pseudo_bulk(adata_RNAtemp, adata_ATACtemp, singlepseudobulk)

    TG_pseudobulk = pd.concat([TG_pseudobulk, TG_pseudobulk_temp], axis=1)
    RE_pseudobulk = pd.concat([RE_pseudobulk, RE_pseudobulk_temp], axis=1)

    RE_pseudobulk[RE_pseudobulk > 100] = 100


sample_data_dir = os.path.join(SAMPLE_INPUT_DIR, sample_name)
if not os.path.exists(sample_data_dir):
    os.makedirs(sample_data_dir)

print(f'Writing adata_ATAC.h5ad and adata_RNA.h5ad')
adata_ATAC.write_h5ad(os.path.join(sample_data_dir, f'{sample_name}_ATAC.h5ad'))
adata_RNA.write_h5ad(os.path.join(sample_data_dir, f'{sample_name}_RNA.h5ad'))

TG_pseudobulk = TG_pseudobulk.fillna(0)
RE_pseudobulk = RE_pseudobulk.fillna(0)

print(f'Writing out peak gene ids')
pd.DataFrame(adata_ATAC.var['gene_ids']).to_csv(os.path.join(sample_data_dir, "Peaks.txt"), header=None, index=None)

print(f'Writing out pseudobulk...')
TG_pseudobulk.to_csv(os.path.join(sample_data_dir, "TG_pseudobulk.tsv"), sep='\t', index=True)
RE_pseudobulk.to_csv(os.path.join(sample_data_dir, "RE_pseudobulk.tsv"), sep='\t', index=True)

	scRNAseq Dataset: 32285 genes, 7416 cells
	scATACseq Dataset: 192248 peaks, 7416 cells

Filtering Data
	Filtering out cells with less than 200 genes...
		Shape of the RNA dataset = 32285 genes, 7411 cells
	Filtering out genes expressed in fewer than 3 cells...
		Shape of the RNA dataset = 23581 genes, 7411 cells
	Filtering out cells with less than 200 ATAC-seq peaks...
		Shape of the ATAC dataset = 192248 peaks, 7416 cells
	Filtering out peaks expressed in fewer than 3 cells...
		Shape of the ATAC dataset = 192248 peaks, 7416 cells

Shape of the dataset after filtering
	scRNAseq Dataset: 23581 genes, 7411 cells
	scATACseq Dataset: 192248 peaks, 7416 cells
Number of common barcodes: 7411

Generating pseudo-bulk / metacells


  view_to_actual(adata)
  view_to_actual(adata)


Writing adata_ATAC.h5ad and adata_RNA.h5ad
Writing out peak gene ids
Writing out pseudobulk...


In [40]:
print(f"TG_pseudobulk: {TG_pseudobulk.shape[0]:,} Genes x {TG_pseudobulk.shape[1]} metacells")
print(f"RE_pseudobulk: {RE_pseudobulk.shape[0]:,} Peaks x {RE_pseudobulk.shape[1]} metacells")

TG_pseudobulk: 23,581 Genes x 87 metacells
RE_pseudobulk: 192,248 Peaks x 87 metacells


In [41]:
TG_pseudobulk.head()

Unnamed: 0,E7.5_rep1.CGCTCAGCATTATGGT.1,E7.5_rep1.CTCCATCAGCTGTCAG.1,E7.5_rep1.CGCTCCATCTACCTCA.1,E7.5_rep1.GTCGAAGCAGGTTCAC.1,E7.5_rep1.AATTGTGTCCGTGACA.1,E7.5_rep1.GCGGGTTTCAACCAAC.1,E7.5_rep1.GTTCGCTTCTGTGCCT.1,E7.5_rep1.CCTAAGGTCAAGCCTG.1,E7.5_rep1.CATATCGCAGGACCAA.1,E7.5_rep1.GCGGTTATCCGCACAA.1,...,E7.5_rep1.TTGTTCCCATGAATAG.1,E7.5_rep1.CGCTTAACACCCACAG.1,E7.5_rep1.GTTAAGCTCCCTCATA.1,E7.5_rep1.GTTAGACTCTAAATCG.1,E7.5_rep1.TTCCACGGTTGGATCA.1,E7.5_rep1.GCGGTTGGTTTAAAGC.1,E7.5_rep1.TGGACCGGTTGGGTTA.1,E7.5_rep1.TTTAACGAGTTAGACC.1,E7.5_rep1.CTAATGTCACTTCATC.1,E7.5_rep1.GAAGTATAGAACCTAC.1
Xkr4,0.146312,0.0,0.223025,0.350406,0.155437,0.040867,0.264128,0.064394,0.342229,0.376776,...,0.202064,0.636426,0.050891,0.64607,0.438066,0.224572,0.167441,0.402592,0.771012,0.773692
Gm1992,0.0,0.0,0.0,0.0,0.015247,0.0,0.0,0.0,0.0,0.078195,...,0.0,0.0,0.0,0.0,0.024955,0.0,0.0,0.050171,0.0,0.0
Gm19938,0.035174,0.0,0.0,0.103212,0.050677,0.0,0.0,0.0,0.103981,0.0,...,0.041829,0.0,0.0,0.106992,0.0,0.0,0.0,0.089593,0.125776,0.083795
Gm37381,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Rp1,0.0,0.0,0.0,0.0,0.019938,0.0,0.0,0.0,0.020998,0.0,...,0.0,0.040058,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [42]:
RE_pseudobulk.head()

Unnamed: 0,E7.5_rep1.CGCTCAGCATTATGGT.1,E7.5_rep1.CTCCATCAGCTGTCAG.1,E7.5_rep1.CGCTCCATCTACCTCA.1,E7.5_rep1.GTCGAAGCAGGTTCAC.1,E7.5_rep1.AATTGTGTCCGTGACA.1,E7.5_rep1.GCGGGTTTCAACCAAC.1,E7.5_rep1.GTTCGCTTCTGTGCCT.1,E7.5_rep1.CCTAAGGTCAAGCCTG.1,E7.5_rep1.CATATCGCAGGACCAA.1,E7.5_rep1.GCGGTTATCCGCACAA.1,...,E7.5_rep1.TTGTTCCCATGAATAG.1,E7.5_rep1.CGCTTAACACCCACAG.1,E7.5_rep1.GTTAAGCTCCCTCATA.1,E7.5_rep1.GTTAGACTCTAAATCG.1,E7.5_rep1.TTCCACGGTTGGATCA.1,E7.5_rep1.GCGGTTGGTTTAAAGC.1,E7.5_rep1.TGGACCGGTTGGGTTA.1,E7.5_rep1.TTTAACGAGTTAGACC.1,E7.5_rep1.CTAATGTCACTTCATC.1,E7.5_rep1.GAAGTATAGAACCTAC.1
chr1:3105825-3106425,0.057822,0.0,0.0,0.057822,0.0,0.057822,0.0,0.115643,0.115643,0.057822,...,0.0,0.057822,0.115643,0.0,0.0,0.0,0.0,0.0,0.0,0.0
chr1:3132876-3133476,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.115643,0.30425,0.0,...,0.072963,0.0,0.173465,0.0,0.0,0.0,0.0,0.0,0.0,0.072963
chr1:3142536-3143136,0.0,0.0,0.0,0.0,0.215492,0.0,0.0,0.0,0.0,0.0,...,0.0,0.057822,0.0,0.0,0.0,0.0,0.036481,0.0,0.0,0.057822
chr1:3261719-3262319,0.057822,0.057822,0.094303,0.0,0.0,0.258172,0.0,0.0,0.0,0.0,...,0.0,0.130785,0.0,0.0,0.036481,0.0,0.0,0.0,0.0,0.057822
chr1:3410798-3411398,0.0,0.036481,0.057822,0.0,0.0,0.0,0.0,0.0,0.057822,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [45]:
import os
import torch
import pandas as pd
import logging
import pybedtools

from transformer_2 import MultiomicTransformer

logging.basicConfig(level=logging.INFO, format="%(message)s")

PROJECT_DIR = "/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER"
RAW_MESC_DATA_DIR = "/gpfs/Labs/Uzun/DATA/PROJECTS/2024.SC_MO_TRN_DB.MIRA/REPOSITORY/CURRENT/SINGLE_CELL_DATASETS/DS014_DOI496239_MOUSE_ESC_RAW_FILES"
MESC_PEAK_MATRIX_FILE = "/gpfs/Labs/Uzun/DATA/PROJECTS/2024.SC_MO_TRN_DB.MIRA/REPOSITORY/CURRENT/SINGLE_CELL_DATASETS/DS014_DOI496239_MOUSE_ESCDAYS7AND8/scATAC_PeakMatrix.txt"

MM10_GENOME_DIR = os.path.join(PROJECT_DIR, "data/reference_genome/mm10")
MM10_CHROM_SIZES_FILE = os.path.join(MM10_GENOME_DIR, "chrom.sizes")
MM10_GENE_TSS_FILE = os.path.join(PROJECT_DIR, "data/genome_annotation/mm10/mm10_TSS.bed")
GROUND_TRUTH_DIR = os.path.join(PROJECT_DIR, "ground_truth_files")
SAMPLE_INPUT_DIR = os.path.join(PROJECT_DIR, "input/mESC/")
OUTPUT_DIR = os.path.join(PROJECT_DIR, "output/transformer_testing_output")

def load_homer_tf_to_peak_results():
    assert os.path.exists(os.path.join(OUTPUT_DIR, "homer_tf_to_peak.parquet")), \
        "ERROR: Homer TF to peak output parquet file required"
        
    homer_results = pd.read_parquet(os.path.join(OUTPUT_DIR, "homer_tf_to_peak.parquet"), engine="pyarrow")
    homer_results = homer_results.reset_index(drop=True)
    homer_results["source_id"] = homer_results["source_id"].str.capitalize()
    
    return homer_results

def create_or_load_genomic_windows(chrom_id, window_size, force_recalculate=False):
    genome_window_file = os.path.join(MM10_GENOME_DIR, f"mm10_{chrom_id}_windows_{window_size // 1000}kb.bed")
    if not os.path.exists(genome_window_file) or force_recalculate:
        
        logging.info("Creating genomic windows")
        mm10_genome_windows = pybedtools.bedtool.BedTool().window_maker(g=MM10_CHROM_SIZES_FILE, w=window_size)
        mm10_windows = (
            mm10_genome_windows
            .filter(lambda x: x.chrom == chrom_id)  # TEMPORARY Restrict to one chromosome for testing
            .saveas(genome_window_file)
            .to_dataframe()
        )
    else:
        
        logging.info("Loading existing genomic windows")
        mm10_windows = pybedtools.BedTool(genome_window_file).to_dataframe()
        
    return mm10_windows

def make_peak_to_window_map(peaks_bed: pd.DataFrame, windows_bed: pd.DataFrame) -> dict[str, int]:
    """
    peaks_bed: df with ['chrom','start','end','peak_id']
    windows_bed: df with ['chrom','start','end','win_idx']
    """
    bedtool_peaks = pybedtools.BedTool.from_dataframe(peaks_bed)
    bedtool_windows = pybedtools.BedTool.from_dataframe(windows_bed)
    
    mapping = {}
    for interval in bedtool_peaks.intersect(bedtool_windows, wa=True, wb=True):
        peak_id = interval.name  # the peak_id column from peaks_bed
        win_idx = int(interval.fields[-1])  # last column = win_idx
        mapping[peak_id] = win_idx
    return mapping

def prepare_inputs(TG_pseudobulk: pd.DataFrame,
                   RE_pseudobulk: pd.DataFrame,
                   tf_list: list[str],
                   window_map: dict[str, int]) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Convert pseudobulk matrices into model inputs.
    
    Args:
      TG_pseudobulk : genes x samples dataframe
      RE_pseudobulk : peaks x samples dataframe
      tf_list       : list of TF gene symbols to keep
      window_map    : dict mapping peaks -> window index (0..num_windows-1)
    
    Returns:
      tf_expr   : Tensor [B, num_tf]
      atac_wins : Tensor [B, num_windows, 1]
    """
    # 1. Extract TF expression
    tf_expr = TG_pseudobulk.loc[TG_pseudobulk.index.intersection(tf_list)].T
    tf_tensor = torch.tensor(tf_expr.values, dtype=torch.float32)   # [B, num_tf]

    # 2. Collapse peaks into windows
    num_windows = max(window_map.values()) + 1
    atac_wins = torch.zeros((RE_pseudobulk.shape[1], num_windows, 1), dtype=torch.float32)

    peak_idx = [RE_pseudobulk.index.get_loc(p) for p in window_map if p in RE_pseudobulk.index]
    win_idx = [window_map[p] for p in window_map if p in RE_pseudobulk.index]

    peak_tensor = torch.tensor(RE_pseudobulk.iloc[peak_idx].values.T, dtype=torch.float32)  # [B, num_peaks]
    win_idx_tensor = torch.tensor(win_idx, dtype=torch.long)

    atac_wins.index_add_(1, win_idx_tensor, peak_tensor.unsqueeze(-1))
    return tf_tensor, atac_wins

sample_name = "E7.5_rep1"
window_size = 50000

tf_list = list(load_homer_tf_to_peak_results()["source_id"].unique())
logging.info(f"TF List: {tf_list[:5]}, total {len(tf_list)} TFs")

sample_data_dir = os.path.join(SAMPLE_INPUT_DIR, sample_name)
TG_pseudobulk = pd.read_csv(os.path.join(sample_data_dir, "TG_pseudobulk.tsv"), sep="\t", index_col=0)
RE_pseudobulk = pd.read_csv(os.path.join(sample_data_dir, "RE_pseudobulk.tsv"), sep="\t", index_col=0)

peaks_df = (
    RE_pseudobulk.index.to_series()
    .str.split("[:-]", expand=True)
    .rename(columns={0: "chrom", 1: "start", 2: "end"})
)
peaks_df["start"] = peaks_df["start"].astype(int)
peaks_df["end"] = peaks_df["end"].astype(int)
peaks_df["peak_id"] = RE_pseudobulk.index

# Create genome windows and add index
mm10_windows = create_or_load_genomic_windows("chr19", window_size)
mm10_windows = mm10_windows.reset_index(drop=True)
mm10_windows["win_idx"] = mm10_windows.index

# Build peak -> window mapping
window_map = make_peak_to_window_map(peaks_df, mm10_windows)
logging.info(f"Mapped {len(window_map)} peaks to windows")

logging.info("TG Pseudobulk")
logging.info(f"TG_pseudobulk: {TG_pseudobulk.shape[0]:,} Genes x {TG_pseudobulk.shape[1]} metacells")
logging.info(TG_pseudobulk.head())
logging.info("")
logging.info("RE Pseudobulk")
logging.info(f"RE_pseudobulk: {RE_pseudobulk.shape[0]:,} Peaks x {RE_pseudobulk.shape[1]} metacells")
logging.info(RE_pseudobulk.head())

# Example setup
d_model = 128
num_heads = 8
d_ff = 256
dropout = 0.1
num_tf = len(tf_list)
num_windows = max(window_map.values()) + 1
num_tg = TG_pseudobulk.shape[0]   # or restrict to TGs only

model = MultiomicTransformer(d_model, num_heads, d_ff, dropout, num_tf, num_windows, num_tg)

# Prepare inputs
tf_tensor, atac_wins = prepare_inputs(TG_pseudobulk, RE_pseudobulk, tf_list, window_map)

# Forward pass
gene_logits = model(atac_wins, tf_tensor)   # [B, num_tg]

logging.info(gene_logits)
logging.info(gene_logits.shape)


ModuleNotFoundError: No module named 'transformer_2'