In [None]:
import os
import torch
import numpy as np

import scanpy as sc
from anndata import AnnData

result_dir = "../edges/"
raw_type=True
use_edge=True

In [None]:
def plot_all(adata,title,color='cell_type',select_type_number=None):
    if select_type_number is not None:
        unique=np.unique(adata.obs['cell_type'],return_counts=True)
        args = np.argsort(-unique[1])[:select_type_number]
        select_type = unique[0][args]
        print("selecting edge types:", select_type)
        flag=[adata.obs['cell_type'][i] in select_type for i in range(adata.shape[0])]
        adata_filtered=adata[flag,:]
    else:
        adata_filtered=adata

    sc.pp.scale(adata_filtered)
    sc.tl.pca(adata_filtered, n_comps=50)
    sc.pp.neighbors(adata_filtered)  # Compute the neighborhood graph
    sc.tl.umap(adata_filtered)  # Compute UMAP
    # Plot UMAP
    sc.pl.umap(adata_filtered,title=title,color=color, show=True, save=title+".pdf")
    if color!="cell_type":
        for cell_typei in np.unique(adata_filtered.obs['cell_type']):
            adatai=adata_filtered[adata_filtered.obs['cell_type']==cell_typei]
            sc.pl.umap(adatai,title=title+"_"+cell_typei,color=color, show=True, save=title+"_"+cell_typei+".pdf")
    return adata_filtered

In [None]:
# Metadata for each sample
samples = ['Lung6', 'Lung5_Rep1', 'Lung5_Rep3', 'Lung5_Rep2', 'Lung9_Rep1', 'Lung9_Rep2', 'Lung12', 'Lung13']

# Sex metadata
sex = {
    'Lung5_Rep1': 'F', 'Lung5_Rep2': 'F', 'Lung5_Rep3': 'F',
    'Lung6': 'M',
    'Lung9_Rep1': 'F', 'Lung9_Rep2': 'F',
    'Lung12': 'F',
    'Lung13': 'M'
}

# Histological diagnosis metadata
histological_diagnosis = {
    'Lung5_Rep1': 'adenocarcinoma', 'Lung5_Rep2': 'adenocarcinoma', 'Lung5_Rep3': 'adenocarcinoma',
    'Lung6': 'squamous cell carcinoma',
    'Lung9_Rep1': 'adenocarcinoma', 'Lung9_Rep2': 'adenocarcinoma',
    'Lung12': 'adenocarcinoma',
    'Lung13': 'adenocarcinoma'
}

# Grade metadata
grade = {
    'Lung5_Rep1': 'G1', 'Lung5_Rep2': 'G1', 'Lung5_Rep3': 'G1',
    'Lung6': 'G2',
    'Lung9_Rep1': 'G3', 'Lung9_Rep2': 'G3',
    'Lung12': 'G3',
    'Lung13': 'G1'
}

# T component of TNM
t_classification = {
    'Lung5_Rep1': 'T2a', 'Lung5_Rep2': 'T2a', 'Lung5_Rep3': 'T2a',
    'Lung6': 'T2b',
    'Lung9_Rep1': 'T3', 'Lung9_Rep2': 'T3',
    'Lung12': 'T4',
    'Lung13': 'T3'
}

# N component of TNM
n_classification = {
    'Lung5_Rep1': 'N2', 'Lung5_Rep2': 'N2', 'Lung5_Rep3': 'N2',
    'Lung6': 'N2',
    'Lung9_Rep1': 'N1', 'Lung9_Rep2': 'N1',
    'Lung12': 'N0',
    'Lung13': 'N0'
}

# M component of TNM
m_classification = {
    'Lung5_Rep1': 'M0', 'Lung5_Rep2': 'M0', 'Lung5_Rep3': 'M0',
    'Lung6': 'M0',
    'Lung9_Rep1': 'M0', 'Lung9_Rep2': 'M0',
    'Lung12': 'M0',
    'Lung13': 'M0'
}
# Stage metadata
stage = {
    'Lung5_Rep1': 'IIIA', 'Lung5_Rep2': 'IIIA', 'Lung5_Rep3': 'IIIA',
    'Lung6': 'IIIA',
    'Lung9_Rep1': 'IIIA', 'Lung9_Rep2': 'IIIA',
    'Lung12': 'IIIA',
    'Lung13': 'IIB'
}

meta={"sex":sex,"histological_diagnosis":histological_diagnosis,"grade":grade,"t_classification":t_classification,"n_classification":n_classification,"stage":stage}


samples=['Lung6', 'Lung5_Rep1', 'Lung5_Rep3', 'Lung5_Rep2', 'Lung9_Rep1', 'Lung9_Rep2', 'Lung12','Lung13']
sample_dict={}
cnt=0
for i in range(len(samples)):
    sample_dict[str(i)]=samples[i]

def add_meta(adata):
    for metai in meta.keys():
        tmp=[meta[metai][sample_dict[adata.obs["batch"][j]]] for j in range(len(adata.obs["batch"]))]
        adata.obs[metai]=tmp
    return adata

In [None]:
adata20=sc.read_h5ad('merged_adata_filtered20_3.h5ad')
print(adata20)
adata20=add_meta(adata20)
print(adata20)

In [None]:
to_draw=["batch",'cell_type']+list(meta.keys())
for metai in to_draw:
    plot_all(adata20,"20_"+metai,color=metai,select_type_number=None)

In [None]:
adata10=sc.read_h5ad('merged_adata_filtered10_3.h5ad')
print(adata10)
adata10=add_meta(adata10)
print(adata10)
to_draw=["batch",'cell_type']+list(meta.keys())
for metai in to_draw:
    plot_all(adata10,"10_"+metai,color=metai,select_type_number=None)

In [None]:
adata=sc.read_h5ad('merged_adata_filtered_all_3.h5ad')
print(adata)
adata=add_meta(adata)
print(adata)
to_draw=["batch",'cell_type']+list(meta.keys())
for metai in to_draw:
    plot_all(adata,"all_"+metai,color=metai,select_type_number=None)