In [1]:
import os
import torch
import numpy as np

import scanpy as sc
from anndata import AnnData

result_dir = "../edges/"
raw_type=True
use_edge=True

In [2]:
def plot_all(adata,title,color='cell_type',select_type_number=None):
    if select_type_number is not None:
        unique=np.unique(adata.obs['cell_type'],return_counts=True)
        args = np.argsort(-unique[1])[:select_type_number]
        select_type = unique[0][args]
        print("selecting edge types:", select_type)
        flag=[adata.obs['cell_type'][i] in select_type for i in range(adata.shape[0])]
        adata_filtered=adata[flag,:]
    else:
        adata_filtered=adata

    if adata_filtered.shape[0]<100:
        print(title,"no enough number")
        return 0
    
    sc.pp.scale(adata_filtered)
    sc.tl.pca(adata_filtered, n_comps=50)
    sc.pp.neighbors(adata_filtered)  # Compute the neighborhood graph
    sc.tl.umap(adata_filtered)  # Compute UMAP
    # Plot UMAP
    sc.pl.umap(adata_filtered,title=title,color=color, show=True, save=title+".pdf")
    if color!="cell_type":
        for cell_typei in np.unique(adata_filtered.obs['cell_type']):
            adatai=adata_filtered[adata_filtered.obs['cell_type']==cell_typei]
            sc.pl.umap(adatai,title=title+"_"+cell_typei,color=color, show=True, save=(title+"_"+cell_typei+".pdf").replace("/", "_"))
    return adata_filtered

In [3]:
import pandas as pd
meta_df=pd.read_csv("./metadata.csv",index_col="Sample")

samples=['H20.33.004.Cx26.MTG.02.007.1.02.04', 'H20.33.004.Cx26.MTG.02.007.1.01.04', 'H20.33.004.Cx26.MTG.02.007.1.01.05', 'H21.33.011.Cx26.MTG.02.007.3.01.06', 'H21.33.016.Cx26.MTG.02.007.3.01.01', 'H21.33.028.CX28.MTG.02.007.1.01.01', 'H21.33.038.Cx20.MTG.02.007.3.01.02', 'H21.33.040.Cx22.MTG.02.007.3.03.03', 'H21.33.022.Cx26.MTG.02.007.2.M.02', 'H21.33.038.Cx20.MTG.02.007.3.01.04', 'H21.33.005.Cx18.MTG.02.007.02.04', 'H20.33.012.Cx24.MTG.02.007.1.01.01', 'H20.33.012.Cx24.MTG.02.007.1.03.03', 'H21.33.023.Cx26.MTG.02.007.1.03.01', 'H20.33.025.Cx28.MTG.02.007.1.01.02', 'H21.33.012.Cx26.MTG.02.007.1.01.06', 'H20.33.025.Cx28.MTG.02.007.1.01.04', 'H20.33.044.Cx26.MTG.02.007.1.01.04', 'H21.33.023.Cx26.MTG.02.007.1.03.05', 'H20.33.004.Cx26.MTG.02.007.1.02.03', 'H21.33.016.Cx26.MTG.02.007.3.01.02', 'H20.33.040.Cx25.MTG.02.007.1.01.03', 'H21.33.001.Cx22.MTG.02.007.1.01.04', 'H20.33.012.Cx24.MTG.02.007.1.03.02', 'H21.33.015.Cx26.MTG.02.007.1.2', 'H21.33.022.Cx26.MTG.02.007.2.M.03', 'H21.33.005.Cx18.MTG.02.007.02.03', 'H21.33.032.CX24.MTG.02.007.1.01.04', 'H21.33.022.Cx26.MTG.02.007.2.M.04', 'H21.33.006.Cx28.MTG.02.007.1.01.09.03', 'H21.33.015.Cx26.MTG.02.007.1.0', 'H20.33.035.Cx26.MTG.02.007.1.01.03', 'H20.33.015.Cx24.MTG.02.007.1.03.03', 'H21.33.021.Cx26.MTG.02.007.1.04', 'H21.33.025.CX26.MTG.02.007.4.01.04', 'H20.33.025.Cx28.MTG.02.007.1.01.06', 'H21.33.014.CX26.MTG.02.007.1.02.02', 'H21.33.040.Cx22.MTG.02.007.3.03.01', 'H21.33.016.Cx26.MTG.02.007.3.01.03', 'H21.33.021.Cx26.MTG.02.007.1.06', 'H21.33.013.Cx24.MTG.02.007.1.06', 'H21.33.015.Cx26.MTG.02.007.1.1', 'H20.33.001.CX28.MTG.02.007.1.02.03', 'H21.33.028.Cx28.MTG.02.007.1.02.04', 'H21.33.019.Cx30.MTG.02.007.5.01.02', 'H20.33.044.Cx26.MTG.02.007.1.01.03', 'H21.33.011.Cx26.MTG.02.007.3.01.04', 'H21.33.006.Cx28.MTG.02.007.1.01.09.04', 'H21.33.025.CX26.MTG.02.007.4.01.06', 'H21.33.012.Cx26.MTG.02.007.1.01.05', 'H20.33.015.CX24.MTG.02.007.1.03.01', 'H21.33.019.Cx30.MTG.02.007.5.0', 'H20.33.035.Cx26.MTG.02.007.1.01.04', 'H21.33.012.Cx26.MTG.02.007.1.01.04', 'H21.33.031.CX24.MTG.02.007.1.01.01', 'H21.33.040.Cx22.MTG.02.007.3.03.04', 'H20.33.015.CX24.MTG.02.007.1.03.02', 'H21.33.028.Cx28.MTG.02.007.1.02.02', 'H21.33.011.Cx26.MTG.02.007.3.01.05', 'H20.33.004.Cx26.MTG.02.007.1.02.02', 'H21.33.023.Cx26.MTG.02.007.1.03.04', 'H21.33.031.CX24.MTG.02.007.1.01.02', 'H20.33.001.CX28.MTG.02.007.1.02.02', 'H21.33.006.Cx28.MTG.02.007.1.01.09.02', 'H20.33.001.Cx28.MTG.02.007.1.01.03', 'H21.33.025.CX26.MTG.02.007.4.01.02', 'H20.33.040.Cx25.MTG.02.007.1.01.04', 'H21.33.001.Cx22.MTG.02.007.1.01.03', 'H20.33.044.Cx26.MTG.02.007.1.01.02']
samples=list(set(samples).intersection(set(meta_df.index.tolist())))
print(len(samples))

sample_dict={}
for i in range(len(samples)):
    sample_dict[str(i)]=samples[i]

def add_meta(adata):
    for metai in ['Overall AD neuropathological Change', 'Thal', 'Braak', 'CERAD score', 'Overall CAA Score', 'Highest Lewy Body Disease', 'Atherosclerosis', 'Arteriolosclerosis', 'LATE', 'Cognitive Status']:
        tmp=[meta_df.loc[sample_dict[adata.obs["batch"][j]],metai] for j in range(len(adata.obs["batch"]))]
        adata.obs[metai]=tmp
    return adata

In [None]:
adata=sc.read_h5ad('../edges/merged_adata_filtered_all_3_softmax.h5ad')
print(adata)
adata=add_meta(adata)
print(adata)

metas=['Overall AD neuropathological Change', 'Thal', 'Braak', 'CERAD score', 'Overall CAA Score', 'Highest Lewy Body Disease', 'Atherosclerosis', 'Arteriolosclerosis', 'LATE', 'Cognitive Status']
to_draw=["batch",'cell_type']+list(reversed(metas))
for metai in to_draw:
    plot_all(adata,"all_"+metai,color=metai,select_type_number=None)