In [None]:
import os

import anndata
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
from matplotlib import patches, rcParams

import scglue

In [None]:
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (7, 7)

PATH = "s04_glue_final"
os.makedirs(PATH, exist_ok=True)

# Aggregated data

## Read data

In [None]:
rna_agg = anndata.read_h5ad("s03_unsupervised_balancing/rna_agg_balanced.h5ad")
atac_agg = anndata.read_h5ad("s03_unsupervised_balancing/atac_agg_balanced.h5ad")
prior = nx.read_graphml("s01_preprocessing/sub.graphml.gz")

## GLUE

In [None]:
vertices = sorted(prior.nodes)
scglue.models.configure_dataset(rna_agg, "NB", use_highly_variable=True, use_rep="X_pca", use_dsc_weight="np_balancing")
scglue.models.configure_dataset(atac_agg, "NB", use_highly_variable=True, use_rep="X_lsi", use_dsc_weight="np_balancing")
glue = scglue.models.SCGLUEModel(
    {"rna": rna_agg, "atac": atac_agg}, vertices,
    h_dim=512, random_seed=0
)

In [None]:
glue_pretrain = scglue.models.load_model("s02_glue_pretrain/final.dill")
glue.adopt_pretrained_model(glue_pretrain)

In [None]:
glue.compile(lr=1e-3)
glue.fit(
    {"rna": rna_agg, "atac": atac_agg}, prior,
    edge_weight="weight", edge_sign="sign",
    directory=f"{PATH}/agg"
)
glue.save(f"{PATH}/agg/final.dill")

In [None]:
rna_agg.obsm["X_glue"] = glue.encode_data("rna", rna_agg)
atac_agg.obsm["X_glue"] = glue.encode_data("atac", atac_agg)

## Visualization

In [None]:
combined_agg = anndata.AnnData(
    obs=pd.concat([rna_agg.obs, atac_agg.obs], join="inner"),
    obsm={"X_glue": np.concatenate([rna_agg.obsm["X_glue"], atac_agg.obsm["X_glue"]])}
)

In [None]:
sc.pp.neighbors(
    combined_agg, n_pcs=combined_agg.obsm["X_glue"].shape[1],
    use_rep="X_glue", metric="cosine"
)
sc.tl.umap(combined_agg)

In [None]:
fig = sc.pl.umap(
    combined_agg, color="cell_type",
    title="Cell type (aggregated)", return_fig=True
)
ct_handles, ct_labels = fig.axes[0].get_legend_handles_labels()
fig.axes[0].get_legend().remove()
fig.savefig(f"{PATH}/agg/combined_ct.pdf")

In [None]:
fig = sc.pl.umap(
    combined_agg, color="domain",
    title="Omics layer (aggregated)", return_fig=True
)
domain_handles, domain_labels = fig.axes[0].get_legend_handles_labels()
fig.axes[0].get_legend().remove()
fig.savefig(f"{PATH}/agg/combined_domain.pdf")

In [None]:
fig, ax = plt.subplots()
ax.set_visible(False)
placeholder = patches.Rectangle((0, 0), 1, 1, visible=False)
handles = [placeholder, *domain_handles, placeholder, placeholder, *ct_handles]
labels = ["Omics layer", *domain_labels, "", "Cell type", *ct_labels]
fig.legend(handles, labels, ncol=5, frameon=False)
fig.savefig(f"{PATH}/agg/combined_legend.pdf")

In [None]:
rna_agg.obsm["X_glue_umap"] = combined_agg[rna_agg.obs_names, :].obsm["X_umap"]
atac_agg.obsm["X_glue_umap"] = combined_agg[atac_agg.obs_names, :].obsm["X_umap"]

In [None]:
rna_agg.obs["cell_type"].cat.set_categories(combined_agg.obs["cell_type"].cat.categories, inplace=True)
atac_agg.obs["cell_type"].cat.set_categories(combined_agg.obs["cell_type"].cat.categories, inplace=True)

In [None]:
rna_agg.uns["cell_type_colors"] = combined_agg.uns["cell_type_colors"]
atac_agg.uns["cell_type_colors"] = combined_agg.uns["cell_type_colors"]

In [None]:
fig = sc.pl.embedding(
    rna_agg, "X_glue_umap", color="cell_type",
    title="scRNA-seq cell type (aggregated)", return_fig=True,
    legend_loc="on data", legend_fontsize=4, legend_fontoutline=0.5
)
fig.axes[0].set_xlabel("UMAP1")
fig.axes[0].set_ylabel("UMAP2")
fig.savefig(f"{PATH}/agg/rna_ct.pdf")

In [None]:
fig = sc.pl.embedding(
    atac_agg, "X_glue_umap", color="cell_type",
    title="scATAC-seq cell type (aggregated)", return_fig=True,
    legend_loc="on data", legend_fontsize=4, legend_fontoutline=0.5
)
fig.axes[0].set_xlabel("UMAP1")
fig.axes[0].set_ylabel("UMAP2")
fig.savefig(f"{PATH}/agg/atac_ct.pdf")

