## Requirements and Imports


In [1]:
import copy
import os
from pathlib import Path
from itertools import combinations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
from sklearn.preprocessing import normalize
from sklearn.metrics import roc_auc_score , pairwise_distances
import torch
import torch.optim as optim
epoch_num=50
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
import scipy.sparse as sp
from numba import jit
import models
import metrics
device

device(type='cpu')

In [2]:
SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18
plt.rcParams["font.family"] = "Verdana"
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


In [None]:
import scvi
adata = scvi.data.pbmc_dataset()
adata.var['alt_names']=adata.var_names
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
adata

Global seed set to 0


[34mINFO    [0m File data/gene_info_pbmc.csv already downloaded                                                           
[34mINFO    [0m File data/pbmc_metadata.pickle already downloaded                                                         
[34mINFO    [0m File data/pbmc8k/filtered_gene_bc_matrices.tar.gz already downloaded                                      
[34mINFO    [0m Extracting tar file                                                                                       


In [None]:
adata.obs['str_labels'].value_counts()

In [None]:
adata= adata[adata.obs['str_labels'].isin(['B cells', 'CD4 T cells', 'CD8 T cells', 'CD14+ Monocytes', 'Dendritic Cells' ,'FCGR3A+ Monocytes', 'NK cells'])]

In [None]:
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)

In [None]:
adata_scaled = adata.copy()
sc.pp.scale(adata_scaled)

In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color='str_labels')

In [None]:
epoch_num=100
prob_list = models.follow_training_dyn_neural_net(adata, label_key='str_labels',iterNum=epoch_num, device=device)

In [None]:
all_conf , all_var = models.probability_list_to_confidence_and_var(prob_list, n_obs= adata.n_obs, epoch_num=epoch_num)

In [None]:
cutoff_conf , cutoff_var = models.find_cutoff_paramter(adata,'str_labels',device,probability=0.05,percentile=95,epoch_num=epoch_num)

In [None]:
adata.obs["var"] = all_var.detach().numpy()
adata.obs["conf"] = all_conf.detach().numpy()

In [None]:
adata.obs['conf_binaries'] = pd.Categorical((adata.obs['conf'] > cutoff_conf) | (adata.obs['var'] > cutoff_var))

In [None]:
adata.obs['CellType']=adata.obs['str_labels']

In [None]:

adata.obs['CellType']=adata.obs['str_labels']
sns.jointplot(data=adata.obs, x="var", y="conf",height=10,  s=25, ratio=2)
plt.ylabel('Confidence')
plt.xlabel('Variability')
plt.show()

In [None]:
cutoff_conf , cutoff_var

In [None]:
adata.obs['conf_binaries'].value_counts()

In [None]:
sc.pp.pca(adata)#, color=['str_labels'])
sc.pp.neighbors(adata)#, color=['conf_binaries'])
sc.tl.umap(adata)#, color=['conf'])

In [None]:
sc.pl.umap(adata, color=['str_labels'], title='Cell type')
sc.pl.umap(adata, color=['conf_binaries'])
sc.pl.umap(adata, color=['conf'], title='Confidence')

In [None]:
annotation_list =[]
for i in range(adata.n_obs):
    if (adata.obs['conf'][i] > cutoff_conf) | (adata.obs['var'][i] > cutoff_var):
        if (adata.obs['conf'][i] > 0.95) & (adata.obs['var'][i] < 0.15):
            annotation_list.append('Correctly annotated')
        else:
            annotation_list.append('Ambiguously annotated.')
    else:
            annotation_list.append('Erroneously annotated')

adata.obs['Annotation']=annotation_list


In [None]:
adata.obs['conf_non_ambiguous'] = pd.Categorical((adata.obs['conf'] > 0.95) | (adata.obs['var'] < 0.15))
adata.obs['Confidence']= adata.obs['conf']
adata.obs['Variability']= adata.obs['var']

fig = sns.jointplot(data=adata.obs, x="Variability", y="Confidence",hue='Annotation',height=10,  s=25, ratio=2)
plt.show(fig)

In [None]:
gene_mapping = {
    'ENSG00000168685': 'IL7R',
    'ENSG00000126353': 'CCR7',
    'ENSG00000196154': 'S100A4',
    'ENSG00000105369': 'CD79A',
    'ENSG00000153563': 'CD8A',
    'ENSG00000170458': 'CD14',
    'ENSG00000131981': 'LGALS3',
    'ENSG00000105374': 'NKG7',
    'ENSG00000111796': 'KLRB1',
    'ENSG00000179639': 'FCER1A',
    'ENSG00000101439': 'CST3',
    'ENSG00000163736': 'PPBP',
    'ENSG00000116824': 'CD2',
    'ENSG00000188404': 'SELL',
    'ENSG00000111537': 'IFNG',
    'ENSG00000168329': 'CX3CR1',
    'ENSG00000160255': 'ITGB2',
    'ENSG00000125498': 'KIR2DL1',
    'ENSG00000243772': 'KIR2DL3',
    'ENSG00000139187': 'KLRG1',
    'ENSG00000180644': 'PRF1',
    'ENSG00000145649': 'GZMA',
    'ENSG00000100453': 'GZMB',
    'ENSG00000163221': 'S100A12',
    'ENSG00000132965': 'ALOX5AP',
    'ENSG00000178623': 'GPR35',
    'ENSG00000019169': 'MARCO',
    'ENSG00000105371': 'ICAM4',
    'ENSG00000007312': 'CD79b',
}



In [None]:
adata.var['alt_names'].loc['ENSG00000168685'] = 'IL7R'
adata.var['alt_names'].loc['ENSG00000126353']='CCR7'
adata.var['alt_names'].loc['ENSG00000196154']='S100A4'
adata.var['alt_names'].loc['ENSG00000105369']='CD79A'
adata.var['alt_names'].loc['ENSG00000153563']='CD8A'
adata.var['alt_names'].loc['ENSG00000170458']='CD14'
adata.var['alt_names'].loc['ENSG00000131981']='LGALS3'
adata.var['alt_names'].loc['ENSG00000105374']='NKG7'
adata.var['alt_names'].loc['ENSG00000111796']='KLRB1'
adata.var['alt_names'].loc['ENSG00000179639']='FCER1A'
adata.var['alt_names'].loc['ENSG00000101439']='CST3'
adata.var['alt_names'].loc['ENSG00000163736']='PPBP'
adata.var['alt_names'].loc['ENSG00000116824'] = 'CD2'
adata.var['alt_names'].loc['ENSG00000188404']='SELL'
adata.var['alt_names'].loc['ENSG00000111537']='IFNG'
adata.var['alt_names'].loc['ENSG00000168329']='CX3CR1'
adata.var['alt_names'].loc['ENSG00000160255']='ITGB2'
adata.var['alt_names'].loc['ENSG00000125498']='KIR2DL1'
adata.var['alt_names'].loc['ENSG00000243772']='KIR2DL3'
adata.var['alt_names'].loc['ENSG00000139187']='KLRG1'
adata.var['alt_names'].loc['ENSG00000180644']='PRF1'
adata.var['alt_names'].loc['ENSG00000145649']='GZMA'
adata.var['alt_names'].loc['ENSG00000100453']='GZMB'
adata.var['alt_names'].loc['ENSG00000163221'] = 'S100A12'
adata.var['alt_names'].loc['ENSG00000132965']='ALOX5AP'
adata.var['alt_names'].loc['ENSG00000178623']='GPR35'
adata.var['alt_names'].loc['ENSG00000019169']='MARCO'
adata.var['alt_names'].loc['ENSG00000105371']='ICAM4'
adata.var['alt_names'].loc['ENSG00000007312']='CD79b'

marker_genes_dict= {
    'CD4 T cells':['IL7R','CCR7'],
    'CD8 T cells':['CD8A','NKG7','KLRB1'],
    'B cells':['CD79A'],
    'CD14+ Monocytes':['CD14','LGALS3'],
    'NK Cells':['NKG7','KLRB1'],
    'Dendritic Cells':['FCER1A','CST3'],
    'Megakaryocytes':['PPBP']}
    

ax = sc.pl.dotplot(adata, marker_genes_dict, groupby='str_labels', gene_symbols='alt_names', cmap='BuGn')


In [None]:
corr_classified_list =[]
for i in range(adata.n_obs):
    if adata.obs['conf_binaries'][i]==True:
        corr_classified_list.append('Correct annotation')
    else: 
        corr_classified_list.append('Erroneous annotation')
adata.obs['annotation']=corr_classified_list

In [None]:
for celltype in set(adata.obs['str_labels']):
    adata_tmp= adata[adata.obs['str_labels'].isin([celltype])]
    print(celltype)
    
    ax = sc.pl.dotplot(adata_tmp, marker_genes_dict, groupby='annotation',cmap='BuGn', gene_symbols='alt_names')


In [None]:
adata_scaled.obs['conf']= adata.obs['conf']
adata_scaled.obs['var']= adata.obs['var']
adata_scaled.obs['conf_binaries']= adata.obs['conf_binaries']

In [None]:
cd56_genes_dict= {
    'CD56 bright':['IL7R','CCR7','CD2','SELL','IFNG'], 
    'CD56 dim':['CX3CR1','ITGB2','KIR2DL1','KIR2DL3','KLRG1','PRF1','GZMA','GZMB'],     

                   }


In [None]:
adata_nk_cd8= adata[adata.obs['conf_binaries'].isin([True])]
adata_nk_cd8= adata_nk_cd8[adata_nk_cd8.obs['str_labels'].isin(['NK cells','CD8 T cells','CD4 T cells'])]


adata_nk_cd8.obs['conf_non_ambiguous'] = pd.Categorical((adata_nk_cd8.obs['conf'] > 0.95) | (adata_nk_cd8.obs['var'] < 0.15))

corr_classified_list =[]
for i in range(adata_nk_cd8.n_obs):
    if adata_nk_cd8.obs['str_labels'][i]=='CD8 T cells':
        corr_classified_list.append('CD8 T cells')
    elif adata_nk_cd8.obs['str_labels'][i]=='CD4 T cells':
        corr_classified_list.append('CD4 T cells')
    else:
        if adata_nk_cd8.obs['conf_non_ambiguous'][i]==True:
            corr_classified_list.append('NK cells- Correct annotation')
        else: 
            corr_classified_list.append('NK cells- Ambiguous annotation')
adata_nk_cd8.obs['amb_annotation']=corr_classified_list

ax = sc.pl.matrixplot(adata_nk_cd8, cd56_genes_dict, groupby='amb_annotation',cmap='BuGn', gene_symbols='alt_names')



In [None]:
adata_nk= adata[adata.obs['conf_binaries'].isin([True])]
adata_nk= adata_nk[adata_nk.obs['str_labels'].isin(['NK cells'])]


adata_nk.obs['conf_non_ambiguous'] = pd.Categorical((adata_nk.obs['conf'] > 0.95) | (adata_nk.obs['var'] < 0.15))

corr_classified_list =[]
for i in range(adata_nk.n_obs):
    if adata_nk.obs['conf_non_ambiguous'][i]==True:
        corr_classified_list.append('Correct annotation')
    else: 
        corr_classified_list.append('Ambiguous annotation')
adata_nk.obs['amb_annotation']=corr_classified_list


adata_nk_sacled = adata_nk.copy()
#adata_nk.obs['conf_ambiguous'] = pd.Categorical(adata_nk.obs['conf'] > 0.95)
ax = sc.pl.matrixplot(adata_nk, cd56_genes_dict, groupby='amb_annotation',cmap='BuGn', gene_symbols='alt_names')



In [None]:
adata_nk.obs['conf_non_ambiguous'].value_counts()

In [None]:
bdata= adata[adata.obs['conf_binaries'].isin([True])]
adata_nk= bdata[bdata.obs['str_labels'].isin(['NK cells'])]
adata_cd4= bdata[bdata.obs['str_labels'].isin(['CD4 T cells'])]
adata_cd8= bdata[bdata.obs['str_labels'].isin(['CD8 T cells'])]
adata_nk.obs['conf_non_ambiguous'] = pd.Categorical((adata_nk.obs['conf'] > 0.95) | (adata_nk.obs['var'] < 0.15))
adata_nk_cells_non_ambiguous= adata_nk[adata_nk.obs['conf_non_ambiguous'].isin([True])]
adata_nk_cells_ambiguou= adata_nk[adata_nk.obs['conf_non_ambiguous'].isin([False])]


In [None]:
mean_non_ambiguous = np.mean(adata_nk_cells_non_ambiguous.X , axis=0)
mean_ambiguous = np.mean(adata_nk_cells_ambiguou.X , axis=0)
mean_cd8 = np.mean(adata_cd8.X , axis=0)
mean_cd4 = np.mean(adata_cd4.X , axis=0)

print(np.linalg.norm(mean_non_ambiguous-mean_ambiguous))
print(np.linalg.norm(mean_non_ambiguous-mean_cd8))
print(np.linalg.norm(mean_non_ambiguous-mean_cd4))
print(np.linalg.norm(mean_cd8-mean_ambiguous))
print(np.linalg.norm(mean_cd4-mean_ambiguous))


In [None]:
A = np.zeros((4,4))
#mean_non_ambiguous -0
#mean_ambiguous -1
#mean_cd8 -2
#mean_cd4 -3

A[0,1]= np.linalg.norm(mean_non_ambiguous-mean_ambiguous)
A[0,2]= np.linalg.norm(mean_non_ambiguous-mean_cd8)
A[0,3]= np.linalg.norm(mean_non_ambiguous-mean_cd4)
A[1,0]= A[0,1]
A[1,2]= np.linalg.norm(mean_ambiguous-mean_cd8)
A[1,3]= np.linalg.norm(mean_ambiguous-mean_cd4)
A[2,0]= A[0,2]
A[2,1]= A[1,2]
A[2,3]= np.linalg.norm(mean_cd8-mean_cd4)
A[3,0]= A[0,3]
A[3,1]= A[1,3]
A[3,2]= A[2,3]


In [None]:
import seaborn as sns
g = sns.heatmap(A, annot=True,cmap='BuGn')
g.set_xticklabels(['NK CA','NK  AA','CD8','CD4'])
g.set_yticklabels(['NK CA','NK  AA','CD8','CD4'])
g.set_title('Distance matrix')
plt.show()

In [None]:
celltype_list= ["CD14+ Monocytes", "FCGR3A+ Monocytes"]#,'Dendritic Cells']
adata_monocytes= adata[adata.obs['conf_binaries'].isin([True])]
adata_monocytes= adata_monocytes[adata_monocytes.obs['str_labels'].isin(celltype_list)]
adata_monocytes_scaled= adata_scaled[adata.obs['conf_binaries'].isin([True])]
adata_monocytes_scaled= adata_monocytes_scaled[adata_monocytes_scaled.obs['str_labels'].isin(celltype_list)]
adata_monocytes_scaled.obs['conf_ambiguous'] = pd.Categorical((adata_monocytes_scaled.obs['conf'] > 0.95) &
                                                     (adata_monocytes_scaled.obs['var'] < 0.15))


In [None]:

adata_scaled.obs['conf']=adata.obs['conf']
adata_scaled.obs['var']=adata.obs['var']

adata_cd14_scaled= adata_scaled[adata.obs['conf_binaries'].isin([True])]
adata_cd14_scaled= adata_cd14_scaled[adata_cd14_scaled.obs['str_labels'].isin(["CD14+ Monocytes"])]
adata_cd14_scaled.obs['conf_ambiguous'] = pd.Categorical((adata_cd14_scaled.obs['conf'] > 0.95) &
                                                     (adata_cd14_scaled.obs['var'] < 0.15))

adata_fc_scaled= adata_scaled[adata.obs['conf_binaries'].isin([True])]
adata_fc_scaled= adata_fc_scaled[adata_fc_scaled.obs['str_labels'].isin(["FCGR3A+ Monocytes"])]
adata_fc_scaled.obs['conf_ambiguous'] = pd.Categorical((adata_fc_scaled.obs['conf'] > 0.95) &
                                                     (adata_fc_scaled.obs['var'] < 0.15))


In [None]:
adata_cd14_scaled.obs['conf_non_ambiguous'] = pd.Categorical((adata_cd14_scaled.obs['conf'] > 0.95) | (adata_cd14_scaled.obs['var'] < 0.15))
adata_fc_scaled.obs['conf_non_ambiguous'] = pd.Categorical((adata_fc_scaled.obs['conf'] > 0.95) | (adata_fc_scaled.obs['var'] < 0.15))
adata_cd14= adata[adata.obs['str_labels'].isin(["CD14+ Monocytes"])]
adata_fc= adata[adata.obs['str_labels'].isin(["FCGR3A+ Monocytes"])]

annotation_list_cd14=[]
for i in range( adata_cd14.n_obs):
    if adata_cd14.obs['conf_non_ambiguous'][i]==True:
        annotation_list_cd14.append('Correctly classified (annotated as CD14+)')
    else:
        annotation_list_cd14.append('Intermediate state (annotated as CD14+)')
annotation_list_fc=[]
for i in range( adata_fc.n_obs):
    if adata_fc.obs['conf_non_ambiguous'][i]==True:
        annotation_list_fc.append('Correctly classified (annotated as FCGR3A+)')
    else:
        annotation_list_fc.append('Intermediate state (annotated as FCGR3A+)')
adata_cd14.obs['cell_state']=annotation_list_cd14
adata_fc.obs['cell_state']=annotation_list_fc
adata_fc_cd14= adata_cd14.concatenate(adata_fc)

sc.pp.scale(adata_fc_cd14)


In [None]:

mono_genes_dict= {
    'Classical':['S100A12','ALOX5AP'], 
    'Intermediate':['GPR35','MARCO'], 
    'Nonclassical':['ICAM4','CD79b'], 
                   }

sc.pl.matrixplot(adata_fc_cd14, mono_genes_dict, groupby='cell_state' ,cmap='RdYlBu', gene_symbols='alt_names')


In [None]:
sc.pl.matrixplot(adata_fc_cd14, mono_genes_dict, groupby='cell_state' ,cmap='BrBG', gene_symbols='alt_names')


In [None]:
import metrics
adata_fc = metrics.rank_genes_conf(adata_fc)

adata_fc.var['conf_score_high'].sort_values(ascending=False).index[:5]

In [None]:
def find_conf_rank(df, gene):
    for i , j in enumerate(df.index):
        if j==gene:
            return i
find_conf_rank(adata_fc.var['conf_score_low'].sort_values(ascending=False),'ENSG00000103187')

In [None]:
#adata.var['alt_names'].loc['ENSG00000163221'] = 'S100A12'
#adata.var['alt_names'].loc['ENSG00000132965']='ALOX5AP'
#adata.var['alt_names'].loc['ENSG00000178623']='GPR35'
#adata.var['alt_names'].loc['ENSG00000019169']='MARCO'
#adata.var['alt_names'].loc['ENSG00000105371']='ICAM4'
#adata.var['alt_names'].loc['ENSG00000007312']='CD79b'
mono_genes_list = ['ENSG00000163221','ENSG00000132965','ENSG00000178623','ENSG00000019169','ENSG00000105371','ENSG00000007312']
df= adata_fc.var['conf_score_low'].sort_values(ascending=False)
for gene in mono_genes_list:
    print(gene)
    print(gene_mapping[gene])

    print(find_conf_rank(df,gene))
    print("***")
    
print("high")
print()
df= adata_fc.var['conf_score_high'].sort_values(ascending=False)
for gene in mono_genes_list:
    print(gene)
    print(gene_mapping[gene])
    print(find_conf_rank(df,gene))
    print("***")


In [None]:
adata_cd14 = metrics.rank_genes_conf(adata_cd14)

df= adata_cd14.var['conf_score_low'].sort_values(ascending=False)
for gene in mono_genes_list:
    print(gene)
    print(gene_mapping[gene])
    print(find_conf_rank(df,gene))
    print("***")
    
print("high")
print()
df= adata_cd14.var['conf_score_high'].sort_values(ascending=False)
for gene in mono_genes_list:
    print(gene)
    print(gene_mapping[gene])
    print(find_conf_rank(df,gene))
    print("***")


In [None]:
cd56_genes_dict= {
    'CD56 bright':['IL7R','CCR7','CD2','SELL','IFNG'], 
    'CD56 dim':['CX3CR1','ITGB2','KIR2DL1','KIR2DL3','KLRG1','PRF1','GZMA','GZMB'],     

                   }
nk_genes_list = ['ENSG00000168685','ENSG00000126353','ENSG00000116824','ENSG00000188404','ENSG00000111537',
                'ENSG00000168329','ENSG00000160255','ENSG00000125498','ENSG00000243772','ENSG00000139187','ENSG00000180644'
                ,'ENSG00000145649','ENSG00000100453']
#adata.var['alt_names'].loc['ENSG00000168685'] = 'IL7R'
#adata.var['alt_names'].loc['ENSG00000126353']='CCR7'
#adata.var['alt_names'].loc['ENSG00000116824'] = 'CD2'
#adata.var['alt_names'].loc['ENSG00000188404']='SELL'
#adata.var['alt_names'].loc['ENSG00000111537']='IFNG'
#adata.var['alt_names'].loc['ENSG00000168329']='CX3CR1'
#adata.var['alt_names'].loc['ENSG00000160255']='ITGB2'
#adata.var['alt_names'].loc['ENSG00000125498']='KIR2DL1'
#adata.var['alt_names'].loc['ENSG00000243772']='KIR2DL3'
#adata.var['alt_names'].loc['ENSG00000139187']='KLRG1'
#adata.var['alt_names'].loc['ENSG00000180644']='PRF1'
#adata.var['alt_names'].loc['ENSG00000145649']='GZMA'
#adata.var['alt_names'].loc['ENSG00000100453']='GZMB'
adata_nk = metrics.rank_genes_conf(adata_nk)

df= adata_nk.var['conf_score_low'].sort_values(ascending=False)
for gene in nk_genes_list:
    print(gene)
    print(gene_mapping[gene])
    print(find_conf_rank(df,gene))
    print("***")
    
print("high")
print()
df= adata_nk.var['conf_score_high'].sort_values(ascending=False)
for gene in nk_genes_list:
    print(gene)
    print(gene_mapping[gene])
    print(find_conf_rank(df,gene))
    print("***")



In [None]:
adata_nk.var['conf_score_low'].sort_values(ascending=False)[:5].index

In [None]:
intermediate_state_mapping = {
    'B cells': 'B cells',
    'CD4 T cells': 'CD4 T cells',
    'CD8 T cells': 'NK cells',
    'FCGR3A+ Monocytes': 'CD14+ Monocytes',
    'CD14+ Monocytes': 'FCGR3A+ Monocytes',
    'Dendritic Cells': 'Dendritic Cells',
    'NK cells': 'CD8 T cells'
}

# Create an empty list to store the updated cell types
cell_type_list = []
for i in range(adata.n_obs):
    cluster_name = adata.obs['str_labels'][i]
    if cluster_name in intermediate_state_mapping:
        cell_type_list.append(intermediate_state_mapping[cluster_name])
    else:
        cell_type_list.append(cluster_name)

# Update the cell type column in the AnnData object
adata.obs['str_labels_2'] = np.array(cell_type_list)
epoch_num=100
prob_list_1, prob_list_2 = models.follow_train_dyn_two_lables(adata, label_one='str_labels', label_two= 'str_labels_2', iterNum=epoch_num, device=device)

In [None]:
for i in range(100):
    adata.obs['prob1_' +str(i)] = prob_list_1[i]
    adata.obs['prob2_' +str(i)] = prob_list_2[i]


In [None]:
adata_cd14= adata[adata.obs['conf_binaries'].isin([True])]
adata_cd14.obs['conf_non_ambiguous'] = pd.Categorical((adata_cd14.obs['conf'] > 0.95) | (adata_cd14.obs['var'] < 0.15))

adata_cd14= adata_cd14[adata_cd14.obs['str_labels'].isin(['CD14+ Monocytes'])]
adata_cd14_miss= adata[adata.obs['conf_binaries'].isin([False])]
adata_cd14_miss= adata_cd14_miss[adata_cd14_miss.obs['str_labels'].isin(['CD14+ Monocytes'])]
adata_cd14_inter= adata_cd14[adata_cd14.obs['conf_non_ambiguous'].isin([False])]
adata_cd14_cd14= adata_cd14[adata_cd14.obs['conf_non_ambiguous'].isin([True])]


In [None]:

prob_cd14_cd14 = []
std_cd14_cd14 = []
prob_cd14_miss = []
std_cd14_miss = []
prob_inter_cd14 = []
std_inter_cd14 = []
prob_inter_fc = []
std_inter_fc = []

for i in range(100):
    prob_cd14_cd14.append(np.mean(adata_cd14_cd14.obs['prob1_' +str(i)]))
    prob_cd14_miss.append(np.mean(adata_cd14_miss.obs['prob1_' +str(i)]))
    prob_inter_cd14.append(np.mean(adata_cd14_inter.obs['prob1_' +str(i)]))
    prob_inter_fc.append(np.mean(adata_cd14_inter.obs['prob2_' +str(i)]))
    std_cd14_cd14.append(np.std(adata_cd14_cd14.obs['prob1_' +str(i)]))
    std_cd14_miss.append(np.std(adata_cd14_miss.obs['prob1_' +str(i)]))
    std_inter_cd14.append(np.std(adata_cd14_inter.obs['prob1_' +str(i)]))
    std_inter_fc.append(np.std(adata_cd14_inter.obs['prob2_' +str(i)]))



In [None]:
std_cd14_cd14 = np.array(std_cd14_cd14)
prob_cd14_cd14  = np.array(prob_cd14_cd14)
prob_inter_cd14 = np.array(prob_inter_cd14)
std_inter_cd14 = np.array(std_inter_cd14)
prob_inter_fc = np.array(prob_inter_fc)
std_inter_fc = np.array(std_inter_fc)
prob_cd14_miss = np.array(prob_cd14_miss)
std_cd14_miss = np.array(std_cd14_miss)

plt.plot(range(100),prob_cd14_cd14[:100],label='Correctly classified to be CD14+', color='b')
plt.fill_between(range(100), prob_cd14_cd14-std_cd14_cd14, prob_cd14_cd14+std_cd14_cd14, interpolate=True, alpha=0.2)
plt.plot(range(100),prob_inter_cd14[:100], label='Intermediate to be CD14+', color='orange')
plt.fill_between(range(100), prob_inter_cd14-std_inter_cd14, prob_inter_cd14+std_inter_cd14, interpolate=True,  alpha=0.2)
plt.plot(range(100),prob_inter_fc[:100], label='Intermediate to be FCGR3A+', color='g')
plt.fill_between(range(100), prob_inter_fc-std_inter_fc, prob_inter_fc+std_inter_fc, interpolate=True,  alpha=0.2)
plt.plot(range(100),prob_cd14_miss[:100], label='IC to be CD14+', color='r')
plt.fill_between(range(100), prob_cd14_miss-std_cd14_miss, prob_cd14_miss+std_cd14_miss, interpolate=True,  alpha=0.2)
plt.xlabel('Epoch')
plt.ylabel('Probability')
plt.legend()
plt.show()



In [None]:
plt.plot(range(100),prob_cd14_cd14[:100],label='correct', color='b')
plt.fill_between(range(100), prob_cd14_cd14-std_cd14_cd14, prob_cd14_cd14+std_cd14_cd14, interpolate=True, alpha=0.2)
plt.plot(range(100),prob_inter_cd14[:100], label='ambiguous', color='orange')
plt.fill_between(range(100), prob_inter_cd14-std_inter_cd14, prob_inter_cd14+std_inter_cd14, interpolate=True,  alpha=0.2)
plt.plot(range(100),prob_inter_fc[:100], label='ambiguous*', color='g')
plt.fill_between(range(100), prob_inter_fc-std_inter_fc, prob_inter_fc+std_inter_fc, interpolate=True,  alpha=0.2)
plt.plot(range(100),prob_cd14_miss[:100], label='erroneous', color='r')
plt.fill_between(range(100), prob_cd14_miss-std_cd14_miss, prob_cd14_miss+std_cd14_miss, interpolate=True,  alpha=0.2)
plt.xlabel('Epoch')
plt.title('Mean probability to be CD14+, in ambiguous* to be FCGR3A+')
plt.ylabel('Mean probability')
#plt.ylabel(' mean probability to be classical (CD14+), except in ambiguous*, it is the probability to be non-classical (FCGR3A+)')
plt.legend()
plt.show()

