# Case Study of Cancer Spatial Transcriptome Data Analysis

In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import os
import sklearn
import Spanve
import matplotlib.pyplot as plt
import seaborn as sns
import squidpy as sq
print(sklearn.__version__)

### Quality Control

In [None]:
adata = sc.datasets.visium_sge('Visium_Human_Breast_Cancer')
adata.var_names_make_unique()
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(18, 4))
sns.histplot(adata.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(adata.obs["total_counts"][adata.obs["total_counts"] < 10000], kde=False, bins=40, ax=axs[1])
sns.histplot(adata.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[2])
sns.histplot(adata.obs["n_genes_by_counts"][adata.obs["n_genes_by_counts"] < 4000], kde=False, bins=60, ax=axs[3])

In [None]:
sc.pp.filter_cells(adata, min_counts=10)
sc.pp.filter_genes(adata, min_cells=5)
print(adata.shape)

### Preprocess

In [None]:
adata.X = adata.X.toarray()
adata.layers["normalized"] = Spanve.adata_preprocess(adata).X
adata.layers['counts'] = adata.X.copy()
adata.layers['normlized_counts'] =  Spanve.adata_preprocess_int(adata).X

### Run Spanve

In [None]:
# adata.X = adata.layers['normlized_counts']
svmodel = Spanve.Spanve(adata)
svmodel.fit(verbose=True)
svmodel.rejects.sum()

In [None]:
X = adata.layers["normalized"]
newX = svmodel.impute_from_graph(X,verbose=True)
adata.layers['imputated'] = newX

In [None]:
svmodel.save('./Results/case.svmodel.pkl',format='pickle')

### Cluster

In [None]:
sq.gr.spatial_neighbors(adata,coord_type="generic", delaunay=False)

In [None]:
from sklearn.decomposition import PCA
newX = PCA(n_components=50).fit_transform(newX[:,svmodel.rejects])
cluster = Spanve.AutoCluster(init_k=3,max_k=15,criteria = 'bic')
labelx = cluster.fit_predict(newX,verbose=True)
cluster.plot_elbow()

In [None]:
# sq.gr.spatial_neighbors(adata,coord_type="generic", delaunay=True)
sq.gr.nhood_enrichment(adata, cluster_key="KMeans")

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(13, 7))
sq.pl.nhood_enrichment(
    adata,
    cluster_key="KMeans",
    figsize=(8, 8),
    title="Neighborhood enrichment adata",
    ax=ax[0],
    cmap = 'gist_earth'
)
# sq.pl.spatial_scatter(adata, color="KMeans", shape=None, ax=ax[1])
sc.pl.spatial(adata, img_key="hires", color="KMeans",ax=ax[1])
with rc_context({'figure.figsize': (7, 7)}):
    sc.pl.spatial(adata, img_key="hires", color="KMeans",show=False)
plt.subplots_adjust(left=0.05,right=0.95,bottom=0.05,top=0.95)
# plt.savefig('./subplot1_cluster.pdf')

In [None]:
sc.tl.rank_genes_groups(adata, "KMeans", method="wilcoxon")
dfgroup = sc.get.get.rank_genes_groups_df(adata,group =[str(i) for i in range(7)])
# dfgroup.to_csv('./Results/case_markers.csv')

In [None]:
adata.X = adata.layers['imputated']
from matplotlib.pyplot import rc_context
with rc_context({'figure.figsize': (10, 4)}):
    sc.pl.rank_genes_groups_heatmap(
        adata, 
        groups=[str(i) for i in range(7)], 
        n_genes=5, 
        groupby="KMeans",
        cmap='viridis',
        show=False
        )
plt.subplots_adjust(left=0.05,right=0.95,bottom=0.15,top=0.95)
# plt.savefig('./subplot3_heatmap.pdf')

### Spatial Co exp

In [None]:
from itertools import combinations
from tqdm import tqdm
import networkx as nx
def lighten_color(color, amount=0.5):
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
def draw_net(df_net,show_label=False,weight_scale=1,**kwargs):
    node1 = df_net['node1'].values
    node2 = df_net['node2'].values
    weight = df_net['weight'].values
    weight = weight / weight_scale
    all_nodes = np.unique(np.concatenate([node1,node2]).astype(str))
    G = nx.Graph()
    G.add_nodes_from(all_nodes,)
    # add edges
    for i in range(node1.shape[0]):
        G.add_edge(node1[i],node2[i])
    # delete nodes with degree 0
    G.remove_nodes_from(list(nx.isolates(G)))
    
    # add edge weights
    for i in range(node1.shape[0]):
        G[node1[i]][node2[i]]['weight'] = weight[i]
    
    return G

