In [None]:
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from scarches.models.scpoli import scPoli

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2


# Load adata

In [None]:
prep_HVGs = True
hvg_number=6_000
N_NEIGHBOR=30
MIN_DIST=0.1
SET_TARGET_MISSING=True

ETA_QUERY=10
ETA_TRAIN=5
MAX_EPOCHS_QUERY=80
MAX_EPOCHS_TRAIN=80

adata_path='/nfs/team298/ls34/disease_atlas/mrvi/adata_scvi5_lesional_plus_nonlesional_novascmural_noHS.h5ad'
if prep_HVGs == False:
    adata_path = adata_path + ".HVGS"
adata=sc.read_h5ad(adata_path)
adata

In [None]:
if prep_HVGs:
    def apply_qc_thresholds(adata, MIN_N_GENES, MAX_TOTAL_COUNT, MAX_PCT_MT, label, MIN_TOTAL_COUNT=0,):
        """
        Apply thresholds to generate QC column 
        """
        ## Cell cycle gene list
        cc_genes_csv=pd.read_csv("/lustre/scratch126/cellgen/team298/sko_expimap_2023/pan_fetal_cc_genes.csv", names=["ind", "gene_ids"], skiprows=1)
        cc_genes_csv = cc_genes_csv["gene_ids"]
        cc_genes_csv = list(cc_genes_csv)

        # Mark MT/ribo/Hb/cell cycle genes
        adata.var['mt'] = adata.var_names.str.startswith('MT-')  
        adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
        adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]")) 
        adata.var["cc_fetal"] = adata.var_names.isin(cc_genes_csv)

        # Calculate QC metrics
        sc.pp.calculate_qc_metrics(adata, qc_vars=["mt", "ribo"], inplace=True, log1p=False) #percent_top=[20],

        conditions = [
            (adata.obs['n_genes_by_counts'] < MIN_N_GENES),
            (adata.obs['pct_counts_mt'] > MAX_PCT_MT),
            (adata.obs['total_counts'] > MAX_TOTAL_COUNT),
            (adata.obs['total_counts'] < MIN_TOTAL_COUNT),
            (adata.obs['pct_counts_mt'] <= MAX_PCT_MT) & (adata.obs['n_genes_by_counts'] >= MIN_N_GENES) & 
            (adata.obs['total_counts'] <= MAX_TOTAL_COUNT)  & 
            (adata.obs['total_counts'] >= MIN_TOTAL_COUNT)
        ]
        label_suffix = label.split("_")[-1]
        print(label_suffix)
        pass_name = "Pass_" + label_suffix
        values = ['Low_nFeature', 'High_MT', 'High total count', 'Low total count', pass_name]

        adata.obs[label] = np.select(conditions, values)
        adata.obs[label] = adata.obs[label].astype('category')

        print(adata.obs[label].value_counts())
 
    apply_qc_thresholds(adata, MIN_N_GENES=500, MAX_TOTAL_COUNT=300_000, MAX_PCT_MT=20,  MIN_TOTAL_COUNT=2000, label="QC_hi")


    HVG_BATCH_KEY = "DonorID"
 
    HVG_BATCH_MINIMUM=80

 
    original_hvg = str(hvg_number) + "select" + str(HVG_BATCH_MINIMUM)

    mask_to_exclude = (adata.var.cc_fetal | 
    adata.var.hb | 
    adata.var.mt |
    adata.var.ribo
    )
    mask_to_include = ~mask_to_exclude
    adata  = adata[:, mask_to_include]
    sc.pp.highly_variable_genes(adata,  
                            n_top_genes=hvg_number, 
                            subset=False,
                            batch_key=HVG_BATCH_KEY,
                            check_values=False,
                           )  
    var_genes_all = adata.var.highly_variable
    var_genes_batch = adata.var.highly_variable_nbatches > HVG_BATCH_MINIMUM
    var_select = adata.var.highly_variable_nbatches >= HVG_BATCH_MINIMUM
    var_genes = var_select.index[var_select]
    hvg_number = len(var_genes)
    print(f"selected {hvg_number} HVGs!")


    adata2=sc.read_h5ad(adata_path)
    adata2.obs["DonorID"]=adata2.obs["sample_id"]
    adata2.layers["counts"]=adata2.X.copy()
    label_dict = adata.var['highly_variable_nbatches'].to_dict()
    adata2.var['highly_variable_nbatches'] = adata2.var.index.map(label_dict).fillna(np.nan)
    label_dict = adata.var['highly_variable'].to_dict()
    adata2.var['highly_variable'] = adata2.var.index.map(label_dict).fillna(False)

    adata2.X=adata2.layers["counts"].copy()
    del(adata2.layers["counts"])
    adata2.write(adata_path + ".HVGS")
    print(f"Saved  HVGs -> adata_path ")

    adata=adata2.copy()
    adata2=0
