# TabNet Model Test

In this notebook, we'll test a training loop for the TabNet model 


In [1]:
import sys
sys.path.append('../src')

from models.lib.neural import *
from models.lib.data import *
from models.lib.train import *

import helper 
from helper import gene_intersection
from pytorch_tabnet.tab_network import TabNet

import torch.nn as nn 
import torch.optim as optim
import torch
from tqdm import tqdm
from torch.utils.data import Subset
from helper import seed_everything

seed_everything(42)

First, we'll define our train, val and test sets, then generate the associated DataLoaders and try training.

In [2]:
train, val, test = generate_single_dataset(
    datafile='../data/interim/primary_bhaduri_T.csv',
    labelfile='../data/processed/labels/primary_bhaduri_labels.csv',
    class_label='Type',
    normalize=True,
    skip=3,
)

refgenes = gene_intersection()
type(train)

models.lib.data.GeneExpressionData

In [8]:
train._labeldf.iloc[0, 'Type']

ValueError: Location based indexing can only have [integer, integer slice (START point is INCLUDED, END point is EXCLUDED), listlike of integers, boolean array] types

In [4]:
train[0]

ValueError: Location based indexing can only have [integer, integer slice (START point is INCLUDED, END point is EXCLUDED), listlike of integers, boolean array] types

In [18]:
val._labeldf

Unnamed: 0,cell,Type
58104,58131,4
83193,83225,4
137498,137557,16
85322,85355,4
81579,81611,4
...,...,...
63276,63305,4
7491,7507,4
22384,22405,7
33875,33899,4


In [None]:
train.shape, val.shape, test.shape

In [8]:
t = helper.INTERIM_DATA_AND_LABEL_FILES_LIST
datafiles, labelfiles = zip(*t.items())
datafiles = [f'../data/interim/{f}' for f in datafiles]
labelfiles = [f'../data/processed/labels/{f}' for f in labelfiles]

datafiles, labelfiles

(['../data/interim/primary_bhaduri_T.csv',
  '../data/interim/allen_cortex_T.csv',
  '../data/interim/allen_m1_region_T.csv',
  '../data/interim/whole_brain_bhaduri_T.csv'],
 ['../data/processed/labels/primary_bhaduri_labels.csv',
  '../data/processed/labels/allen_cortex_labels.csv',
  '../data/processed/labels/allen_m1_region_labels.csv',
  '../data/processed/labels/whole_brain_bhaduri_labels.csv'])

In [None]:
fulltrain, fullval, fulltest = generate_loaders(
    datafiles,
    labelfiles,
    'Type',
    num_workers=0,
    collocate=False,
)

In [None]:
fulltrain

Now, we'll subset and define our DataLoaders

In [30]:
fulltrain = Subset(train, range(10))
val = Subset(train, range(10))
test = Subset(test, range(10))

In [31]:
train = DataLoader(train, batch_size=2)
val = DataLoader(val, batch_size=2)
test = DataLoader(test, batch_size=2)

In [32]:
len(train)

74590

In [33]:
sample = next(iter(train))[0]
sample = clean_sample(sample, refgenes, currgenes)
sample

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.1107, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [34]:
next(iter(val))

[tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.1740, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([8, 4])]

In [46]:
model = TabNet(
    input_dim=len(refgenes),
    output_dim=18,
)

model(sample)

(tensor([[ 0.7212, -1.4900, -1.4875, -0.8384,  0.7192, -1.9079,  2.2873, -1.1778,
           1.4394,  0.0852,  2.8979,  1.5542, -0.4073,  2.4720,  0.3535, -1.4971,
           0.7302,  0.3869],
         [ 0.2523, -3.2524,  0.2527, -0.0890, -1.2329,  0.6200,  2.7651,  0.5081,
          -2.0700, -0.4171,  3.0618, -0.4611, -2.4633, -0.4383,  1.4057,  0.1778,
           0.7112,  0.7152]], grad_fn=<MmBackward0>),
 tensor(-8.6592, grad_fn=<DivBackward0>))

In [47]:
from models.lib.neural import TabNetGeneClassifier

model = TabNetGeneClassifier(input_dim=len(refgenes), output_dim=18)

In [48]:
model(sample)

