# TabNet Model Test

In this notebook, we'll test a training loop for the TabNet model 


In [1]:
import sys
sys.path.append('../src')

from models.lib.neural import *
from models.lib.data import *
from models.lib.train import *

import helper 
from helper import gene_intersection
from pytorch_tabnet.tab_network import TabNet

import torch.nn as nn 
import torch.optim as optim
import torch
from tqdm import tqdm
from torch.utils.data import Subset
from helper import seed_everything

seed_everything(42)

In [2]:
import pytorch_lightning

First, we'll define our train, val and test sets, then generate the associated DataLoaders and try training.

In [3]:
t = helper.INTERIM_DATA_AND_LABEL_FILES_LIST
datafiles, labelfiles = zip(*t.items())
datafiles = [f'../data/interim/{f}' for f in datafiles]
labelfiles = [f'../data/processed/labels/{f}' for f in labelfiles]
refgenes = gene_intersection()

datafiles, labelfiles

(['../data/interim/primary_bhaduri_T.csv',
  '../data/interim/allen_cortex_T.csv',
  '../data/interim/allen_m1_region_T.csv',
  '../data/interim/whole_brain_bhaduri_T.csv'],
 ['../data/processed/labels/primary_bhaduri_labels.csv',
  '../data/processed/labels/allen_cortex_labels.csv',
  '../data/processed/labels/allen_m1_region_labels.csv',
  '../data/processed/labels/whole_brain_bhaduri_labels.csv'])

In [23]:
train, val, test = generate_single_dataset(
    datafiles[0],
    labelfiles[0],
    'Type',
    skip=3,
)

In [5]:
# print(len(train[0][0]))

In [6]:
# trainloader_map, _, _ = generate_single_dataloader(
#     datafile=datafiles[0], 
#     labelfile=labelfiles[0], 
#     class_label='Type',
#     skip=3,
#     map_genes=True
# )

# trainloader_nomap, _, _ = generate_single_dataloader(
#     datafile=datafiles[0], 
#     labelfile=labelfiles[0], 
#     class_label='Type',
#     skip=3,
#     map_genes=False
# )

In [7]:
# for i, (X, y) in enumerate(tqdm(trainloader_map)):
#     if i == 200:
#         break

In [8]:
# for i, (X, y) in enumerate(tqdm(trainloader_nomap)):
#     if i == 200:
#         break
#     X = clean_sample(X, refgenes, train.features)
    

In [9]:
# train_loader, val_loader, test_loader = generate_loaders(
#     datafiles,
#     labelfiles,
#     'Type',
#     num_workers=0,
#     collocate=True,
# )

In [10]:
# for X, y in train_loader:
#     print(type(train_loader))

In [11]:
len(refgenes)

16604

In [12]:
from models.lib.neural import TabNetGeneClassifier

model = TabNetGeneClassifier(
    input_dim=len(refgenes),
    output_dim=19
)