#### Markers

In [None]:
dfmarkers = dfgroup.groupby('group').apply(lambda x: x.sort_values(by='scores').iloc[-10:]).reset_index(drop=True).set_index('group')
marker_groups = dfmarkers['names'].reset_index().set_index('names')['group']

In [None]:
adata.X = adata.layers['normalized']
scmodel = Spanve.Spanve(adata)
# search_space = list(combinations(adata.var_names,2))
spatial_coexp(scmodel,search_space = list(combinations(dfmarkers['names'].unique().tolist(),2)) ,verbose=True)

scmodel.fit(verbose=True)
scmodel.rejects.sum()

In [None]:
scvdf = pd.DataFrame(np.array(scmodel.result_df.index.str.split('~').tolist()),columns=['node1','node2'])
scvdf['weight'] = scmodel.result_df.loc[scvdf['node1']+'~'+scvdf['node2'],'ent'].values
scvdf['weight'] = (scvdf['weight'] - scvdf['weight'].min()) / scvdf['weight'].max()
scvdf = scvdf[scmodel.result_df.loc[scvdf['node1']+'~'+scvdf['node2'],'rejects'].values]

In [None]:
# scvdf = pd.read_csv('./Results/case.scvdf.csv',index_col=0)
# dfmarkers = pd.read_csv('./Results/case_markers.csv',index_col=0) 
temp = scvdf.groupby('node1').apply(lambda x: x.sort_values('weight')[-5:])
# temp = temp.query("weight > 0.2")
G = draw_net(temp,weight_scale=1)

In [None]:
f,ax = plt.subplots(figsize=(5.6,6))
cmap = plt.get_cmap('Paired')
# positions = nx.spring_layout(G, k=0.35, scale=50)
node_color2 = {i:lighten_color(adata.uns['KMeans_colors'][eval(n)],amount=0.5) for n,sg in dfmarkers['names'].groupby('group').apply(lambda x: x.tolist()).to_dict().items() for i in sg}
# draw
show_label = False
nx.draw_networkx(
    G, 
    # pos=kwargs.get('pos',nx.spring_layout(G, k=0.15, scale=25)),
    node_size=[v*40 for k,v in G.degree], 
    labels={node: node for node in G.nodes()} if show_label else None,
    font_color='.3',
    edgecolors='w',
    # edge width as weight
    width=2,#[G[u][v]['weight'] for u,v in G.edges()]*5,
    edge_color = [G[u][v]['weight'] for u,v in G.edges()],
    edge_cmap = plt.get_cmap('binary'),
    edge_vmin=0.2,edge_vmax=1,
    font_size=4.5,
    node_color = [node_color2[i] for i in G.nodes],
    ax=ax,
    pos=positions,
    )
plt.subplots_adjust(left=0.05,right=0.95,bottom=0.05,top=0.95)

plt.savefig('figures/subplot4_geneNet.pdf')

In [None]:
edge_types = []
for i in range(scvdf.shape[0]):
    mg0 = list(marker_groups[scvdf.iloc[i,0]])
    mg1 = list(marker_groups[scvdf.iloc[i,1]])
    if len(set(mg1) & set(mg0)) >= 1: edge_types.append(True)
    else: edge_types.append(False)
scvdf['edge_types'] = edge_types

In [None]:
cross_edges = pd.merge(
    right = pd.merge(
        right = temp.reset_index(drop=True),left = marker_groups,
        right_on='node1',
        left_index=True,how='outer'
    ),left = marker_groups,right_on='node2',left_index=True)

In [None]:
sns.clustermap(
    cross_edges.groupby(['group_x','group_y']).sum().unstack(),
    cmap="Blues",
    mask=np.tri(7,k=-1).T,
    # cbar=False,
    col_cluster=False,
    col_linkage=False,
    # row_linkage=False
)

In [None]:
f = plt.figure()
ax1 = plt.subplot2grid((1, 3), (0, 0), colspan=2)
ax2 = plt.subplot2grid((1, 3), (0, 2))

