In [14]:
import anndata as ad
import torch
import torch.nn as nn


## Load the hand annotated data

In [2]:
adata = ad.read_h5ad('multivelo.h5ad') # hand annotated 7th day uninfected HSPC

In [3]:
adata


AnnData object with n_obs × n_vars = 11605 × 899
    obs: 'n_counts', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'initial_size_spliced', 'initial_size_unspliced', 'initial_size', 'S_score', 'G2M_score', 'phase', 'fractions_u', 'leiden'
    var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'mean', 'std', 'fit_alpha_c', 'fit_alpha', 'fit_beta', 'fit_gamma', 'fit_t_sw1', 'fit_t_sw2', 'fit_t_sw3', 'fit_scale_cc', 'fit_rescale_c', 'fit_rescale_u', 'fit_alignment_scaling', 'fit_c0', 'fit_u0', 'fit_s0', 'fit_model', 'fit_direction', 'fit_loss', 'fit_likelihood', 'fit_anchor_min_idx', 'fit_anchor_max_idx', 'velo_s_genes', 'velo_u_genes', 'velo_chrom_genes'
    uns: 'leiden', 'leiden_colors', 'neighbors', 'pca', 'umap', 'velo_chrom_params', 'velo_s_params', 'velo_u_params'
    obsm: 'X_pca', 'X_

In [4]:
adata.X

array([[-0.1972677 , -0.22719564, -0.3174438 , ..., -0.07487252,
        -0.1085362 , -0.24923815],
       [-0.17314151, -0.16308507,  0.03774771, ..., -0.0642124 ,
        -0.08726899, -0.1995333 ],
       [-0.30308467, -0.41306373, -0.3432359 , ..., -0.05103816,
        -0.10733148, -0.2633949 ],
       ...,
       [-0.13953881, -0.15732218, -0.04169727, ..., -0.06226784,
        -0.09160008, -0.23593524],
       [-0.16304561, -0.19282901,  0.06897046, ..., -0.0511967 ,
        -0.08410984, -0.22448511],
       [-0.04084875,  0.02475426,  0.35794598, ..., -0.06304198,
        -0.06911539, -0.18171224]], dtype=float32)

This data has 11605 cells with 899 genes measured for each. 

### Var Dataframe

In [26]:
adata.var

Unnamed: 0_level_0,Accession,Chromosome,End,Start,Strand,mt,n_cells_by_counts,mean_counts,pct_dropout_by_counts,total_counts,...,fit_s0,fit_model,fit_direction,fit_loss,fit_likelihood,fit_anchor_min_idx,fit_anchor_max_idx,velo_s_genes,velo_u_genes,velo_chrom_genes
Gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ABAT,ENSG00000183044,16,8784575,8674596,+,False,272,0.020566,98.057698,288.0,...,0.000000e+00,1.0,on,0.000651,0.053900,45.0,482.0,True,True,True
ABCA1,ENSG00000165029,9,104928155,104781006,-,False,245,0.018566,98.250500,260.0,...,5.823158e-05,1.0,complete,0.001142,0.039279,86.0,365.0,False,False,False
ABCC3,ENSG00000108846,17,50692253,50634777,+,False,108,0.010783,99.228792,151.0,...,8.646782e-04,1.0,on,0.000871,0.008601,20.0,482.0,False,False,False
ABCC4,ENSG00000125257,13,95301475,95019835,-,False,6723,1.259212,51.992288,17634.0,...,1.366475e-01,1.0,on,0.173222,0.507818,53.0,484.0,True,True,True
ABHD2,ENSG00000140526,15,89202355,89087459,+,False,6018,0.726078,57.026564,10168.0,...,2.874029e-05,1.0,complete,0.073722,0.097848,67.0,498.0,True,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZNF780B,ENSG00000128000,19,40056231,40028260,-,False,981,0.073693,92.994859,1032.0,...,1.833241e-05,1.0,on,0.002445,0.042543,348.0,487.0,False,False,False
ZNF783,ENSG00000204946,7,149297302,149262171,+,False,542,0.040988,96.129677,574.0,...,0.000000e+00,1.0,on,0.001448,0.065151,215.0,487.0,True,True,True
ZNF785,ENSG00000197162,16,30585769,30573740,-,False,81,0.005784,99.421594,81.0,...,2.132494e-05,1.0,complete,0.000157,0.016822,36.0,242.0,False,False,False
ZNF790-AS1,ENSG00000267254,19,36831596,36797502,+,False,141,0.010426,98.993145,146.0,...,9.586948e-09,1.0,complete,0.000302,0.020806,94.0,354.0,False,False,False


In [27]:
var_df = adata.var

In [28]:
var_df['Accession'].nunique()

899

So there are 899 genes in this dataset 

### Obs DataFrame

In [5]:
adata.obs

