In [1]:
from scsims import SIMS 
import scanpy as sc
import anndata as an
from torch.utils.data import DataLoader
from scsims.testing import TestAnndatasetMatrix
import torch
from tqdm import tqdm
from scsims import clean_sample 
import numpy as np
import pandas as pd 
from scsims import SIMSClassifier
from scsims.data import AnnDatasetMatrix
import plotly.express as px 
from scsims import DataModule
import os 
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl 

class UploadCallback(pl.callbacks.Callback):
    def __init__(
        self, 
        path: str, 
        desc: str, 
        upload_path='model_checkpoints',
        epochs: int=1,
    ) -> None:
        super().__init__()
        self.path = path 
        self.desc = desc
        self.upload_path = upload_path
        self.epochs = epochs

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch

        if epoch % self.epochs == 0 and epoch > 0: # Save every ten epochs
            checkpoint = f'checkpoint-{epoch}-desc-{self.desc}.ckpt'
            trainer.save_checkpoint(os.path.join(self.path, checkpoint))
            print(f'Saving checkpoint at epoch {epoch}')

In [2]:
labels = pd.read_csv('../data/cd4/Atlas_Annotation_CD4.csv')

In [4]:
labels['categorical_Atlas Annotation'].value_counts()

Tcm         1867
Tem         1712
trTregs     1534
Trm_        1428
Teff        1320
CD4 RPL     1151
eTregs      1112
cTregs       734
Th1 CTL      655
T prolif     511
Tfh          393
Name: categorical_Atlas Annotation, dtype: int64

In [3]:
data = an.read_h5ad('../cd4_atlas/GSE99254_108989_96838_filtered_QC.h5ad')

In [None]:
Data_merge_raw = data.copy()
Data_normalized = Data_merge_raw.copy()
sc.pp.normalize_total(Data_normalized)
Data_log1p = Data_normalized.copy()
sc.pp.log1p(Data_log1p)

In [None]:
Data_log1p.write_h5ad('../cd4_atlas/GSE99254_108989_96838_filtered_QC_LOG_NORM.h5ad')

In [6]:
import plotly.express as px 
sample = (np.asarray(data.X.todense()[0])[0])


In [7]:
sample[np.where(np.array(sample) > 0)[0]]

array([  7., 215.,  11., ..., 126.,  73., 392.], dtype=float32)

In [8]:
# px.histogram(sample[np.where(np.array(sample) > 0)[0]])

In [9]:
# px.histogram(np.log(np.array(sample[np.where(np.array(sample) > 0)[0]]) + 1))

In [10]:
from scsims.data import AnnDatasetFile, AnnDatasetMatrix
from torch.utils.data import DataLoader
from scsims.data import CollateLoader 

to_explain = AnnDatasetMatrix(
    matrix=data.X,
    labels=labels['Atlas Annotation'],
)

to_explain = CollateLoader(to_explain, batch_size=4, num_workers=0)

In [56]:
next(iter(to_explain))[0]

tensor([[  0.,   0.,   0.,  ...,  73., 392.,   0.],
        [  0.,   0.,   0.,  ...,   0.,   0.,  57.],
        [  0.,   0.,   0.,  ...,   0., 195.,   0.],
        [ 11.,   0.,   0.,  ...,   0.,   0.,   8.]])

In [4]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
le = le.fit(labels['categorical_Atlas Annotation'])

module = DataModule(
    datafiles=['../cd4_atlas/GSE99254_108989_96838_filtered_QC_LOG_NORM.h5ad'],
    labelfiles=['../data/cd4/Atlas_Annotation_CD4.csv'],
    class_label='Atlas Annotation',
    batch_size=16,
    num_workers=0,
    deterministic=True,
    normalize=True,
)

module.prepare_data()
module.setup()

Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights


In [5]:
wandb_logger = WandbLogger(
    project=f"CD4 Atlas",
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='epoch')

upload_callback = UploadCallback(
    path='checkpoints',
    desc=f'local_cd4_model_8_1_big'
)

early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=20,
)