sns.heatmap(
    cross_edges.groupby(['group_x','group_y']).sum().unstack(),
    cmap="Blues",
    mask=np.tri(7,k=-1).T,
    cbar=False,
    xticklabels=range(7),
    yticklabels=range(7),
    square=True,ax=ax1
)
ax1.set_xlabel('cluster')
ax1.set_ylabel('cluster')

# scvdf_copy = scvdf.copy()
# scvdf_copy['edge_types'].apply(lambda x: 'Cross' if x else 'Within')
cmap = plt.get_cmap('Blues')
sns.boxplot(scvdf,y = 'weight', x='edge_types', ax=ax2, palette='Blues')
ax2.set_xlabel('')
ax2.set_ylabel('Weight')
# plt.subplots_adjust(left=0.1,right=0.97)
ax2.set_xticklabels(['Cross','Within'])
plt.tight_layout()
plt.savefig('figures/subplot5_weight.pdf')

In [None]:
scvdf.to_csv('./Results/case.scvdf.csv')

#### pathway

In [None]:
def parse_sif_file(path):
    df = pd.read_table(path,header=None)
    df.columns = ['node1','edge_attr','node2']
    return df

In [None]:
def run_path_spa_coexp(dfpathway):
    adata.X = adata.layers['normalized']
    
    sccmodel = Spanve.Spanve(adata)
    edge_pairs = dfpathway[['node1','node2']].drop_duplicates()
    edge_pairs = edge_pairs[edge_pairs['node1'].isin(adata.var_names) & edge_pairs['node2'].isin(adata.var_names)]

    spatial_coexp(sccmodel,search_space = edge_pairs.values, verbose=True, groupby='KMeans')
    sccmodel.fit(verbose=True)
    spacoexp_res = sccmodel.result_df

    return sccmodel

    sccmodel.adata.obs['KMeans'] = adata.obs['KMeans']
    sc.pp.scale(sccmodel.adata)
    sc.tl.pca(sccmodel.adata)
    sc.tl.rank_genes_groups(sccmodel.adata, "KMeans",method='wilcoxon')
    dfgroup = sc.get.get.rank_genes_groups_df(sccmodel.adata,group =[str(i) for i in range(7)])
    markers = dfgroup.groupby('group').apply(lambda x: x.sort_values('scores',ascending=False).set_index('names').iloc[0:5,:].index)
    res = []
    for g in np.arange(7).astype(str):
        scores = pd.DataFrame(dfgroup.loc[dfgroup['group']==g,['names','scores']].set_index('names')['scores'][markers[g]])
        scores['fdrs'] = spacoexp_res.loc[markers[g],'fdrs']
        scores['group'] = g
        res.append(scores)
    return pd.concat(res,axis=0)

In [None]:
pathes = []
for root,_,files in os.walk('./data/pathway_sif/'):
    for file in files:
        if not file.endswith('txt'): continue
        print(file)
        dfpathway = parse_sif_file(os.path.join(root,file))
        dfpathway['path'] = file.replace('.txt','')
        pathes.append(dfpathway)
pathes = pd.concat(pathes,axis=0)
# pathes = pathes[['node1','node2']].drop_duplicates()

In [None]:
sccmodel = run_path_spa_coexp(pathes)

In [None]:
pathes['id'] = pathes['node1']+'~'+pathes['node2']
sccmodel.result_df['pathes'] = pathes.groupby('id')['path'].apply(lambda x: ','.join(x.unique().tolist()))[sccmodel.result_df.index]

In [None]:
sccmodel.result_df[sccmodel.rejects]['pathes'].value_counts()

In [None]:
sccmodel.result_df[sccmodel.rejects].sort_values('fdrs')

In [None]:
sccmodel.adata.uns = adata.uns.copy()

scale_factor = 5
f,axes = plt.subplots(3,4,figsize = (3*scale_factor,2*scale_factor))
axes = axes.flatten()
for n,genepair in enumerate(sccmodel.result_df[sccmodel.rejects].sort_values('fdrs').index[0:12]):
    sc.pl.spatial(
        sccmodel.adata, img_key="lowres", color=genepair,
        vmax=3,vmin=-3,alpha_img=0.8,cmap='vlag',
        ax = axes[n], show=False
    )
f.savefig('./Results/case.path.spatial_coex.pdf',bbox_inches='tight')

### Chemokines score

