# TabNet Model Test

In this notebook, we'll test a training loop for the TabNet model 


In [15]:
import sys
sys.path.append('../src')

from models.lib.neural import *
from models.lib.data import *
from models.lib.train import train_loop

from helper import gene_intersection
from pytorch_tabnet.tab_network import TabNet

import torch.nn as nn 
import torch.optim as optim
import torch
from tqdm import tqdm
from torch.utils.data import Subset

First, we'll define our train, val and test sets, then generate the associated DataLoaders and try training.

In [52]:
train, val, test = generate_single_dataset(
    datafile='../data/interim/primary_bhaduri_T.csv',
    labelfile='../data/processed/labels/primary_bhaduri_labels.csv',
    class_label='Type',
    normalize=True,
    skip=3,
)

refgenes = gene_intersection()
currgenes = train.dataset.columns

In [53]:
train[0], val[0], test[0]

((tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4),
 (tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4),
 (tensor([0., 0., 0.,  ..., 0., 0., 0.]), 4))

In [54]:
len(train), len(test), len(val)

(149180, 37296, 37296)

Now, we'll subset and define our DataLoaders

In [55]:
train = Subset(train, range(10))
val = Subset(train, range(10))
test = Subset(test, range(10))

In [56]:
train = DataLoader(train, batch_size=2)

In [57]:
for X, y in train:
    print(X)

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0477, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.1793, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0.0000, 0.1265, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


In [58]:
sample = next(iter(train))[0]
sample = clean_sample(sample, refgenes, currgenes)
sample

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.1281, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [59]:
model = TabNet(
    input_dim=len(refgenes),
    output_dim=18,
)

In [68]:
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor

classifier = TabNetClassifier()


Device used : cpu


AttributeError: 'TabNetClassifier' object has no attribute '__attr__'

In [62]:
from torchmetrics.functional import accuracy

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(1000):  # loop over the dataset multiple times
    running_loss = 0.0

    # Train loop
    model.train()
    for i, data in enumerate(tqdm(train, disable=True)):
        inputs, labels = data
        # CLEAN INPUTS
        inputs = clean_sample(inputs, refgenes, currgenes)
        
        # Forward pass
        outputs, mloss = model(inputs)

        loss = criterion(outputs, labels)
        # Backward pass ⬅
        optimizer.zero_grad()
        loss.backward()

        # Step with optimizer
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 0:
            print(f'Epoch {epoch}, running loss is {running_loss/100}')
            running_loss = 0.0

Epoch 0, running loss is 0.003861255943775177
Epoch 1, running loss is 0.0038433924317359926
Epoch 2, running loss is 0.0038303643465042113
Epoch 3, running loss is 0.0038204428553581236
Epoch 4, running loss is 0.003812532424926758
Epoch 5, running loss is 0.0038059064745903015
Epoch 6, running loss is 0.003800087571144104
Epoch 7, running loss is 0.0037947577238082886
Epoch 8, running loss is 0.0037896955013275147
Epoch 9, running loss is 0.0037847691774368286
Epoch 10, running loss is 0.0037798789143562318
Epoch 11, running loss is 0.0037749558687210083
Epoch 12, running loss is 0.0037699565291404724
Epoch 13, running loss is 0.003764839470386505
Epoch 14, running loss is 0.0037595871090888976
Epoch 15, running loss is 0.0037541866302490236
Epoch 16, running loss is 0.003748604357242584
Epoch 17, running loss is 0.0037428417801856993
Epoch 18, running loss is 0.0037368693947792053
Epoch 19, running loss is 0.0037306669354438783
Epoch 20, running loss is 0.003724181056022644
Epoch 21

Epoch 173, running loss is 0.004542267024517059
Epoch 174, running loss is 0.004520441591739655
Epoch 175, running loss is 0.004499209523200989
Epoch 176, running loss is 0.004478539526462555
Epoch 177, running loss is 0.004458394646644592
Epoch 178, running loss is 0.004438740313053131
Epoch 179, running loss is 0.004419521391391754
Epoch 180, running loss is 0.004400703012943268
Epoch 181, running loss is 0.004382257461547852
Epoch 182, running loss is 0.0043641459941864014
Epoch 183, running loss is 0.004346339702606201
Epoch 184, running loss is 0.0043288213014602665
Epoch 185, running loss is 0.004311559796333313
Epoch 186, running loss is 0.004294542074203492
Epoch 187, running loss is 0.004277738332748413
Epoch 188, running loss is 0.004261143505573273
Epoch 189, running loss is 0.00424473375082016
Epoch 190, running loss is 0.004228494763374328
Epoch 191, running loss is 0.004212416112422943
Epoch 192, running loss is 0.004196484982967377
Epoch 193, running loss is 0.0041806918