else:
    print("Skipping HVGs")

# Start

In [None]:
adata.obs["dataset_and_status"] = np.where(
    adata.obs["Site_status_binary"] == "Nonlesional",
    "Nonlesional",
    adata.obs["Patient_status"].astype(str) + "_" + adata.obs["Site_status_binary"].astype(str) + "_" + adata.obs["dataset_id"].astype(str)
)
adata.obs["dataset_and_status2"] = adata.obs["Site_status_binary"].astype(str) + "_" + adata.obs["dataset_id"].astype(str)



In [None]:
adata.obs['lvl3_annotation']=adata.obs['corefb_names']

In [None]:
condition_key = 'dataset_and_status2'
cell_type_key = 'lvl3_annotation'


reference = [x for x in adata.obs["dataset_and_status2"].unique() if "Nonlesional" in x]

query = [x for x in adata.obs["dataset_and_status2"].unique() if "Lesional_" in x]

In [None]:
adata.obs['query'] = adata.obs[condition_key].isin(query)
adata.obs['query'] = adata.obs['query'].astype('category')
source_adata = adata[adata.obs.dataset_and_status2.isin(reference)].copy()
target_adata = adata[adata.obs.dataset_and_status2.isin(query)].copy()



In [None]:
if SET_TARGET_MISSING:
    source_adata.obs['lvl3_annotation'] = source_adata.obs['lvl3_annotation'].apply(
        lambda x: '.' if ('activated' in x.lower() or x.startswith('F6')) else x
    )
if SET_TARGET_MISSING:
    target_adata.obs['lvl3_annotation']="Missing_lesional"

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
model = scPoli(
    adata=source_adata,
    condition_keys=condition_key,
    cell_type_keys=cell_type_key,
    embedding_dims=10,
    recon_loss='nb',
)


In [None]:
print("Training data:", source_adata.shape)
model.train(
    n_epochs=MAX_EPOCHS_TRAIN,
    pretraining_epochs=MAX_EPOCHS_TRAIN*0.8,
    early_stopping_kwargs=early_stopping_kwargs,
    eta=ETA_TRAIN,
      layer="counts",
    accelerator='gpu'
)

# Query

In [None]:
scpoli_query = scPoli.load_query_data(
    adata=target_adata,
    reference_model=model,
    labeled_indices=[],
)

In [None]:
scpoli_query.train(
    n_epochs=MAX_EPOCHS_QUERY ,
    pretraining_epochs=MAX_EPOCHS_QUERY*0.8,
    eta=ETA_QUERY ,
    layer="counts",
    accelerator='gpu'
)


# Transfer labels

In [None]:
target_adata.X = target_adata.X.astype('float32')
results_dict = scpoli_query.classify(target_adata, scale_uncertainties=True)


In [None]:
for i in range(len(cell_type_key)):
    preds = results_dict[cell_type_key]["preds"]
    results_dict[cell_type_key]["uncert"]
    classification_df = pd.DataFrame(
        classification_report(
            y_true=target_adata.obs[cell_type_key],
            y_pred=preds,
            output_dict=True,
        )
    ).transpose()
print(classification_df)

In [None]:
#get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
    source_adata,
    mean=True
)

adata_latent_source = sc.AnnData(data_latent_source)
adata_latent_source.obs = source_adata.obs.copy()

data_latent= scpoli_query.get_latent(
    target_adata,
    mean=True
)