In [None]:
tls_sig = "CCL2 CCL3 CCL4 CCL5 CCL8 CCL18 CCL19 CCL21 CXCL9 CXCL10 CXCL11 CXCL13".split(' ')
assert np.array([i in adata.var_names for i in tls_sig]).all()
from sklearn.decomposition import PCA
tlsX = adata.to_df().loc[:,tls_sig]
model = PCA(n_components=1)
tls_score = model.fit_transform(tlsX)
adata.obs['CS'] = tls_score.flatten()
sc.pl.spatial(adata, img_key="hires", color="TLS",save='case.TLS.pdf',spot_size=250, cmap = 'Reds',alpha = 0.7,
)

### Tumor marker score

In [None]:
can_markers = "ITK	ARHGAP10	EDNRA	SELP	TLL1	SEMA6A	CDH5	CPNE1	TPST2	CRYBB2	SULF2	KLRF1	SELE	ALPI	FAM177A1	ADH1B	AKR1A1	CAMK1D	CHST15	GOLM1	ISLR2	CD36	PRDM1	B3GNT2	TMPRSS11D	STOM	TNS2	MET	VCAM1	JAG1	THSD1	PSD	IL3RA	KIN	BCAM	C1GALT1C1	ENG	RSPO3	DOCK9	NOTCH1	KRT19	KRT8	EPCAM	ESR1	KRT18	ERBB2	BRCA1	BRCA2	PDCD1	VIM	MKI67	TRPS1	PIP	NKX2-1	GATA3	EREG	CYP3A4	CYP3A5	CYP3A7	STAT1	RIC8A	IRF9	ISG15	STAT2	JAK1	CDH1	CTNNA1	GNA13	OAS1"
can_markers = can_markers.split('\t')
len(set(can_markers) & set(adata.var_names)) / len(can_markers)
adata.X = adata.layers['imputated']
canX = adata.to_df().loc[:,list(set(can_markers) & set(adata.var_names))]
model = PCA(n_components=1)
can_score = model.fit_transform(canX)
adata.obs['cancer marker score'] = np.abs(can_score.flatten())
sc.pl.spatial(
    adata, color = 'cancer marker score',
    spot_size=250, cmap = 'Reds_r',alpha = 0.7,
    save = 'can_marker'
)

### Tumor heterogeneity

In [None]:
tmdata = adata[adata.obs['KMeans'].isin(['0','1','3','4']),adata.var['spanve_spatial_features']].copy()
tmdata.X = tmdata.layers['imputated']
sc.pp.pca(tmdata)
sc.pp.neighbors(tmdata)
sc.tl.umap(tmdata)
sq.gr.spatial_neighbors(tmdata)

In [None]:
sc.tl.diffmap(tmdata)
root_ixs = tmdata.obsm["X_diffmap"][:, 2].argmin()
tmdata.uns["iroot"] = root_ixs
sc.tl.dpt(adata)

In [None]:
sc.pl.scatter(
    tmdata,
    basis="diffmap",
    color=["dpt_pseudotime",'KMeans'],
    color_map="gnuplot2",
    components=[2, 3],
    save = 'diffuesemap'
)

In [None]:
sc.tl.rank_genes_groups(tmdata, "KMeans", method="wilcoxon")

In [None]:
tmdata.X = tmdata.layers['imputated']
sc.tl.dendrogram(tmdata,groupby='KMeans')
sc.pl.rank_genes_groups_heatmap(
    tmdata, 
    groups=['0','1','3','4'], 
    n_genes=5, 
    groupby="KMeans",
    cmap='viridis',
    show=False,
    save = 'tumor_hetero'
    )

In [None]:
dfgroup.query("group == '4'").sort_values(by='scores',ascending=False)

In [None]:
print(
    *dfgroup.query("group == '4'").sort_values(by='scores',ascending=False)['names'][0:15].tolist(),
    sep='\n')

### Cluster Functions

In [None]:
import pandas as pd
import numpy as np
import gseapy as gp
import os
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
base_dir = 'K:/Caiguoxin/Data/Annotation/gseapy/'
gene_sets = ['Reactome_2022','KEGG_2019_Human',
            'GO_Biological_Process_2021','GO_Molecular_Function_2021','GO_Cellular_Component_2021',
            'WikiPathways_2019_Human','MSigDB_Hallmark_2020']