## Save results

In [None]:
rna_agg.write(f"{PATH}/agg/rna_agg.h5ad", compression="gzip")
atac_agg.write(f"{PATH}/agg/atac_agg.h5ad", compression="gzip")
combined_agg.write(f"{PATH}/agg/combined_agg.h5ad", compression="gzip")

# Full data

## Read data

In [None]:
rna = anndata.read_h5ad("s03_unsupervised_balancing/rna_balanced.h5ad")
atac = anndata.read_h5ad("s03_unsupervised_balancing/atac_balanced.h5ad")
prior = nx.read_graphml("s01_preprocessing/sub.graphml.gz")

## GLUE

In [None]:
vertices = sorted(prior.nodes)
scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_dsc_weight="nc_balancing")
scglue.models.configure_dataset(atac, "NB", use_highly_variable=True, use_rep="X_lsi", use_dsc_weight="nc_balancing")
glue = scglue.models.SCGLUEModel(
    {"rna": rna, "atac": atac}, vertices,
    h_dim=512, random_seed=0
)

In [None]:
glue_agg = scglue.models.load_model(f"{PATH}/agg/final.dill")
glue.adopt_pretrained_model(glue_agg)

In [None]:
glue.compile(lr=5e-4)
glue.fit(
    {"rna": rna, "atac": atac}, prior,
    edge_weight="weight", edge_sign="sign",
    align_burnin=0, data_batch_size=512,
    directory=f"{PATH}/full"
)
glue.save(f"{PATH}/full/final.dill")

In [None]:
rna.obsm["X_glue"] = glue.encode_data("rna", rna)
atac.obsm["X_glue"] = glue.encode_data("atac", atac)

## Visualization

In [None]:
combined = anndata.AnnData(
    obs=pd.concat([rna.obs, atac.obs], join="inner"),
    obsm={"X_glue": np.concatenate([rna.obsm["X_glue"], atac.obsm["X_glue"]])}
)

In [None]:
sc.pp.neighbors(
    combined, n_pcs=combined.obsm["X_glue"].shape[1],
    use_rep="X_glue", metric="cosine"
)
sc.tl.umap(combined)

In [None]:
fig = sc.pl.umap(
    combined, color="cell_type",
    title="Cell type", return_fig=True
)
ct_handles, ct_labels = fig.axes[0].get_legend_handles_labels()
fig.axes[0].get_legend().remove()
fig.savefig(f"{PATH}/full/combined_ct.pdf")

In [None]:
fig = sc.pl.umap(
    combined, color="domain",
    title="Omics layer", return_fig=True
)
domain_handles, domain_labels = fig.axes[0].get_legend_handles_labels()
fig.axes[0].get_legend().remove()
fig.savefig(f"{PATH}/full/combined_domain.pdf")

In [None]:
fig, ax = plt.subplots()
ax.set_visible(False)
placeholder = patches.Rectangle((0, 0), 1, 1, visible=False)
handles = [placeholder, *domain_handles, placeholder, placeholder, *ct_handles]
labels = ["Omics layer", *domain_labels, "", "Cell type", *ct_labels]
fig.legend(handles, labels, ncol=5, frameon=False)
fig.savefig(f"{PATH}/full/combined_legend.pdf")

In [None]:
rna.obsm["X_glue_umap"] = combined[rna.obs_names, :].obsm["X_umap"]
atac.obsm["X_glue_umap"] = combined[atac.obs_names, :].obsm["X_umap"]

In [None]:
rna.obs["cell_type"].cat.set_categories(combined.obs["cell_type"].cat.categories, inplace=True)
atac.obs["cell_type"].cat.set_categories(combined.obs["cell_type"].cat.categories, inplace=True)

In [None]:
rna.uns["cell_type_colors"] = combined.uns["cell_type_colors"]
atac.uns["cell_type_colors"] = combined.uns["cell_type_colors"]

In [None]:
fig = sc.pl.embedding(
    rna, "X_glue_umap", color="cell_type",
    title="scRNA-seq cell type", return_fig=True,
    legend_loc="on data", legend_fontsize=4, legend_fontoutline=0.5
)
fig.axes[0].set_xlabel("UMAP1")
fig.axes[0].set_ylabel("UMAP2")
fig.savefig(f"{PATH}/full/rna_ct.pdf")

In [None]:
fig = sc.pl.embedding(
    atac, "X_glue_umap", color="cell_type",
    title="scATAC-seq cell type", return_fig=True,
    legend_loc="on data", legend_fontsize=4, legend_fontoutline=0.5
)
fig.axes[0].set_xlabel("UMAP1")
fig.axes[0].set_ylabel("UMAP2")
fig.savefig(f"{PATH}/full/atac_ct.pdf")

## Save results

In [None]:
rna.write(f"{PATH}/full/rna.h5ad", compression="gzip")
atac.write(f"{PATH}/full/atac.h5ad", compression="gzip")
combined.write(f"{PATH}/full/combined.h5ad", compression="gzip")