In [1]:
import scanpy as sc
import anndata as an
import os
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder 
import torch
from sys import getsizeof
import json 

import os, sys
sys.path.append('../src')
from models.lib.data import *
from models.lib.lightning_train import DataModule, generate_trainer
datafiles=['../data/mouse/MouseAdultInhibitoryNeurons.h5ad']
labelfiles=['../data/mouse/Adult Inhibitory Neurons in Mouse_labels.tsv']

In [2]:
mouse_atlas = sc.read_h5ad('../data/mouse/MouseAdultInhibitoryNeurons.h5ad')

In [13]:
atlas_labels = pd.DataFrame(mouse_atlas.obs['class'])
le = LabelEncoder()
atlas_labels['numeric_class'] = le.fit_transform(atlas_labels['class'])
atlas_labels

Unnamed: 0,class,numeric_class
AAAAAAAAAAAAAA_p25-27_Amygdala_SAMN08730984,S-phase_MCM4/H43C,36
GAAACCCAATCTCG_p25-27_Amygdala_SAMN08730984,S-phase_MCM4/H43C,36
GACCAAACTGCCTC_p25-27_Amygdala_SAMN08730984,Ctx_LHX6/SST,9
GCGTAATGGACGGA_p25-27_Amygdala_SAMN08730984,Str_LHX8/CHAT,40
CTTGAGGACAGAAA_p25-27_Amygdala_SAMN08730984,Str_LHX8/CHAT,40
...,...,...
ACTACGGACAGAAA_e12.0_ForebrainVentral_SRR11947650_e12.0_ForebrainVentral_SRR11947650,S-phase_MCM4/H43C,36
GGGACCTGGCGAAG_e12.0_ForebrainVentral_SRR11947650_e12.0_ForebrainVentral_SRR11947650,Transition,41
CAACAGACTGGTTG_e12.0_ForebrainVentral_SRR11947650_e12.0_ForebrainVentral_SRR11947650,Transition,41
AAAGATCTCTTCCG_e12.0_ForebrainVentral_SRR11947650_e12.0_ForebrainVentral_SRR11947650,S-phase_MCM4/H43C,36


In [15]:
atlas_labels.to_csv('../data/mouse/MouseAdultInhibitoryNeurons_labels.csv', index=False)

pd.read_csv('../data/mouse/MouseAdultInhibitoryNeurons_labels.csv')

Unnamed: 0,class,numeric_class
0,S-phase_MCM4/H43C,36
1,S-phase_MCM4/H43C,36
2,Ctx_LHX6/SST,9
3,Str_LHX8/CHAT,40
4,Str_LHX8/CHAT,40
...,...,...
141064,S-phase_MCM4/H43C,36
141065,Transition,41
141066,Transition,41
141067,S-phase_MCM4/H43C,36


In [2]:
if not os.path.isfile('mousegenes.txt'):
    mouse_atlas = sc.read_h5ad('../data/mouse/MouseAdultInhibitoryNeurons.h5ad')
    mo_data = an.read_h5ad('../data/mouse/Mo_PV_paper_TDTomato_mouseonly.h5ad')

    g1 = mo_data.var.index.values 
    g2 = mouse_atlas.var.index.values

    refgenes = sorted(list(set(g1).intersection(g2)))

    with open('mousegenes.txt', 'w') as f:
        for gene in refgenes:
            f.write(gene + '\n')
else:
    with open('mousegenes.txt', 'r') as f:
        refgenes = f.read().splitlines()

In [3]:
len(refgenes)

25163

# Mouse model training with generate_trainer test

In [4]:
datafiles

['../data/mouse/MouseAdultInhibitoryNeurons.h5ad']

In [5]:
trainer, model, module = generate_trainer(
    datafiles=datafiles,
    labelfiles=labelfiles,
    class_label='numeric_class',
    drop_last=True,
    shuffle=True,
    batch_size=4,
    num_workers=0,
    refgenes=refgenes,
    weighted_metrics=True,
    optim_params={
        'optimizer': torch.optim.SGD,
        'lr': 3e-4,
    },
    sep='\t',
    max_epochs=100,
    collocate=False,
    subset=list(range(0, 100000, 100))
)

trainer.fit(model, datamodule=module)

Device is cpu
../data/mouse/MouseAdultInhibitoryNeurons.h5ad exists, continuing...
../data/mouse/Adult Inhibitory Neurons in Mouse_labels.tsv exists, continuing...



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Model initialized. input_dim = 25163, output_dim = 50. Metrics are dict_keys(['accuracy', 'precision', 'recall']) and weighted_metrics = True
Creating train/val/test DataLoaders...
0     75
9     73
27    73
5     65
36    60
8     46
31    44
38    37
39    35
16    34
6     32
20    29
19    28
11    28
2     25
12    24
44    23
17    22
22    21
43    20
48    16
4     16
41    14
42    14
1     14
18    11
15    11
14    11
45    10
3      9
46     8
34     8
28     7
24     7
33     7
35     7
47     6
29     6
37     6
7      4
23     4
13     3
21     2
26     2
40     1
25     1
10     1
Name: numeric_class, dtype: int64


ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

# Mouse Model training code using lib.lightning_data

In [None]:
# labelfile = '../data/mouse/Adult Inhibitory Neurons in Mouse_labels.tsv'
# label_df = pd.read_csv(labelfile, sep='\t')


# current_labels = label_df.loc[:, 'numeric_class']

In [None]:
test_prop=0.2
# trainsplit, valsplit = train_test_split(current_labels, stratify=current_labels, test_size=test_prop)
# trainsplit, testsplit = train_test_split(trainsplit, stratify=trainsplit, test_size=test_prop)


# trainsplit.index

In [None]:
module = DataModule(
    datafiles=['../data/mouse/MouseAdultInhibitoryNeurons.h5ad'],
    labelfiles=['../data/mouse/Adult Inhibitory Neurons in Mouse_labels.tsv'],
    class_label='numeric_class',
    sep='\t',
    collocate=False,
    batch_size=4,
    num_workers=0,
    refgenes=refgenes,
    currgenes=g2,
)

module

In [None]:
module.setup()

In [None]:
train = module.train_dataloader()

print(next(iter(train)))

In [None]:
train = DataLoader(atlas_train, batch_size=4, num_workers=0, drop_last=True, shuffle=True)
val = DataLoader(atlas_val, batch_size=4, num_workers=0, drop_last=True, shuffle=False)

In [None]:
# from models.lib.neural import GeneClassifier
# from pytorch_lightning import Trainer 
# from pytorch_lightning.loggers import WandbLogger

# model = GeneClassifier(
#     input_dim=34430, 
#     output_dim=50,
#     optim_params={
#         'optimizer': torch.optim.Adam,
#         'lr': 3e-4,
#     }
# )

# wandb_logger = WandbLogger(project='Mouse Classifier', name='Tabnet with Metrics')
# trainer = Trainer(logger=wandb_logger)

# trainer.fit(model, train, val)

In [None]:
def gene_intersection(
    files
):
    import dask.dataframe as dd 
    
    cols = []
    for file in files:
        temp = pd.read_csv(fpath, nrows=1, header=1).columns 
        temp = [x.split('|')[0].upper().strip() for x in temp]
        cols.append(set(temp))
    
    unique = list(set.intersection(*cols))
    unique = sorted(unique)
    
    return unique 

In [None]:
generate_single_dataset(
    datafile='../data/mouse/MouseAdultInhibitoryNeurons.h5ad',
    labelfile='../data/mouse/Adult Inhibitory Neurons in Mouse_labels.tsv',
    class_label='numeric_class',
    sep='\t'
)