gene_sets = [os.path.join(base_dir,f"enrichr.{i}.gmt") for i in gene_sets]
def func_enrichr(genes):
    enrich_res = gp.enrichr(
        gene_list=list(genes),
        gene_sets=gene_sets,
        outdir=None,
        verbose=False
    )
    return enrich_res.results[enrich_res.results['Adjusted P-value']<0.05]

def func_gsea(scores):
    gsea_res = gp.prerank(
        rnk=scores,
        gene_sets=gene_sets,
        outdir=None,
        # min_size=1
    )
    ret = gsea_res.res2d.query("`FDR q-val` < 0.05")
    ret.loc[:,'Gene_set'] = ret['Term'].str.split('__').str[0].str.replace('enrichr.','')
    ret.loc[:,'Term'] = ret['Term'].str.split('__').str[1]
    return ret.rename(columns = {'FDR q-val':'Adjusted P-value'})

#### cell type annotation

In [None]:
celltype_scores = pd.DataFrame(columns = ['group','celltype','scores'])
df = pd.read_csv('./Results/case_markers.csv',index_col=0)
for gid,scores in {i:df.groupby('group').apply(lambda x: x.set_index('names')['scores'])[i] for i in range(7)}.items():
    for cell_type,markers in gp.read_gmt('/share/home/biopharm/Caiguoxin/Data/Annotation/Cell_marker_Breast.gmt').items():
        celltype_scores.loc[len(celltype_scores),:] = gid, cell_type, np.nanmean(scores[pd.Index(markers).intersection(scores.index)])

In [None]:
scale_factor = 3.5
cmap = plt.get_cmap('autumn_r')
f,axes = plt.subplots(
    7,1,dpi=150,
    figsize=(3.7,7),
    sharey=True,sharex=True
)
n_show = 3
# plt.subplots_adjust(wspace=1)
for n,ax in enumerate(axes.flatten()):
    if n == 7:ax.set_axis_off();break
    dfvis = celltype_scores.groupby(by='group').apply(lambda x: x.sort_values('scores',ascending=False).iloc[0:3,:]).reset_index(drop=True).query('group == @n')
    ax.set_ylim(0,n_show)
    
    for i in range(dfvis.shape[0]):
        ax.text(
            x = 0.2,#logp.max()/20,
            y = dfvis.shape[0]-i-0.5,
            s = dfvis['celltype'].values[i],
            verticalalignment ='center',
            color = 'k',
            fontsize=10,
            # transform = ax.transAxes
        )
    ax.barh(
        y = np.arange(dfvis.shape[0])+0.5,
        width = dfvis['scores'],
        color = cmap(logp / logp.max() / 1.5),alpha=0.5
    )
    
    # ax.set_xlim(xmax = logp.max()*1.1)
    # ax.set_title(n)
    ax.set_ylabel(n)
    ax.set_yticks([])
plt.tight_layout()
ax.set_xlabel('scores')
# ax.set_xlim(0,10)
# ax.set_ylim(0,10)
plt.savefig('figures/subplot2_celltype.pdf')

#### Cluster function

In [None]:
df = pd.read_csv('./Results/case_svmarkers.csv',index_col=0)
df = df.set_index('names').groupby('group',).apply(lambda x: x.sort_values('pvals_adj')['scores'].iloc[0:1000])
func_res = {}
for gid in range(7):
    score = df[gid]
    res = func_gsea(score)
    print(gid,res.shape[0])
    func_res[gid] = res

In [None]:
writer = pd.ExcelWriter('./Results/case.cluster.func.gsea.xlsx')
for k,v in func_res.items():
    v.to_excel(writer, sheet_name=f'Cluster{k}',index=False)
writer.close()

In [None]:
df = pd.read_csv('./Results/case_svmarkers.csv',index_col=0)
func_res = {}
for gid,genes in df.groupby('group').apply(lambda x: x.sort_values('scores',ascending=False)['names'][0:100].values).to_dict().items():
    res = func_enrichr(genes)
    print(gid,res.shape[0])
    func_res[gid] = res

In [None]:
# draw functional heatmap: x axis: clusters, y axis: functional terms, color: -log10(p-value)
dfvis = pd.concat(func_res).reset_index().rename(columns={'level_0':'group'}).drop(columns=['level_1'])
dfvis.loc[:,'Gene_set'] = dfvis['Gene_set'].str.replace('enrichr.','').str.replace('.gmt','')