tensor([[ 0.8933, -0.3937,  0.0965,  0.8868,  1.7275, -1.8121, -0.5918, -1.8454,
         -0.3578, -1.0836, -0.6907, -0.6712,  0.3328, -1.2917,  0.3095, -1.0449,
          0.7997, -0.5279],
        [-0.6963,  0.4636, -0.7859,  2.6002,  0.2334, -0.1801,  0.8615, -0.6257,
         -3.6933, -1.7297, -1.4980, -1.5532,  0.8441, -0.4745, -1.1023, -2.3745,
          1.9691,  0.0044]], grad_fn=<MmBackward0>)

In [51]:
import wandb
from torchmetrics.functional import accuracy

wandb.init()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

train_loss = []
val_loss = []
test_loss = []

mod = 10
wandb.watch(model)
for epoch in range(1000):  # loop over the dataset multiple times
    running_loss = 0.0
    epoch_loss = 0.0
    # Train loop
    model.train()
    for i, data in enumerate(train):
        inputs, labels = data
        # CLEAN INPUTS
        inputs = clean_sample(inputs, refgenes, currgenes)
        # Forward pass ➡
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass ⬅
        optimizer.zero_grad()
        loss.backward()

        # Step with optimizer
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        epoch_loss += loss.item()
        

        if i % mod == 0: # record every 2000 mini batches 
            metric_results = calculate_metrics(
                outputs=outputs,
                labels=labels,
                append_str='train',
                num_classes=model.output_dim,
                subset='weighted_accuracy',
            )

            wandb.log(metric_results)
            running_loss = running_loss / mod
            wandb.log({f"batch_train_loss": loss})
            
            running_loss = 0.0
            
    wandb.log({f"epoch_train_loss": epoch_loss / len(train)})
    
    model.eval()
    with torch.no_grad(): # save memory but not computing gradients 
        running_loss = 0.0
        epoch_loss = 0.0
        
        for i, data in enumerate(val):
            inputs, labels = data
            # CLEAN INPUTS
            inputs = clean_sample(inputs, refgenes, currgenes)
            # Forward pass ➡
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # print statistics
            running_loss += loss.item()
            epoch_loss += loss.item()
            
            if i % mod == 0: #every 2000 mini batches 
                running_loss = running_loss / mod
                wandb.log({"val_loss": loss})
                running_loss = 0.0
                
                metric_results = calculate_metrics(
                    outputs=outputs,
                    labels=labels,
                    num_classes=model.output_dim,
                    subset='weighted_accuracy',
                    append_str='val',
                )
            
            wandb.log(metric_results)
    
        wandb.log({f"epoch_val_loss": epoch_loss / len(train)})





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt 

plt.plot(train_loss, label='Train')
plt.plot(val_loss, label='Val')
plt.legend()
plt.show()

In [None]:
subset = ['asdf']
subset = ([subset] if isinstance(subset, str) else subset)
subset

In [9]:
from sklearn.model_selection import train_test_split 

class_label = 'Type'
current_labels = pd.read_csv(labelfiles[0]).loc[:, class_label]

# Make stratified split on labels
trainsplit, valsplit = train_test_split(current_labels, stratify=current_labels)
trainsplit, testsplit = train_test_split(trainsplit, stratify=trainsplit)

In [12]:
trainsplit.index

Int64Index([174325,  18203,  27173,  36480, 159231, 115763, 186094, 148146,
             90320, 121851,
            ...
            109835,  66503,  77671, 122667, 113471, 109234,  44183, 173813,
             25342, 114073],
           dtype='int64', length=104892)

In [13]:
valsplit.index

Int64Index([181994,  99209, 122906,  57260, 158699, 176820,  91588, 163894,
            185239,  96336,
            ...
            129980,  81327,  77781, 175941, 106817, 185070,  81234,  44742,
            138899,  19378],
           dtype='int64', length=46619)

In [14]:
testsplit.index

Int64Index([161887,  37156, 183727,   1018,  45962, 164500, 102501,  52386,
            150553, 160443,
            ...
            135993, 165830, 122385, 144447,  82285,  23968,  82232, 170058,
            141393, 124560],
           dtype='int64', length=34965)

In [66]:
t = [1]
l = *t

SyntaxError: can't use starred expression here (2466760160.py, line 2)