Unnamed: 0,n_counts,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,initial_size_spliced,initial_size_unspliced,initial_size,S_score,G2M_score,phase,fractions_u,leiden
AAACAGCCAAACCTTG-1,4234.334961,1497,3384.0,0.0,0.0,3384,1563,3384.0,-0.634671,-0.215746,G1,0.308031,Granulocyte
AAACAGCCACCCTCAC-1,4246.441406,1673,3251.0,0.0,0.0,3251,1904,3251.0,0.423275,-0.362225,S,0.301643,Erythrocyte
AAACATGCAATCGCAC-1,4228.449219,1652,3272.0,0.0,0.0,3272,1849,3272.0,-0.801623,-0.463818,G1,0.374815,Erythrocyte
AAACATGCACAGCCTG-1,4223.146973,3637,10265.0,0.0,0.0,10265,3769,10265.0,0.408826,0.540197,G2M,0.290607,Prog MK
AAACATGCAGAGGCTA-1,4230.780762,1572,3526.0,0.0,0.0,3526,1792,3526.0,0.400493,-0.121910,S,0.318393,MPP
...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTGTTAAAC-1,4194.219238,1642,3244.0,0.0,0.0,3244,1723,3244.0,0.000000,-0.201697,S,0.340409,HSC
TTTGTTGGTGTTGTGA-1,4216.825195,2306,5480.0,0.0,0.0,5480,2348,5480.0,-0.291098,-0.117548,G1,0.345615,MEP
TTTGTTGGTTACCTGT-1,4240.065430,1664,3310.0,0.0,0.0,3310,1678,3310.0,0.000000,0.155516,G2M,0.304211,MEP
TTTGTTGGTTCCGGGA-1,4253.901855,2193,5188.0,0.0,0.0,5188,1870,5188.0,0.296031,0.044591,S,0.322785,Erythrocyte


In [6]:
df = adata.obs

In [7]:
df.columns

Index(['n_counts', 'n_genes_by_counts', 'total_counts', 'total_counts_mt',
       'pct_counts_mt', 'initial_size_spliced', 'initial_size_unspliced',
       'initial_size', 'S_score', 'G2M_score', 'phase', 'fractions_u',
       'leiden'],
      dtype='object')

In [8]:
df['leiden'].unique().tolist()

['Granulocyte',
 'Erythrocyte',
 'Prog MK',
 'MPP',
 'GMP',
 'HSC',
 'LMPP',
 'MEP',
 'Prog DC',
 'Platelet',
 'Prog B']

In [9]:
# double check at least PROG DC, this was the primary one that was
# unclear in identifying

# all of these values vall under Hematopoietic cell
# 553 and 1012 not in cell census

mapping_dict = {'Granulocyte' : 'CL:0000094',
                'Erythrocyte' : 'CL:0000232',
                'Prog MK' : 'CL:0000553',
                'MPP' : 'CL:0000837',
                'GMP' : 'CL:0000557',
                'HSC' : 'CL:0000037',
                'LMPP' : 'CL:0000936',
                'MEP' : 'CL:0000050',
                'Prog DC' : 'CL:0001012',
                'Platelet' : 'CL:0000233',
                'Prog B' : 'CL:0000826'}    
    

df['cell_type'] = df['leiden'].map(mapping_dict)

In [10]:
df['cell_type'].value_counts()

cell_type
CL:0000936    1816
CL:0000232    1606
CL:0000837    1605
CL:0000037    1529
CL:0000557    1344
CL:0000050    1246
CL:0000553    1240
CL:0000094     661
CL:0001012     492
CL:0000826      36
CL:0000233      30
Name: count, dtype: int64

In [11]:
df.shape

(11605, 14)

936 and 233 are the only leaf nodes, so that amounts to 1846 cells, which is 16% of the total sample.

## Load the PyTorch Model

In [17]:
input_dim = 19922

# number of neurons for hidden layers
hidden_layer_1 = 256
hidden_layer_2 = 128

# number of classes (unique of y)
output_dim = 52

In [18]:
class Network2(nn.Module):
    def __init__(self):
        super(Network2, self).__init__()
        self.linear1 = nn.Linear(input_dim,hidden_layer_1)
        self.linear2 = nn.Linear(hidden_layer_1,hidden_layer_2)
        self.linear3 = nn.Linear(hidden_layer_2,output_dim)
        self.bn1 = nn.BatchNorm1d(hidden_layer_1)
        self.bn2 = nn.BatchNorm1d(hidden_layer_2)
        
    def forward(self,x):
        x = self.linear1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.linear3(x)
        x = F.softmax(x,dim=1)
        return x

In [25]:
model = Network2()

# use map_location if model was trained on GPU but is being loaded on CPU
model = torch.load('30_Oct_best_model',map_location=torch.device('cpu'))

# set model to eval mode
#model.eval()