In [42]:
class SampleLoader(torch.utils.data.DataLoader):
    def __init__(self, refgenes, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.refgenes = refgenes
        self.currgenes = self.dataset.columns 
            
    def __iter__(self):
        for batch in super().__iter__():
            yield clean_sample(batch[0], self.refgenes, self.currgenes), batch[1]

In [43]:
test = SampleLoader(refgenes=refgenes, dataset=train, batch_size=11, num_workers=0)
next(iter(test))

(tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.9815, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [1.0537, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.7632, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [1.0228, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([ 6,  4,  4,  4, 16,  4,  4,  4, 16, 16,  4]))

In [27]:
next(iter(test))[0].shape

torch.Size([11, 16604])

In [48]:
from functools import partial 

def custom_collate(sample):
    data = torch.stack([x[0] for x in sample])
    labels = torch.tensor([x[1] for x in sample])
    return data, labels

test = DataLoader(dataset=train, collate_fn=custom_collate, batch_size=4)

[(tensor([0., 0., 0.,  ..., 0., 0., 0.]), 6), (tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4), (tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4), (tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4)]


(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([6, 4, 4, 4]))

## PyTorch-Lightning compatible TabNet architecture 

In [16]:
from models.lib.neural import GeneClassifier
from models.lib.neural import TabNetGeneClassifier

base_model = TabNetGeneClassifier(
    input_dim=len(refgenes),
    output_dim=19,
)

In [17]:
base_model.input_dim

16604

In [18]:
classifier = GeneClassifier(
    input_dim=base_model.input_dim,
    output_dim=base_model.output_dim,
    model=base_model,
)

Model initialized. input_dim = 16604, output_dim = 19. Metrics are dict_keys(['accuracy', 'precision', 'recall']) and weighted_metrics = False


In [20]:
import pytorch_lightning as pl 
from typing import *

class GeneDataModule(pl.LightningDataModule):
    def __init__(
        self, 
        datafiles: List[str],
        labelfiles: List[str],
        class_label: str,
        batch_size: int=16,
        num_workers=32,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.__dict__.update(**kwargs)

        self.datafiles = datafiles
        self.labelfiles = labelfiles
        self.class_label = class_label
        
        self.num_workers = num_workers
        self.batch_size = batch_size
        
        self.trainloaders = []
        self.valloaders = []
        self.testloaders = []
        
        self.args = args
        self.kwargs = kwargs
        
    def prepare_data(self):
        # Download data from S3 here 
        pass 
    
    def setup(self, stage: Optional[str] = None):
        for datafile, labelfile in zip(self.datafiles, self.labelfiles):
            train, val, test = generate_single_dataloader(
                datafile=datafile,
                labelfile=labelfile,
                class_label=self.class_label,
                *self.args,
                **self.kwargs,
            )
            
            self.trainloaders.append(train)
            self.valloaders.append(val)
            self.testloaders.append(test)
            
    def train_dataloader(self):
        return self.trainloaders

    def val_dataloader(self):
        return self.valloaders

    def test_dataloader(self):
        return self.testloaders

In [21]:
module = GeneDataModule(
    datafiles, 
    labelfiles, 
    'Type', 
    skip=3, 
    normalize=True
)

In [57]:
import functools 

def custom_collate(sample, refgenes, currgenes):
    data = clean_sample(torch.stack([x[0] for x in sample]), refgenes, currgenes)
    labels = torch.tensor([x[1] for x in sample])
    return data, labels

class CollateLoader(torch.utils.data.DataLoader):
    def __init__(self, refgenes, currgenes, *args, **kwargs):
        collate_fn = functools.partial(custom_collate, refgenes=refgenes, currgenes=currgenes)
        super().__init__(collate_fn = collate_fn, *args, **kwargs)
        

In [60]:
test = CollateLoader(dataset=train, refgenes=refgenes, currgenes=train.columns, batch_size=4)
next(iter(test))[0].shape

torch.Size([4, 16604])

In [22]:
from pytorch_lightning import Trainer

trainer = Trainer()
trainer.fit(classifier, datamodule=module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type                 | Params
-----------------------------------------------
0 | model | TabNetGeneClassifier | 1.1 M 
-----------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.407     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


RuntimeError: running_mean should contain 19765 elements not 16604

Now, we'll subset and define our DataLoaders

In [None]:
import wandb
from torchmetrics.functional import accuracy

wandb.init()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

train_loss = []
val_loss = []
test_loss = []

mod = 10
wandb.watch(model)
for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    epoch_loss = 0.0
    # Train loop
    model.train()
    for idx, train in enumerate(train_loader):
        print(f'On loader {idx = }')
        for i, data in enumerate(tqdm(train)):
            print(f'On minibatch {i = }/10')
            if i == 10:
                break 
            inputs, labels = data
            # CLEAN INPUTS
            inputs = clean_sample(inputs, refgenes, train.dataset.columns)
            # Forward pass ➡
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass ⬅
            optimizer.zero_grad()
            loss.backward()

            # Step with optimizer
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_loss += loss.item()


            if i % mod == 0: # record every 2000 mini batches 
                metric_results = calculate_metrics(
                    outputs=outputs,
                    labels=labels,
                    append_str='train',
                    num_classes=model.output_dim,
                    subset='weighted_accuracy',
                )

                wandb.log(metric_results)
                running_loss = running_loss / mod
                wandb.log({f"batch_train_loss": loss})

                running_loss = 0.0
            
    wandb.log({f"epoch_train_loss": epoch_loss / len(train)})
    
    model.eval()
    with torch.no_grad(): # save memory but not computing gradients 
        running_loss = 0.0
        epoch_loss = 0.0
        
        for val in val_loader:
            print(f'On loader {i = }')
            for i, data in enumerate(val):
                if i == 10:
                    break 
                inputs, labels = data
                # CLEAN INPUTS
                inputs = clean_sample(inputs, refgenes, val.dataset.columns)
                # Forward pass ➡
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # print statistics
                running_loss += loss.item()
                epoch_loss += loss.item()

                if i % mod == 0: #every 2000 mini batches 
                    running_loss = running_loss / mod
                    wandb.log({"val_loss": loss})
                    running_loss = 0.0

                    metric_results = calculate_metrics(
                        outputs=outputs,
                        labels=labels,
                        num_classes=model.output_dim,
                        subset='weighted_accuracy',
                        append_str='val',
                    )

                wandb.log(metric_results)
    
        wandb.log({f"epoch_val_loss": epoch_loss / len(train)})


In [None]:
model

In [None]:
def test_loop(
    model,
    testloaders,
    refgenes,
    criterion,
    mod,
):
    model.eval()
    
    with torch.no_grad():
        for idx, test in enumerate(testloaders):
            print(f'On {idx = }')
            running_loss = 0.0
            for i, data in enumerate(test):
                print(f'minibatch {i = }')
                if i == 10:
                    break
                inputs, labels = data
                # CLEAN INPUTS
                inputs = clean_sample(inputs, refgenes, test.dataset.columns)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # print statistics
                running_loss += loss.item()
                if i % mod == 0: #every 2000 mini batches 
                    running_loss = running_loss / mod
                    wandb.log({"test_loss": loss})
                    running_loss = 0.0

                    metric_results = calculate_metrics(
                        outputs=outputs,
                        labels=labels,
                        num_classes=model.output_dim,
                        subset='weighted_accuracy',
                        append_str='test',
                    )

                    wandb.log(metric_results)


In [None]:
test_loop(model, test_loader, refgenes, criterion, mod)

In [None]:
import matplotlib.pyplot as plt 

plt.plot(train_loss, label='Train')
plt.plot(val_loss, label='Val')
plt.legend()
plt.show()

In [None]:
labelfiles

In [None]:
from numpy import memmap

In [None]:
f = memmap('../data/interim/allen_cortex_T.csv', dtype=np.float64, mode='r')