# Modeling Sanity Check: Making sure everything is ok

In this notebook, we'll test our entire network pipeline because SURELY there are bugs.

In [1]:
import dask.dataframe as dd
import pandas as pd 
import torch
import linecache 
import csv
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sys
import torch
import pytorch_lightning as pl
sys.path.append('../src/')
sys.path.append('..')

Let's define our custom data class and make sure everything is being streamed in correctly

In [25]:
from models.lib.data import GeneExpressionData
from models.lib.neural import GeneClassifier

In [26]:
data = GeneExpressionData(
    filename='../data/interim/primary_bhaduri_T.csv',
    labelname='../data/processed/labels/primary_bhaduri_labels.csv',
    class_label='Type',
    skip=3,
)

data.shape

(186476, 19765)

In [27]:
model = GeneClassifier(
    N_features = len(data.columns),
    N_labels = max(data.labels), # Since indexed from zero
    weights=data.class_weights,
    params={
        'width' : 1024,
        'layers': 2,
        'epochs': 10,
        'lr': 3e-5,
        'momentum': 1e-4,
        'weight_decay': 1e-4
    }
)

Model initialized. N_features = 19765, N_labels = 17. Metrics are {'accuracy': <function accuracy at 0x7fc803c50430>, 'precision': <function precision at 0x7fc803c63a60>, 'recall': <function recall at 0x7fc803c63b80>} and weighted_metrics = False


In [28]:
data.labels

array([16,  4,  9, 11,  6,  8,  7,  3, 17])

Now that we have our dataset, at least make sure a forward pass is computing correctly, and that our model can at least overfit on a small subset of the training set. Therefore, we'll subset our dataset and create the train and val loaders this way.

In [29]:
from torch.utils.data import Subset

tr_10k = Subset(data, range(10))

In [31]:
def train_test(data):
    train_size = int(0.80 * len(data))
    test_size = len(data) - train_size

    train, test = torch.utils.data.random_split(data, [train_size, test_size])

    traindata = DataLoader(train, batch_size=4)
    valdata = DataLoader(test, batch_size=4)
    
    return traindata, valdata

train, test = train_test(tr_10k)

In [32]:
len(train), len(test)

(2, 1)

Even though we'll ultimately be using PyTorch Lightning for GPU training, let's try writing the training loop here so we can debug each step. To do this, we'll need to redefine the optimizer and loss

In [33]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [34]:
from tqdm import tqdm 

