In [None]:
import os
from dotenv import load_dotenv
from pathlib import Path
import pandas as pd
import pickle
import umap
from sklearn.preprocessing import StandardScaler
from pycisTopic.clust_vis import plot_metadata
from pycisTopic.cistopic_class import CistopicObject
load_dotenv()

In [None]:
out_dir = Path(os.getenv("OUTPUT_PATH")) / "garcia_ATAC/atac_preprocessing_combined"
cistopic_obj = pickle.load(open(os.path.join(out_dir, "cistopic_obj.pkl"), "rb"))

In [None]:
def run_umap_projection_within_cistopic_obj(
    cistopic_obj: CistopicObject,
    reference_dataset_value: str,
    query_dataset_value: str,
    dataset_column_name: str = "dataset",
    target: str = "cell",
    scale: bool = False,
    reduction_name: str = "UMAP_projected",
    random_state: int = 555,
    harmony: bool = False,
    **kwargs,
):
    """
    Run UMAP on a reference dataset (defined by a column value) and project a query dataset (defined by another column value)
    onto the learned UMAP space, all within a single CistopicObject.

    Parameters
    ----------
    cistopic_obj: CistopicObject
        A cisTopic object with a model in `selected_model` containing both datasets.
    reference_dataset_value: str
        The value in the 'dataset_column_name' that identifies the reference dataset.
    query_dataset_value: str
        The value in the 'dataset_column_name' that identifies the query dataset.
    dataset_column_name: str
        The name of the column in `cistopic_obj.cell_data` that distinguishes the datasets. Default: "dataset"
    target: str
        Whether cells ('cell') or regions ('region') should be used. Default: 'cell'
    scale: bool
        Whether to scale the cell-topic or topic-regions contributions prior to the dimensionality reduction. Default: False
    reduction_name: str
        Reduction name to use as key in the dimensionality reduction dictionary. Default: 'UMAP_projected'
    random_state: int
        Seed parameter for running UMAP. Default: 555
    harmony: bool
        If target is 'cell', whether to use harmony processed topic contributions. Default: False.
    **kwargs
        Parameters to pass to umap.UMAP.

    """

    # 1. Prepare Data from the Reference Dataset

    model = cistopic_obj.selected_model

    if target == "cell":
        data_mat_reference = (
            model.cell_topic_harmony
            if harmony
            else model.cell_topic
        )
        # Filter by cells in the reference dataset
        reference_cells = cistopic_obj.cell_data[cistopic_obj.cell_data[dataset_column_name] == reference_dataset_value].index.tolist()
        data_mat_reference = data_mat_reference.loc[:, reference_cells]
        data_names_reference = reference_cells

    if target == "region":
        data_mat_reference = model.topic_region.T
        # Filter by cells in the reference dataset, even when using regions
        reference_cells = cistopic_obj.cell_data[cistopic_obj.cell_data[dataset_column_name] == reference_dataset_value].index.tolist()
        data_mat_reference = data_mat_reference.loc[:, reference_cells]
        data_names_reference = reference_cells

    data_mat_reference = data_mat_reference.T

    # Fit scaler on reference data
    if scale:
        scaler = StandardScaler()
        data_mat_reference = pd.DataFrame(
            scaler.fit_transform(data_mat_reference),
            index=data_mat_reference.index.to_list(),
            columns=data_mat_reference.columns,
        )


    # 2. Train UMAP on the Reference Dataset

    reducer = umap.UMAP(random_state=random_state, **kwargs)
    embedding_reference = reducer.fit_transform(data_mat_reference)

    # 3. Prepare Data from the Query Dataset

    if target == "cell":
        data_mat_query = (
            model.cell_topic_harmony if harmony else model.cell_topic
        )
        # Filter by cells in the query dataset
        query_cells = cistopic_obj.cell_data[cistopic_obj.cell_data[dataset_column_name] == query_dataset_value].index.tolist()
        data_mat_query = data_mat_query.loc[:, query_cells]
        data_names_query = query_cells

    if target == "region":
        data_mat_query = model.topic_region.T
        # Filter by cells in the query dataset, even when using regions
        query_cells = cistopic_obj.cell_data[cistopic_obj.cell_data[dataset_column_name] == query_dataset_value].index.tolist()
        data_mat_query = data_mat_query.loc[:, query_cells]
        data_names_query = query_cells

    data_mat_query = data_mat_query.T

    # Use the same scaler that was fitted on reference data
    if scale:
        data_mat_query = pd.DataFrame(
            scaler.transform(data_mat_query),
            index=data_mat_query.index.to_list(),
            columns=data_mat_query.columns,
        )

    # 4. Project the Query Dataset onto the Reference UMAP Space

    embedding_query = reducer.transform(data_mat_query)

    # 5. Store the Projected Embeddings

    dr_reference = pd.DataFrame(
        embedding_reference,
        index=data_names_reference,
        columns=["UMAP_1", "UMAP_2"],
    )
    dr_query = pd.DataFrame(
        embedding_query, index=data_names_query, columns=["UMAP_1", "UMAP_2"]
    )

    if target == "cell":
        # Update the cistopic object with the umap coordinates
        cistopic_obj.projections["cell"][reduction_name] = dr_reference
        cistopic_obj.projections["cell"]["_".join([reduction_name, query_dataset_value])] = dr_query
        
        # Create combined umap with labels for plotting
        dr_combined = pd.concat([dr_reference, dr_query])
        labels = [reference_dataset_value] * len(dr_reference) + [query_dataset_value] * len(dr_query)
        dr_combined[dataset_column_name] = labels
        
        # Name the combined umap based on reference and query names
        cistopic_obj.projections["cell"][reduction_name + "_" + reference_dataset_value + "_" + query_dataset_value] = dr_combined

    if target == "region":
        cistopic_obj.projections["region"][reduction_name] = dr_reference
        cistopic_obj.projections["region"]["_".join([reduction_name, query_dataset_value])] = dr_query
        
        # Create combined umap with labels for plotting
        dr_combined = pd.concat([dr_reference, dr_query])
        labels = [reference_dataset_value] * len(dr_reference) + [query_dataset_value] * len(dr_query)
        dr_combined[dataset_column_name] = labels
        cistopic_obj.projections["region"][reduction_name + "_" + reference_dataset_value + "_" + query_dataset_value] = dr_combined

    return reducer

