# TabNet Model Test

In this notebook, we'll test a training loop for the TabNet model 


In [1]:
import sys
sys.path.append('../src')

from models.lib.neural import *
from models.lib.data import *
from models.lib.train import *

from helper import gene_intersection
from pytorch_tabnet.tab_network import TabNet

import torch.nn as nn 
import torch.optim as optim
import torch
from tqdm import tqdm
from torch.utils.data import Subset

First, we'll define our train, val and test sets, then generate the associated DataLoaders and try training.

In [12]:
train, val, test = generate_single_dataset(
    datafile='../data/interim/primary_bhaduri_T.csv',
    labelfile='../data/processed/labels/primary_bhaduri_labels.csv',
    class_label='Type',
    normalize=True,
    skip=3,
)

refgenes = gene_intersection()
currgenes = train.dataset.columns

In [13]:
train[0], val[0], test[0]

((tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4),
 (tensor([0.0929, 0.0000, 0.0000,  ..., 0.0657, 0.0000, 0.0000]), 4),
 (tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4))

In [14]:
len(train), len(test), len(val)

(149180, 37296, 37296)

Now, we'll subset and define our DataLoaders

In [15]:
# train = Subset(train, range(10))
# val = Subset(train, range(10))
# test = Subset(test, range(10))

In [16]:
train = DataLoader(train, batch_size=2)
val = DataLoader(val, batch_size=2)
test = DataLoader(test, batch_size=2)

In [17]:
len(train)

74590

In [18]:
sample = next(iter(train))[0]
sample = clean_sample(sample, refgenes, currgenes)
sample

tensor([[0.1836, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.1571, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [19]:
next(iter(val))

[tensor([[0.0929, 0.0000, 0.0000,  ..., 0.0657, 0.0000, 0.0000],
         [0.1932, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([4, 4])]

In [20]:
model = TabNet(
    input_dim=len(refgenes),
    output_dim=18,
)

In [None]:
import wandb
from torchmetrics.functional import accuracy

wandb.init()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

train_loss = []
val_loss = []
test_loss = []

wandb.watch(model)
for epoch in range(1000):  # loop over the dataset multiple times
    running_loss = 0.0

    # Train loop
    model.train()
    for i, data in enumerate(train):
        inputs, labels = data
        # CLEAN INPUTS
        inputs = clean_sample(inputs, refgenes, currgenes)
        # Forward pass ➡
        outputs, M_loss = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass ⬅
        optimizer.zero_grad()
        loss.backward()

        # Step with optimizer
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        metric_results = calculate_metrics(
            outputs=outputs,
            labels=labels,
            append_str='train',
            num_classes=model.output_dim,
            subset='weighted_accuracy'
        )
        wandb.log(metric_results)

        if i % 100 == 0:
            running_loss = running_loss / 100
            wandb.log({"train_loss": loss})
            running_loss = 0.0
            

    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val):
            inputs, labels = data
            # CLEAN INPUTS
            inputs = clean_sample(inputs, refgenes, currgenes)
            # Forward pass ➡
            outputs, M_loss = model(inputs)
            loss = criterion(outputs, labels)

            # print statistics
            running_loss += loss.item()
            metric_results = calculate_metrics(
                outputs=outputs,
                labels=labels,
                append_str='val',
                num_classes=model.output_dim,
                subset='weighted_accuracy'
            )
            wandb.log(metric_results)

            if i % 100 == 0:
                running_loss = running_loss / 100
                wandb.log({"val_loss": loss})
                running_loss = 0.0

# model.eval()
# with torch.no_grad():
#     for i, data in enumerate(test):
#         running_loss, _ = _inner_computation(
#             data,
#             model,
#             optimizer,
#             i,
#             running_loss,
#             refgenes,
#             currgenes,
#             'test',
#         )




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_loss,▂█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
weighted_accuracy_train,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,0.00065
weighted_accuracy_train,1.0


In [None]:
import matplotlib.pyplot as plt 

plt.plot(train_loss, label='Train')
plt.plot(val_loss, label='Val')
plt.legend()
plt.show()

In [None]:
subset = ['asdf']
subset = ([subset] if isinstance(subset, str) else subset)
subset