In [3]:
import pandas as pd 
import scsims
from sklearn.preprocessing import LabelEncoder
import sys

sys.path.append('../src')
sys.path.append('../tests')
from models.lib.lightning_train import DataModule, generate_trainer
from models.lib.neural import GeneClassifier
from models.lib.data import *
from models.lib.neural import *
from pytorch_lightning.loggers import WandbLogger

Let's do some simple preprocessing for training, and then test our library

In [6]:
le = LabelEncoder()

labels = pd.read_csv('../data/retina/retina_labels.tsv', sep='\t')
labels = labels[labels['CellType'].isin(['retinal rod cell type B', 'retinal rod cell type A'])]
labels = labels[labels['CellType'] != 'unannotated']
labels['class_label'] = le.fit_transform(labels['CellType'])

labels.index.name = 'cell'
# labels = labels.iloc[0:5000, :]
labels.to_csv('../data/retina/retina_labels_numeric.csv')

# label_df = pd.read_csv('../data/retina/retina_labels_numeric.csv', index_col='cell')
# # label_df

# labels.to_csv('../data/retina/retina_labels_numeric.csv')

labels = pd.read_csv('../data/retina/retina_labels_numeric.csv')
labels

Unnamed: 0,cell,index,Expressed Genes,UMI Count,Percent Mitochond.,BroadCellType,CellType,Donor,class_label
0,0,0024369980fd003553cbc9dfe29f7f95,2351,6060.0,4.125413,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,0
1,2,0037f1b36684cb59b84d3585ca55ff69,950,1507.0,0.597213,retinal rod cell,retinal rod cell type B,427c0a62-9baf-42ab-a3a3-f48d10544280,1
2,3,00390952646f52d11a9ab9bba7d6ac51,961,1962.0,7.543323,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,0
3,4,005b3351658380695a5dc46c384d72d7,858,1384.0,0.144509,retinal rod cell,retinal rod cell type B,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,1
4,6,00906832f470fc434a52ac7d678a95bc,532,1054.0,6.451613,retinal rod cell,retinal rod cell type A,427c0a62-9baf-42ab-a3a3-f48d10544280,0
...,...,...,...,...,...,...,...,...,...
12061,19683,ff53554e8720a2302874fcbd21c7b0ed,1910,4604.0,2.845352,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,0
12062,19684,ff61c0c282f41e4a37885b05342441da,1494,3320.0,2.289157,retinal rod cell,retinal rod cell type B,427c0a62-9baf-42ab-a3a3-f48d10544280,1
12063,19689,ffa4633bef82949d2c6ac17b3ddf46e9,1813,4129.0,3.996125,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,0
12064,19691,ffd3fd6119de767f3c3b8c47b2c28bf0,516,896.0,3.794643,retinal rod cell,retinal rod cell type A,427c0a62-9baf-42ab-a3a3-f48d10544280,0


In [7]:
pd.read_csv('../data/mouse/MouseAdultInhibitoryNeurons_labels.csv')

Unnamed: 0,class,numeric_class
0,S-phase_MCM4/H43C,36
1,S-phase_MCM4/H43C,36
2,Ctx_LHX6/SST,9
3,Str_LHX8/CHAT,40
4,Str_LHX8/CHAT,40
...,...,...
141064,S-phase_MCM4/H43C,36
141065,Transition,41
141066,Transition,41
141067,S-phase_MCM4/H43C,36


In [3]:
trainer, _, module = generate_trainer(
    datafiles=['../data/retina/retina_T.csv'],
    labelfiles=['../data/retina/retina_labels_numeric.csv'],
    class_label='class_label',
    index_col='cell',
    batch_size=16,
    num_workers=0,
    skip=3,
    shuffle=True,
    drop_last=True,
    weighted_metrics=False,
    normalize=True,
    weights=total_class_weights(['../data/retina/retina_labels_numeric.csv'], 'class_label'),
    wandb_name='local-retina-model',
    optim_params={
        'optimizer': torch.optim.Adam,
        'lr': 0.2,
        'weight_decay': 0,
    },
    scheduler_params={
        'scheduler': torch.optim.lr_scheduler.StepLR,
        'step_size': 1e-5,
    },
    max_epochs=100,
)

Device is cpu
../data/retina/retina_T.csv exists, continuing...
../data/retina/retina_labels_numeric.csv exists, continuing...



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Model initialized. input_dim = 37475, output_dim = 14. Metrics are dict_keys(['accuracy', 'precision', 'recall']) and weighted_metrics = False


In [8]:

module = DataModule(
    datafiles=['../data/retina/retina_T.csv'],
    labelfiles=['../data/retina/retina_labels_numeric.csv'],
    class_label='class_label',
    index_col='cell',
    batch_size=16,
    num_workers=0,
    skip=3,
    shuffle=True,
    drop_last=True,
    normalize=True,
)

tabnetmodel = TabNetLightning(
    input_dim=module.num_features,
    output_dim=module.num_labels,
    optim_params={
        'optimizer': torch.optim.Adam,
        'lr': 0.2,
        'weight_decay': 0,
    },
    scheduler_params={
        'scheduler': torch.optim.lr_scheduler.StepLR,
        'step_size': 1e-5,
    },
)

wandb_logger = WandbLogger(
    project=f"tabnet-classifer-sweep",
    name='local-retina-model'
)

trainer = pl.Trainer(
    gpus=(1 if torch.cuda.is_available() else 0),
    auto_lr_find=False,
    logger=wandb_logger,
    max_epochs=100,
)


Initializing network
Initializing explain matrix


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(tabnetmodel, datamodule=module)

  rank_zero_warn(

  | Name    | Type   | Params
-----------------------------------
0 | network | TabNet | 2.5 M 
-----------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
9.916     Total estimated model params size (MB)


Creating train/val/test DataLoaders...
Done, continuing to training.
Got here


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
wandb: ERROR Error while calling W&B API: Error 1040: Too many connections (<Response [500]>)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [6]:
trainloader.dataset[0:50]

NameError: name 'trainloader' is not defined

In [None]:
def _test_first_n_samples(n, datafile, labelfile):
    data = GeneExpressionData(
        datafile, 
        labelfile, 
        'class_label', 
        skip=3,
        index_col='cell'
    )
    cols = data.columns
    
    # Generate dict with half precision values to read this into my 16gb memory
    dtype_cols = dict(zip(cols, [np.float32]*len(cols)))
    
    data_df = pd.read_csv(datafile, nrows=2*n, header=1, dtype=dtype_cols) # Might need some extras since numerical index drops some values
    label_df = pd.read_csv(labelfile, nrows=n)

    similar = []
    for i in range(n):
        datasample = data[i][0]

        dfsample = torch.from_numpy(data_df.loc[label_df.loc[i, 'cell'], :].values).float()
        isclose = all(torch.isclose(datasample, dfsample))
        similar.append(isclose)
    
    print(f"First {n=} columns of expression matrix is equal to GeneExpressionData: {all(p for p in similar)}")

    assert (all(p for p in similar))

_test_first_n_samples(100, '../data/retina/retina_T.csv', '../data/retina/retina_labels_numeric.csv')


In [None]:
idk = module.trainloader.dataset._labeldf['cell']
idk