In [None]:
import numpy as np
cistopic_obj.cell_data['dataset'] = np.where(cistopic_obj.cell_data['sample'].str.contains('24047'), 'meiotic', 'garcia_ATAC')

In [None]:
cistopic_obj.cell_data.loc[cistopic_obj.cell_data['dataset'] == "meiotic", "celltype"].value_counts()

In [None]:
cistopic_obj.cell_data.loc[cistopic_obj.cell_data['dataset'] == "garcia_ATAC", "celltype"].value_counts()

In [None]:
from pycisTopic.clust_vis import harmony
harmony(cistopic_obj, 'sample')

In [None]:
# Run UMAP projection
run_umap_projection_within_cistopic_obj(
    cistopic_obj,
    reference_dataset_value="garcia_ATAC",
    query_dataset_value="meiotic",
    dataset_column_name="dataset",
    target="cell",
    scale=True,
    harmony=True)

In [None]:
categories = list(cistopic_obj.cell_data.celltype.unique())
colors = dict(zip(categories, ["red", "green", "blue", "cyan", "yellow", "purple", "orange", "pink"]))
color_dictionary = {"dataset": {"meiotic": "orange", "garcia_ATAC": "blue"}, "celltype": colors}


In [None]:
plot_metadata(
    cistopic_obj,
    reduction_name="UMAP_projected_garcia_ATAC_meiotic",  # Use the combined UMAP
    variables=[
        "dataset",
        "celltype",        
    ],
    target="cell",
    num_columns=2,
    show_label=False,
    show_legend=True,
    color_dictionary=color_dictionary,
    text_size=10,
    dot_size=5,
)

In [None]:

plot_metadata(
    cistopic_obj,
    reduction_name="UMAP_projected_garcia_ATAC_meiotic",  # Use the combined UMAP
    variables=[
        "dataset",
        "celltype",        
    ],
    target="cell",
    num_columns=2,
    show_label=False,
    show_legend=True,
    color_dictionary = color_dictionary,
    selected_features=cistopic_obj.cell_data[cistopic_obj.cell_data['dataset'] == 'garcia_ATAC'].index.tolist(),
    text_size=10,
    dot_size=5,
)