trainer = pl.Trainer(
    gpus=(1 if torch.cuda.is_available() else 0),
    auto_lr_find=False,
    logger=wandb_logger,
    max_epochs=500,
    gradient_clip_val=0.5,
    callbacks=[
        lr_callback, 
        upload_callback,
        early_stopping_callback,
    ]
)


model = SIMSClassifier(
    input_dim=module.num_features,
    output_dim=module.num_labels,
    weights=module.weights,
    n_d=1000,
    n_a=1000,
    n_steps=4,
)

trainer.fit(model, datamodule=module)
trainer.test(model, datamodule=module)

[34m[1mwandb[0m: Currently logged in as: [33mjlehrer1[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Initializing network
Initializing explain matrix
Creating train/val/test DataLoaders...



  | Name    | Type   | Params
-----------------------------------
0 | network | TabNet | 245 M 
-----------------------------------
245 M     Trainable params
0         Non-trainable params
245 M     Total params
981.754   Total estimated model params size (MB)


Done, continuing to training.
Calculating weights


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1s = 2*(precision * recall) / (precision + recall)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [5]:
model = SIMSClassifier.load_from_checkpoint(
    'checkpoints/checkpoint-29-desc-local_cd4_model_7_22_22_big.ckpt',
    input_dim=module.num_features,
    output_dim=module.num_labels,
    weights=module.weights,
    n_d=500,
    n_a=500,
    n_steps=4,
)

Initializing network
Initializing explain matrix


In [10]:
explain = model.explain(module.testloader)

100%|██████████| 125/125 [01:38<00:00,  1.27it/s]


In [6]:
preds = model.predict(module.testloader)

100%|█████████| 125/125 [00:00<00:00, 426.47it/s]


ValueError: Length of values (1987) does not match length of index (3)

In [11]:
preds = []
labels = []
for X in tqdm(module.testloader):
    # Some dataloaders will have labels, handle this case 
    if len(X) == 2:
        data, label = X 
        labels.extend(label.numpy())
    else:
        data = X 

    res, _ = model(data)
    res = np.argmax(res.detach(), axis=1)
    preds.extend(res.numpy())

final = pd.DataFrame()
final['predicted_label'] = preds

if labels != []: final['actual_label'] = labels 


100%|██████████| 125/125 [01:02<00:00,  2.02it/s]


In [13]:
final

Unnamed: 0,predicted_label,actual_label
0,7,7
1,10,9
2,3,3
3,0,0
4,8,8
...,...,...
1982,10,10
1983,8,2
1984,10,10
1985,3,0


In [25]:
from sklearn.metrics import accuracy_score

accuracy_score(final['predicted_label'], final['actual_label'])

0.7096124811273277

In [48]:
from anndata import AnnData

indices = module.testloader.dataset.split 

obj = AnnData(
    X=explain[0],
)

obj.obs['preds'] = le.inverse_transform(total)
obj.obs['actual'] = le.inverse_transform(explain[1])
obj.obs['indices'] = indices
obj.obs['barcode'] = labels.loc[:, 'Barcode'].iloc[indices.values].values

  obj = AnnData(


In [50]:
obj.write_h5ad('test_explain_75_acc_model.h5ad')

In [45]:
labels.loc[:, 'Barcode'].iloc[indices.values]

5007          TTY64.0913-1
11288    TTR188.20180123-2
2184         TTH147.0508-0
935          TTH115.0508-0
797         PTS3.72.0508-0
               ...        
10897     TTY50.20161012-2
3235          PTS45.0617-1
11791    TTR157.20170825-2
9968      TTH48.20180123-2
9087      NTR53.20180123-2
Name: Barcode, Length: 1987, dtype: object

In [93]:
explain.write('explain_matrix_full_dataset.h5ad')

In [11]:
np.savetxt("../train_explain_matrix_88_accurate_7_14_22_FULL.csv", matrix[0], delimiter=",")

In [29]:
pd.Series(le.inverse_transform(matrix[1])).to_csv('../train_labels_88_accurate_7_14_22.csv', index=False)

pd.read_csv('../train_labels_88_accurate_7_14_22.csv')

Unnamed: 0,0
0,cTregs
1,cTregs
2,Teff
3,CD4 RPL
4,Trm_
...,...
7941,Teff
7942,Teff
7943,trTregs
7944,CD4 RPL


In [64]:
matrix[0].tofile('test.npy')

In [24]:
np.load('test.npy')

TypeError: Mismatch between array dtype ('object') and format specifier ('%.18e')

In [7]:
data = an.read_h5ad('../data/cd4/cd4_train.h5ad')[0:1000]
test = an.read_h5ad('../data/cd4/test.h5ad')[::10]

In [4]:
refgenes = list(set(data.var.index).intersection(test.var.index))
currgenes = data.var.index

indices = np.intersect1d(currgenes, refgenes, return_indices=True)[1]

In [6]:
model = SIMSClassifier.load_from_checkpoint(
    '../checkpoints/checkpoint-20-desc-cd4_atlas_intersection_None.ckpt',
    input_dim=len(refgenes),
    output_dim=cd4_model.datamodule.output_dim,
)

Initializing network
Initializing explain matrix


In [14]:
from scsims.data import CollateLoader 
from scsims.data import AnnDatasetMatrix 

traindataset = AnnDatasetMatrix(
    data.X[0:1000], 
    cd4_model.label_encoder.transform(data.obs['cluster'][0:1000].values),
)

trainloader = CollateLoader(
    dataset=traindataset, 
    batch_size=4, 
    num_workers=0, 
    refgenes=refgenes, 
    currgenes=currgenes,
)

In [16]:
explain_mtx = model.explain(trainloader, normalize=True)


  0%|                                                                                                                                                               | 0/250 [00:00<?, ?it/s][A
  0%|▌                                                                                                                                                      | 1/250 [00:00<00:26,  9.23it/s][A
  2%|██▍                                                                                                                                                    | 4/250 [00:00<00:13, 18.38it/s][A
  3%|████▏                                                                                                                                                  | 7/250 [00:00<00:11, 20.87it/s][A
  4%|██████                                                                                                                                                | 10/250 [00:00<00:10, 21.97it/s][A
  5%|███████▊                          

 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                            | 176/250 [00:12<00:07,  9.60it/s][A
 71%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                           | 177/250 [00:13<00:07,  9.47it/s][A
 71%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                           | 178/250 [00:13<00:07,  9.29it/s][A
 72%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                          | 179/250 [00:13<00:07,  9.37it/s][A
 72%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                         | 180/250 [00:13<00:07,  9.26it/s][A
 72%|███████████████████████████████████

In [27]:
agg = pd.DataFrame(explain_mtx[0], columns=currgenes[indices])
agg['CellType'] = data.obs['cluster'][0:1000].values
agg = agg.groupby('CellType').sum()

agg.head(5)

Unnamed: 0_level_0,A1BG,A2M,AAAS,AACS,AAGAB,AAK1,AAMDC,AAMP,AAR2,AARS,...,ZSWIM8,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
CellType,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CD4_C1-Naive,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CD4_C2-Tcm,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CD4_C3-Tem,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CD4_C4-CD69,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CD4_C5-ISG15,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [29]:
agg.T

CellType,CD4_C1-Naive,CD4_C2-Tcm,CD4_C3-Tem,CD4_C4-CD69,CD4_C5-ISG15,CD4_C6-RPL,CD4_C7-Th1-like,CD4_C8-Treg,CD4_C9-Prolif.,CD4_CD4_CD4_XCL1
A1BG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
A2M,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAAS,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AACS,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAGAB,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...
ZXDB,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ZXDC,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ZYG11B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ZYX,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [35]:
aggT = agg.T

for idx in aggT.columns:
    mask = aggT[idx]
    mask = mask.sort_values(ascending=False)
    px.bar(
        mask[0:50]
    ).show()

A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C1-Naive, Length: 10381, dtype: float64
B2M         11.836211
ACTB         9.703895
FTH1         8.768082
HLA-DRB1     7.891657
RPL14        6.532100
              ...    
GRAMD1A      0.000000
GRAMD1C      0.000000
GRAMD4       0.000000
GRAP         0.000000
ZZEF1        0.000000
Name: CD4_C1-Naive, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C2-Tcm, Length: 10381, dtype: float64
ACTB       11.520334
FTH1       10.204775
TPT1        7.526892
TCP1        7.198909
S100A13     6.887030
             ...    
GPX7        0.000000
GRAMD1A     0.000000
GRAMD1B     0.000000
GRAMD1C     0.000000
ZZEF1       0.000000
Name: CD4_C2-Tcm, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C3-Tem, Length: 10381, dtype: float64
ACTB        11.016292
HLA-DRB1     6.472397
IL7R         6.240834
IFITM1       5.679183
B2M          5.468706
              ...    
GRAP2        0.000000
GRASP        0.000000
GRB2         0.000000
GRHPR        0.000000
ZZEF1        0.000000
Name: CD4_C3-Tem, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C4-CD69, Length: 10381, dtype: float64
FOS       39.059869
RNF167    14.280225
TROVE2     9.483388
ADCK2      7.742420
ACTB       5.937043
            ...    
GPR65      0.000000
GPR68      0.000000
GPR82      0.000000
GPR89A     0.000000
ZZEF1      0.000000
Name: CD4_C4-CD69, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C5-ISG15, Length: 10381, dtype: float64
PLAG1     7.300617
XAF1      7.115205
ADCK2     3.862940
STAT1     3.570174
IFI44     2.784242
            ...   
GPR160    0.000000
GPR171    0.000000
GPR174    0.000000
GPR18     0.000000
ZZEF1     0.000000
Name: CD4_C5-ISG15, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C6-RPL, Length: 10381, dtype: float64
NDUFS7    3.870878
RAB1B     3.293878
B2M       2.685536
RPL14     2.512998
EPSTI1    2.234386
            ...   
GPR18     0.000000
GPR180    0.000000
GPR183    0.000000
GPR19     0.000000
ZZEF1     0.000000
Name: CD4_C6-RPL, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C7-Th1-like, Length: 10381, dtype: float64
RPS6KA3    26.215332
IFI44      13.617716
UBE2B      11.094547
JAKMIP1     8.402340
TPT1        3.842732
             ...    
GPR137B     0.000000
GPR155      0.000000
GPR157      0.000000
GPR160      0.000000
ZZEF1       0.000000
Name: CD4_C7-Th1-like, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C8-Treg, Length: 10381, dtype: float64
BUB3       87.758435
THADA      67.517044
ADCK2      42.727594
PTPMT1     29.770023
IFI44      29.765173
             ...    
GRAMD1A     0.000000
GRAMD1C     0.000000
GRAMD4      0.000000
GRAP        0.000000
ZZEF1       0.000000
Name: CD4_C8-Treg, Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_C9-Prolif., Length: 10381, dtype: float64
HMGB2     1.858178
ACTB      1.730955
EEF1A1    0.714988
HLA-A     0.630918
RAB1B     0.615410
            ...   
GPR157    0.000000
GPR160    0.000000
GPR171    0.000000
GPR174    0.000000
ZZEF1     0.000000
Name: CD4_C9-Prolif., Length: 10381, dtype: float64


A1BG      0.0
A2M       0.0
AAAS      0.0
AACS      0.0
AAGAB     0.0
         ... 
ZXDB      0.0
ZXDC      0.0
ZYG11B    0.0
ZYX       0.0
ZZEF1     0.0
Name: CD4_CD4_CD4_XCL1, Length: 10381, dtype: float64
ARL4A      2.781468
RXRB       0.446710
ACTB       0.442732
FXYD5      0.303230
TPT1       0.230515
             ...   
GPR132     0.000000
GPR137     0.000000
GPR137B    0.000000
GPR155     0.000000
ZZEF1      0.000000
Name: CD4_CD4_CD4_XCL1, Length: 10381, dtype: float64
