# TabNet Model Test

In this notebook, we'll test a training loop for the TabNet model 


In [1]:
import sys
sys.path.append('../src')

from models.lib.neural import *
from models.lib.data import *
from models.lib.train import *

import helper 
from helper import gene_intersection
from pytorch_tabnet.tab_network import TabNet

import torch.nn as nn 
import torch.optim as optim
import torch
from tqdm import tqdm
from torch.utils.data import Subset
from helper import seed_everything

seed_everything(42)

First, we'll define our train, val and test sets, then generate the associated DataLoaders and try training.

In [2]:
t = helper.INTERIM_DATA_AND_LABEL_FILES_LIST
datafiles, labelfiles = zip(*t.items())
datafiles = [f'../data/interim/{f}' for f in datafiles]
labelfiles = [f'../data/processed/labels/{f}' for f in labelfiles]

datafiles, labelfiles

(['../data/interim/primary_bhaduri_T.csv',
  '../data/interim/allen_cortex_T.csv',
  '../data/interim/allen_m1_region_T.csv',
  '../data/interim/whole_brain_bhaduri_T.csv'],
 ['../data/processed/labels/primary_bhaduri_labels.csv',
  '../data/processed/labels/allen_cortex_labels.csv',
  '../data/processed/labels/allen_m1_region_labels.csv',
  '../data/processed/labels/whole_brain_bhaduri_labels.csv'])

In [8]:
train_loader, val_loader, test_loader = generate_loaders(
    datafiles,
    labelfiles,
    'Type',
    num_workers=0,
    collocate=False,
)

refgenes = gene_intersection()

In [15]:
from models.lib.neural import TabNetGeneClassifier

model = TabNetGeneClassifier(
    input_dim=len(refgenes),
    output_dim=18
)

In [18]:
sample = next(iter(train_loader[0]))[0]
sample = clean_sample(sample, refgenes, train_loader[0].dataset.columns)
sample

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.9815, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.0537, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.0578, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [23]:
model(sample)

tensor([[ 1.2157, -0.5378,  0.0325, -0.4296, -0.6017,  1.8330,  0.3227, -0.2993,
          0.3952,  2.5684,  3.7755,  1.5207,  2.0239, -1.0538,  0.9257,  2.2742,
         -1.4293,  0.0628],
        [ 2.0681,  0.3052,  0.7337, -0.6232, -2.1632,  2.6200,  2.9574, -3.0004,
         -1.6573,  0.9772,  0.2851, -1.0471,  2.5345, -0.0687,  0.4001,  3.0233,
         -0.8894, -0.3621],
        [ 0.5813,  0.6007, -0.1726,  0.3551, -1.4737,  2.1677,  2.3966, -1.5421,
         -1.9907,  0.8255,  1.7362,  1.2050,  0.6144, -0.8192, -1.0613,  1.9794,
         -0.6014, -0.9850],
        [ 1.9329,  0.8326,  0.8148, -0.3646, -1.7832,  2.4768,  2.5341, -2.3260,
         -1.3227,  1.2317,  1.5714, -1.0348,  1.5763, -1.4306,  0.0840,  2.3380,
         -0.9355, -0.3554]], grad_fn=<MmBackward0>)

Now, we'll subset and define our DataLoaders

In [13]:
import wandb
from torchmetrics.functional import accuracy

wandb.init()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

train_loss = []
val_loss = []
test_loss = []

mod = 10
wandb.watch(model)
for epoch in range(1000):  # loop over the dataset multiple times
    running_loss = 0.0
    epoch_loss = 0.0
    # Train loop
    model.train()
    for train in train_loader:
        for i, data in enumerate(train):
            inputs, labels = data
            # CLEAN INPUTS
            inputs = clean_sample(inputs, refgenes, train.dataset.columns)
            # Forward pass ➡
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass ⬅
            optimizer.zero_grad()
            loss.backward()

            # Step with optimizer
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_loss += loss.item()


            if i % mod == 0: # record every 2000 mini batches 
                metric_results = calculate_metrics(
                    outputs=outputs,
                    labels=labels,
                    append_str='train',
                    num_classes=model.output_dim,
                    subset='weighted_accuracy',
                )

                wandb.log(metric_results)
                running_loss = running_loss / mod
                wandb.log({f"batch_train_loss": loss})

                running_loss = 0.0
            
    wandb.log({f"epoch_train_loss": epoch_loss / len(train)})
    
    model.eval()
    with torch.no_grad(): # save memory but not computing gradients 
        running_loss = 0.0
        epoch_loss = 0.0
        
        for val in val_loader:
            for i, data in enumerate(val):
                inputs, labels = data
                # CLEAN INPUTS
                inputs = clean_sample(inputs, refgenes, val.dataset.columns)
                # Forward pass ➡
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # print statistics
                running_loss += loss.item()
                epoch_loss += loss.item()

                if i % mod == 0: #every 2000 mini batches 
                    running_loss = running_loss / mod
                    wandb.log({"val_loss": loss})
                    running_loss = 0.0

                    metric_results = calculate_metrics(
                        outputs=outputs,
                        labels=labels,
                        num_classes=model.output_dim,
                        subset='weighted_accuracy',
                        append_str='val',
                    )

                wandb.log(metric_results)
    
        wandb.log({f"epoch_val_loss": epoch_loss / len(train)})





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not tuple

In [None]:
import matplotlib.pyplot as plt 

plt.plot(train_loss, label='Train')
plt.plot(val_loss, label='Val')
plt.legend()
plt.show()

In [None]:
subset = ['asdf']
subset = ([subset] if isinstance(subset, str) else subset)
subset