In [None]:
scpoli_query.save(f'/nfs/team298/ls34/fibroblast_atlas/fig1/model_scpoli_allfibroblasts_{hvg_number}_2',
           save_anndata=True, 
           overwrite=True)



In [None]:
adata_latent = sc.AnnData(data_latent)
adata_latent.obs = target_adata.obs.copy()


In [None]:
adata.obs["dataset_and_status2"] = adata.obs["Site_status_binary"].astype(str) + "_" + adata.obs["dataset_id"].astype(str)
target_adata = adata[adata.obs.dataset_and_status2.isin(query)].copy()
target_adata.X = target_adata.X.astype('float32')


In [None]:
target_adata.obs['cell_type_pred'] = results_dict['lvl3_annotation']['preds'].tolist()
target_adata.obs['cell_type_uncert'] = results_dict['lvl3_annotation']['uncert'].tolist()
target_adata.obs['classifier_outcome'] = (
    target_adata.obs['cell_type_pred'] == target_adata.obs['lvl3_annotation']
)

In [None]:
target_adata.obsm["X_scpoli"] = data_latent


In [None]:
neighbor_id = "neighbor_" + str(N_NEIGHBOR)   
sc.pp.neighbors(target_adata, use_rep = 'X_scpoli', metric = "euclidean", n_neighbors=N_NEIGHBOR,key_added=neighbor_id)
print("neighbours done")

sc.tl.umap(target_adata, min_dist=MIN_DIST, neighbors_key =neighbor_id ) 
print(f"UMAP done")


In [None]:


colors_f1 = plt.cm.YlOrBr(np.linspace(0, 1, 10))
colors_f2 = plt.cm.Blues(np.linspace(0.2, 1, 10))
colors_f3 = plt.cm.Reds(np.linspace(0.1, 1, 10))
colors_f4 = plt.cm.Greens(np.linspace(0.2, 1, 10))
colors_f5 = plt.cm.Purples(np.linspace(0.5, 1.0, 10))  # Brighter purple palette
colors_other = plt.cm.Greys(np.linspace(0.2, 1, 10))

custom_colors = {}

f1_shared_color = colors_f1[0]
colors_f1 = colors_f1[1:]
target_adata.obs['cell_type_pred'] = target_adata.obs['cell_type_pred'].astype('category')

for category in target_adata.obs["cell_type_pred"].cat.categories:
    if category in ["F1*: Secretory", "F1: Secretory superficial"]:
        custom_colors[category] = f1_shared_color
    elif category.startswith("F1"):
        custom_colors[category] = colors_f1[0]
        colors_f1 = colors_f1[1:]
    elif category.startswith("F2"):
        custom_colors[category] = colors_f2[0]
        colors_f2 = colors_f2[1:]
    elif category.startswith("F3") or category.startswith("Peric") or category.startswith("Vasc") :
        custom_colors[category] = colors_f3[0]
        colors_f3 = colors_f3[1:]
    elif category.startswith("F4"):
        custom_colors[category] = colors_f4[0]
        colors_f4 = colors_f4[1:]
    elif category.startswith("F5"):
        custom_colors[category] = colors_f5[0]
        colors_f5 = colors_f5[1:]
    elif category.startswith("UNCERTAIN"):
        custom_colors[category] = colors_f5[0]
        colors_f5 = colors_f5[1:]
    else:
        custom_colors[category] = colors_other[0]
        #colors_other = colors_other[1:]
custom_colors['UNCERTAIN_CELLTYPE'] = "#EE4B2B"
target_adata.uns['cell_type_pred_colors'] = [custom_colors[cat] for cat in target_adata.obs["cell_type_pred"].cat.categories]

sc.settings.figdir="/lustre/scratch126/cellgen/team298/adult_skin_visium/"
sc.pl.umap(
    target_adata,
    color='cell_type_pred',
    show=False,
    frameon=False,
    legend_loc="on data",
    cmap='Reds',
    vmax=0.5,s=5,
    legend_fontsize=4,legend_fontoutline=2
)

sc.pl.umap(
    target_adata,
    color='cell_type_pred',
    show=False,
    frameon=False,
   # legend_loc="on data",
    cmap='Reds',
    vmax=0.5,s=5
)