In [None]:
plot_metadata(
    cistopic_obj,
    reduction_name="UMAP_projected_garcia_ATAC_meiotic",  # Use the combined UMAP
    variables=[
        "dataset",
        "celltype",        
    ],
    target="cell",
    num_columns=2,
    show_label=False,
    show_legend=True,
    color_dictionary = color_dictionary,
    selected_features=cistopic_obj.cell_data[cistopic_obj.cell_data['dataset'] == 'meiotic'].index.tolist(),
    text_size=10,
    dot_size=5,
)

In [None]:
import anndata as ad
# Create AnnData objects
adata_all = ad.AnnData(X=cistopic_obj.fragment_matrix.T.astype(np.float32), 
            obs=pd.DataFrame(index=cistopic_obj.cell_data.index.tolist()), 
            var=pd.DataFrame(index=cistopic_obj.region_names))


adata_all.obs['celltype'] = cistopic_obj.cell_data.celltype
adata_all.obs['dataset'] = cistopic_obj.cell_data.dataset

In [None]:
import scanpy as sc
sc.pp.normalize_total(adata_all)
sc.pp.log1p(adata_all)
sc.pp.pca(adata_all)
sc.pp.neighbors(adata_all)
sc.tl.umap(adata_all)

In [None]:
from matplotlib import pyplot as plt

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
sc.pl.umap(adata_all, ax=ax1, color = "dataset", show=False)
sc.pl.umap(adata_all, ax=ax2, color = "celltype", show=False)
plt.tight_layout()
plt.show()

In [None]:
# The scanpy ingest way

cistopic_meiotic = cistopic_obj.subset(cistopic_obj.cell_data[cistopic_obj.cell_data['dataset'] == 'meiotic'].index.tolist(), copy=True)


In [None]:
cistopic_garcia = cistopic_obj.subset(cistopic_obj.cell_data[cistopic_obj.cell_data['dataset'] == 'garcia_ATAC'].index.tolist(), copy=True)

In [None]:
import anndata as ad
# Create AnnData objects
adata_meiotic = ad.AnnData(X=cistopic_meiotic.fragment_matrix.T.astype(np.float32), 
            obs=pd.DataFrame(index=cistopic_obj.cell_data[cistopic_obj.cell_data['dataset'] == 'meiotic'].index.tolist()), 
            var=pd.DataFrame(index=cistopic_obj.region_names))

adata_garcia = ad.AnnData(X=cistopic_garcia.fragment_matrix.T.astype(np.float32), 
            obs=pd.DataFrame(index=cistopic_obj.cell_data[cistopic_obj.cell_data['dataset'] == 'garcia_ATAC'].index.tolist()), 
            var=pd.DataFrame(index=cistopic_garcia.region_names))

In [None]:

# Add some metadata
adata_meiotic.obs['dataset'] = 'meiotic'
adata_garcia.obs['dataset'] = 'garcia'

adata_meiotic.obs['celltype'] = cistopic_meiotic.cell_data.celltype
adata_garcia.obs['celltype'] = cistopic_garcia.cell_data.celltype

In [None]:
adata_garcia.X

In [None]:
common_regions = list(set(adata_garcia.var_names) & set(adata_meiotic.var_names))
adata_meiotic = adata_meiotic[:, common_regions]
adata_garcia = adata_garcia[:, common_regions]

In [None]:
import scanpy as sc
sc.pp.normalize_total(adata_garcia)
sc.pp.log1p(adata_garcia)
sc.pp.pca(adata_garcia)
sc.pp.neighbors(adata_garcia)
sc.tl.umap(adata_garcia)

In [None]:
# Ingest the new data into the reference
sc.tl.ingest(adata_meiotic, adata_garcia)

In [None]:
from matplotlib import pyplot as plt

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
sc.pl.umap(adata_garcia, ax=ax1, color = "celltype", show=False, title='Garcia')
sc.pl.umap(adata_meiotic, ax=ax2, color = "celltype", show=False, title='Meiotic')
plt.tight_layout()
plt.show()