dfvis = dfvis.loc[~dfvis['Gene_set'].isin(['GO_Cellular_Component_2021']),:]
dfvis = dfvis.loc[dfvis['Gene_set'].isin(['Reactome_2022','MSigDB_Hallmark_2020']),:]
dfvis = dfvis[dfvis['Term'].isin(dfvis.groupby('group',group_keys=False).apply(lambda x: x.sort_values(by= 'Adjusted P-value',ascending=True)['Term'][0:15]).values)]
dfheatmap = dfvis.pivot_table(
    index = 'Term',
    columns = 'group',
    values = 'Adjusted P-value'
).fillna(1).applymap(lambda x: -np.log10(x))
# draw functional heatmap: x axis: clusters, y axis: functional terms, color: -log10(p-value)
dfvis = pd.concat(func_res).reset_index().rename(columns={'level_0':'group'}).drop(columns=['level_1'])
dfvis.loc[:,'Gene_set'] = dfvis['Gene_set'].str.replace('enrichr.','').str.replace('.gmt','')
print("before filter: ",dfvis.shape[0])
dfvis = dfvis.loc[~dfvis['Gene_set'].isin(['GO_Cellular_Component_2021','KEGG_2019_Human','WikiPathways_2019_Human']),:]
# delete terms that include: viral, infection SARS-CoV-2 riboso SPR subunit
dfvis = dfvis.loc[~dfvis['Term'].str.lower().str.contains('RB1|kappa|platelet|EPH|via|EPHB|NCA|ROBO|19221|60333|2479|metal|calnexin|69278|neuron|NMD|Chaperones|structure|Elongation|Cristae|GTPase|43S|Seleno|viral|infection|SARS-CoV-2|riboso|SRP|subunit|disease|Rejection|diabetes|Leishmaniasis|Complement'.lower()),:]
dfvis = dfvis.loc[dfvis['Term'].str.lower().str.contains('Myc|Adhesion|telomere|mitotic|cycle|Signaling|antigen|MHC|immune|T cell|Estrogen|tumor|integrin|death'.lower()),:]
print("after filter: ",dfvis.shape[0],dfvis['Term'].nunique())
# dfvis = dfvis.loc[dfvis['Gene_set'].isin(['Reactome_2022','MSigDB_Hallmark_2020']),:]
# dfvis = dfvis[dfvis['Term'].isin(dfvis.groupby('group',group_keys=False).apply(lambda x: x.sort_values(by= 'Adjusted P-value',ascending=True)['Term'][0:10]).values)]
dfheatmap = dfvis.pivot_table(
    index = 'Term',
    columns = 'group',
    values = 'NES'
).fillna(0)#.applymap(lambda x: -np.log10(x))

In [None]:
sns.set_style('white')
cmap = plt.get_cmap('Set1')
# fill inf
sns.clustermap(
    data = dfheatmap.replace([np.inf, -np.inf], np.nan).fillna(10),
    cmap = 'RdBu_r',
    figsize = (12,12),
    cbar_kws = {'label':'NES','ticks':[-5,0,5],},
    yticklabels = True,
    xticklabels = True,
    vmin = -5,
    vmax = 5,
    cbar_pos=(0.03, 0.85, 0.015, 0.1),
)
# plt.tight_layout()
# save
plt.savefig('figures/heatmap_functional.pdf')

### os

In [None]:
from utils import os_utils
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
def load_cancer_exp(cancer_type,survival_use = 'OS'):
    samples = all_samples[all_samples['cancer type abbreviation'] == cancer_type]
    df = data.loc[:,samples.index]

    # delete all 0 or na genes
    df = df.fillna(0)
    df = df[~(df == 0).all(axis=1)]
    
    # drop duplicated genes
    idx = ~df.index.duplicated()
    df = df[idx]
    
    samples_os = dataos.set_index('_PATIENT',drop=True).loc[pd.Index(samples['sample'].values).intersection(pd.Index(samples['sample'].values)),[survival_use,f'{survival_use}.time']]
    return df,samples_os