for epoch in range(1000):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, sample in tqdm(enumerate(train)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = sample

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 0: # print every 2000 mini-batches
#             print(epoch, running_loss / 100)
            running_loss = 0.0

print('Finished Training')

2it [00:00, 10.89it/s]
2it [00:00, 21.30it/s]
2it [00:00, 23.01it/s]
2it [00:00, 22.39it/s]
2it [00:00, 22.45it/s]
2it [00:00, 23.05it/s]
2it [00:00, 22.09it/s]
2it [00:00, 23.70it/s]
2it [00:00, 23.56it/s]
2it [00:00, 22.95it/s]
2it [00:00, 23.40it/s]
2it [00:00, 23.46it/s]
2it [00:00, 23.33it/s]
2it [00:00, 23.25it/s]
2it [00:00, 22.13it/s]
2it [00:00, 23.10it/s]
2it [00:00, 23.39it/s]
2it [00:00, 23.41it/s]
2it [00:00, 23.67it/s]
2it [00:00, 23.38it/s]
2it [00:00, 23.21it/s]
2it [00:00, 23.63it/s]
2it [00:00, 23.73it/s]
2it [00:00, 23.35it/s]
2it [00:00, 23.32it/s]
2it [00:00, 23.60it/s]
2it [00:00, 23.45it/s]
2it [00:00, 23.45it/s]
2it [00:00, 23.59it/s]
2it [00:00, 23.72it/s]
2it [00:00, 23.63it/s]
2it [00:00, 23.51it/s]
2it [00:00, 23.36it/s]
2it [00:00, 23.72it/s]
2it [00:00, 23.60it/s]
2it [00:00, 22.94it/s]
2it [00:00, 22.76it/s]
2it [00:00, 23.63it/s]
2it [00:00, 23.44it/s]
2it [00:00, 23.54it/s]
2it [00:00, 23.39it/s]
2it [00:00, 22.65it/s]
2it [00:00, 22.75it/s]
2it [00:00,

2it [00:00, 24.24it/s]
2it [00:00, 24.35it/s]
2it [00:00, 24.27it/s]
2it [00:00, 24.31it/s]
2it [00:00, 24.34it/s]
2it [00:00, 24.30it/s]
2it [00:00, 24.27it/s]
2it [00:00, 24.12it/s]
2it [00:00, 24.23it/s]
2it [00:00, 24.40it/s]
2it [00:00, 24.23it/s]
2it [00:00, 24.30it/s]
2it [00:00, 24.31it/s]
2it [00:00, 24.38it/s]
2it [00:00, 24.40it/s]
2it [00:00, 24.20it/s]
2it [00:00, 24.27it/s]
2it [00:00, 24.32it/s]
2it [00:00, 24.20it/s]
2it [00:00, 24.07it/s]
2it [00:00, 23.92it/s]
2it [00:00, 24.14it/s]
2it [00:00, 24.23it/s]
2it [00:00, 24.34it/s]
2it [00:00, 24.17it/s]
2it [00:00, 24.26it/s]
2it [00:00, 24.19it/s]
2it [00:00, 23.95it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.10it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.35it/s]
2it [00:00, 24.33it/s]
2it [00:00, 24.31it/s]
2it [00:00, 24.41it/s]
2it [00:00, 24.26it/s]
2it [00:00, 24.17it/s]
2it [00:00, 23.23it/s]
2it [00:00, 23.24it/s]
2it [00:00, 23.85it/s]
2it [00:00, 23.99it/s]
2it [00:00, 24.20it/s]
2it [00:00,

2it [00:00, 24.24it/s]
2it [00:00, 24.28it/s]
2it [00:00, 24.08it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.12it/s]
2it [00:00, 24.12it/s]
2it [00:00, 24.19it/s]
2it [00:00, 24.30it/s]
2it [00:00, 24.20it/s]
2it [00:00, 24.01it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.26it/s]
2it [00:00, 24.12it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.17it/s]
2it [00:00, 24.30it/s]
2it [00:00, 24.23it/s]
2it [00:00, 24.13it/s]
2it [00:00, 24.30it/s]
2it [00:00, 24.24it/s]
2it [00:00, 24.20it/s]
2it [00:00, 24.22it/s]
2it [00:00, 24.25it/s]
2it [00:00, 24.21it/s]
2it [00:00, 24.23it/s]
2it [00:00, 24.11it/s]
2it [00:00, 24.22it/s]
2it [00:00, 23.99it/s]
2it [00:00, 24.24it/s]
2it [00:00, 24.08it/s]
2it [00:00, 24.20it/s]
2it [00:00, 24.08it/s]
2it [00:00, 24.10it/s]
2it [00:00, 24.17it/s]
2it [00:00, 24.13it/s]
2it [00:00, 23.26it/s]
2it [00:00, 23.25it/s]
2it [00:00, 23.70it/s]
2it [00:00, 23.49it/s]
2it [00:00, 24.05it/s]
2it [00:00, 23.85it/s]
2it [00:00, 24.22it/s]
2it [00:00,

Finished Training





In [14]:
from typing import *
from torchmetrics import Accuracy
import torch.nn.functional as F

class TEST(pl.LightningModule):
    def __init__(self, 
        N_features: int, 
        N_labels: int, 
        weights: List[torch.Tensor], 
        params: Dict[str, float],
    ):
        super(TEST, self).__init__()

        # Set hyperparameters
        self.width = params['width']
        self.layers = params['layers']
        self.lr = params['lr']
        self.momentum = params['momentum']
        self.weight_decay = params['weight_decay']

        layers = self.layers*[
            nn.Linear(self.width, self.width),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.BatchNorm1d(self.width),
        ]

        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(N_features, self.width),
            *layers,
            nn.Linear(self.width, N_labels),
        )

        self.accuracy = Accuracy(average='weighted', num_classes=N_labels)
        self.weights = weights

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.lr, 
            momentum=self.momentum, 
            weight_decay=self.weight_decay,
        )

        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y, weight=self.weights)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)

        self.log("train_loss", loss, logger=True, on_epoch=True)
        self.log("train_accuracy", acc, logger=True, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y, weight=self.weights)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)

        self.log("val_loss", val_loss, logger=True, on_epoch=True)
        self.log("val_accuracy", acc, logger=True, on_epoch=True)

        return val_loss
    
model = TEST(
    N_features = data.num_features(),
    N_labels = data.num_labels(),
    weights=data.compute_class_weights(),
    params={
        'width' : 2,
        'layers': 2,
        'epochs': 10,
        'lr': 0.001,
        'momentum': 0,
        'weight_decay':0,
    }
)

Our cost function is converging on a small subset of the data, which is good! Now let's try this same training routine with PyTorch Lightning to make sure nothing is going awry there.

In [None]:
from pytorch_lightning import Trainer

run = Trainer()
run.fit(model, train, test)