Epoch 343, running loss is 0.002503548264503479
Epoch 344, running loss is 0.0024954354763031005
Epoch 345, running loss is 0.002487339377403259
Epoch 346, running loss is 0.0024792774021625517
Epoch 347, running loss is 0.00247123658657074
Epoch 348, running loss is 0.00246323361992836
Epoch 349, running loss is 0.00245527982711792
Epoch 350, running loss is 0.002447372376918793
Epoch 351, running loss is 0.002439487725496292
Epoch 352, running loss is 0.0024316222965717316
Epoch 353, running loss is 0.002423783242702484
Epoch 354, running loss is 0.0024159781634807585
Epoch 355, running loss is 0.002408214807510376
Epoch 356, running loss is 0.002400512248277664
Epoch 357, running loss is 0.0023928342759609224
Epoch 358, running loss is 0.00238517090678215
Epoch 359, running loss is 0.0023775303363800047
Epoch 360, running loss is 0.002368471026420593
Epoch 361, running loss is 0.0023588427901268007
Epoch 362, running loss is 0.002356565296649933
Epoch 363, running loss is 0.00236211

Epoch 512, running loss is 0.0014962412416934967
Epoch 513, running loss is 0.001492053121328354
Epoch 514, running loss is 0.0014878879487514496
Epoch 515, running loss is 0.0014837396144866944
Epoch 516, running loss is 0.0014796112477779388
Epoch 517, running loss is 0.0014755068719387054
Epoch 518, running loss is 0.0014714168012142181
Epoch 519, running loss is 0.001467348039150238
Epoch 520, running loss is 0.00146329864859581
Epoch 521, running loss is 0.0014592680335044862
Epoch 522, running loss is 0.0014552555978298188
Epoch 523, running loss is 0.0014512614905834197
Epoch 524, running loss is 0.0014472846686840058
Epoch 525, running loss is 0.0014433269202709198
Epoch 526, running loss is 0.0014393885433673858
Epoch 527, running loss is 0.0014354661107063293
Epoch 528, running loss is 0.0014315642416477203
Epoch 529, running loss is 0.001427675187587738
Epoch 530, running loss is 0.0014238092303276062
Epoch 531, running loss is 0.0014199548959732055
Epoch 532, running loss i

Epoch 681, running loss is 0.0009777218103408814
Epoch 682, running loss is 0.000975324735045433
Epoch 683, running loss is 0.0009729307889938354
Epoch 684, running loss is 0.0009705465286970139
Epoch 685, running loss is 0.0009681697934865952
Epoch 686, running loss is 0.0009657973796129227
Epoch 687, running loss is 0.0009634366631507873
Epoch 688, running loss is 0.0009610813111066818
Epoch 689, running loss is 0.0009587308019399643
Epoch 690, running loss is 0.0009563906490802765
Epoch 691, running loss is 0.0009540561586618423
Epoch 692, running loss is 0.0009517272561788559
Epoch 693, running loss is 0.0009494056552648545
Epoch 694, running loss is 0.0009470938891172409
Epoch 695, running loss is 0.0009447860717773437
Epoch 696, running loss is 0.0009424856305122375
Epoch 697, running loss is 0.0009401946514844894
Epoch 698, running loss is 0.0009379084408283234
Epoch 699, running loss is 0.0009356294572353363
Epoch 700, running loss is 0.0009333591908216477
Epoch 701, running lo

Epoch 849, running loss is 0.0006177312880754471
Epoch 850, running loss is 0.0006159341707825661
Epoch 851, running loss is 0.000614144317805767
Epoch 852, running loss is 0.0006123654916882514
Epoch 853, running loss is 0.0006105886399745942
Epoch 854, running loss is 0.0006088180840015411
Epoch 855, running loss is 0.000607055276632309
Epoch 856, running loss is 0.0006053056940436364
Epoch 857, running loss is 0.0006035559624433518
Epoch 858, running loss is 0.0006018146499991417
Epoch 859, running loss is 0.000600077472627163
Epoch 860, running loss is 0.0005983518436551094
Epoch 861, running loss is 0.0005966341122984886
Epoch 862, running loss is 0.0005949189886450768
Epoch 863, running loss is 0.0005932118743658066
Epoch 864, running loss is 0.0005915088206529617
Epoch 865, running loss is 0.0005898163840174675
Epoch 866, running loss is 0.0005881313607096672
Epoch 867, running loss is 0.0005864521488547325
Epoch 868, running loss is 0.0005847766250371933
Epoch 869, running loss