In [None]:
def creat_merged_df(gene_set,thres=1500):
    n = len(gene_set)
    gene_set = list(set(gene_set) & set(dfexp.index))
    if n > len(gene_set):
        print("Genes in the data:",len(gene_set),'not include:',n-len(gene_set))
    
    ret = pd.merge(
        dfexp.loc[gene_set,samples].T,
        dfos.loc[samples,:].rename(columns = {dfos.columns[0]:"OS",dfos.columns[1]:"OS.time"}),
        right_index=True,left_index=True
    ).query("`OS.time` < @thres")
    ret.loc[:,gene_set] = ret.loc[:,gene_set].apply(lambda x: (x-x.mean()) / x.std())
    ret.loc[:,'geneset_median'] = ret.loc[:,gene_set].median(axis=1).values
    return ret

In [None]:
from utils.os_utils import *
import seaborn as sns
def QuantileSampleSplitPlot(df_coxreg,gene,ax,q=0.1,**plotargs):
    upper = df_coxreg.loc[:,gene].quantile(1-q)
    lower = df_coxreg.loc[:,gene].quantile(q)
    upper_ix = df_coxreg.loc[:,gene] >= upper
    lower_ix = df_coxreg.loc[:,gene] <= lower
    df_upper = df_coxreg.loc[upper_ix,[gene,'OS.time','OS']]
    df_upper['label'] = 'High'
    df_lower = df_coxreg.loc[lower_ix,[gene,'OS.time','OS']]
    df_lower['label'] = 'Low'
    df = pd.concat([df_upper,df_lower])
    try:
        results = logrank_test(
            durations_A=df_upper['OS.time'],
            event_observed_A=df_upper['OS'],
            durations_B=df_lower['OS.time'],
            event_observed_B=df_lower['OS'],
        )
    except Exception as e:
        print(df_lower,df_upper)
        raise e
    p_value = results.p_value

    KMFFitPlot(df,df['label'],ax=ax,**plotargs)
    ax.plot([],[],label=f"p-value = {p_value:.2f}",color='white')
    ax.legend(
        frameon=False,loc='best'
    )
    ax.set_title(gene)
    ax.set_xlabel('')
    return ax

In [None]:
data_dir = '../../Data/TCGA/EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena.gz'
os_dir = '../../Data/TCGA/Survival_SupplementalTable_S1_20171025_xena_sp.txt'
type_dir = '../../Data/TCGA/TCGASubtype.20170308.tsv.gz'

data = pd.read_table(data_dir,index_col=0)
sample_type = pd.read_table(type_dir)
dataos = pd.read_table(os_dir)

valid_idx = data.columns.str[0:12].isin(dataos['_PATIENT'])
data = data.loc[:,valid_idx]
sample_map = dataos.loc[:,['_PATIENT','cancer type abbreviation']].drop_duplicates().set_index('_PATIENT',drop=True)
all_samples = pd.DataFrame(columns = ['sample','site'],index= data.columns)
all_samples['sample'] = all_samples.index.str[0:12]
all_samples['site'] = all_samples.index.str[-2:]
all_samples = pd.merge(left = all_samples,right = sample_map,right_index=True,left_on='sample')

In [None]:
dfexp,dfos = load_cancer_exp('BRCA','PFI') # OS	DSS	DFI	PFI		
dfsubtype = dataos.set_index('sample').loc[dfexp.columns,:]
samples = dfsubtype.query("histological_type == 'Infiltrating Ductal Carcinoma'").index
samples = samples[samples.str.endswith('01')]
dfos.index = dfos.index + '-01'

gene_set = ['LINC00645',
 'PVALB',
 'LINC02224',
 'COLEC12',
 'TNFSF10',
 'TMEM150C',
 'SLC39A6',
 'MUC5B',
 'EXOC2',
 'CFB',
 'RPS18',
 'AFP',
 'SDC4',
 'AC037198.2',
 'RPS23']

gene_set_sep = gene_set[0:5]
dfcox = creat_merged_df(gene_set_sep,thres =360*10)
sns.set_style('white')
scale=1.7
f,ax = os_utils.plt.subplots(dpi=300,figsize=(5*scale,3.5*scale))
QuantileSampleSplitPlot(dfcox,gene = 'geneset_median', show_censors=True,q = 0.2, ci_show=False,ax=ax,colors= {"High":"#d1464f","Low":"#455d9f"})
ax.set_ylabel('Survival probability')
ax.set_xlabel('Progression free interval (days)')
ax.set_title('')
f.savefig('figures/survival.pdf',bbox_inches='tight')