# Modeling Sanity Check: Making sure everything is ok

In this notebook, we'll test our entire network pipeline because SURELY there are bugs.

In [1]:
import dask.dataframe as dd
import pandas as pd 
import torch
import linecache 
import csv
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sys
import torch
import pytorch_lightning as pl
sys.path.append('../src/')

Let's define our custom data class and make sure everything is being streamed in correctly

In [2]:
from models.train_neural_network import GeneExpressionData, GeneClassifier

In [3]:
data = GeneExpressionData(
    filename='../data/processed/primary.csv',
    labelname='../data/processed/meta_primary_labels.csv',
    class_label='Subtype'
)

In [4]:
model = GeneClassifier(
    N_features = data.num_features(),
    N_labels = data.num_labels(),
    weights=data.compute_class_weights(),
    params={
        'width' : 2,
        'layers': 2,
        'epochs': 10,
        'lr': 3e-5,
        'momentum': 1e-4,
        'weight_decay': 1e-4
    }
)

Now that we have our dataset, at least make sure a forward pass is computing correctly, and that our model can at least overfit on a small subset of the training set. Therefore, we'll subset our dataset and create the train and val loaders this way.

In [5]:
from torch.utils.data import Subset

tr_10k = Subset(data, range(10))

In [6]:
def train_test(data):
    train_size = int(0.80 * len(data))
    test_size = len(data) - train_size

    train, test = torch.utils.data.random_split(data, [train_size, test_size])

    traindata = DataLoader(train, batch_size=2, num_workers=0)
    valdata = DataLoader(test, batch_size=2, num_workers=0)
    
    return traindata, valdata

train, test = train_test(tr_10k)

In [7]:
len(train), len(test)

(4, 1)

Even though we'll ultimately be using PyTorch Lightning for GPU training, let's try writing the training loop here so we can debug each step. To do this, we'll need to redefine the optimizer and loss

In [8]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [9]:
for epoch in range(1000):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, sample in enumerate(train, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = sample

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 0: # print every 2000 mini-batches
            print(epoch, running_loss / 100)
            running_loss = 0.0

print('Finished Training')

0 0.033207635879516605
1 0.033177757263183595
2 0.033148369789123534
3 0.03225062370300293
4 0.03222107648849487
5 0.0340830397605896
6 0.03216262102127075
7 0.03299127817153931
8 0.03296057224273682
9 0.0329289984703064
10 0.03204559326171875
11 0.03286097526550293
12 0.031986949443817136
13 0.03346689701080322
14 0.03381027936935425
15 0.03276041030883789
16 0.03271931171417236
17 0.031839971542358396
18 0.032652158737182614
19 0.03264082908630371
20 0.03325051307678223
21 0.0325810718536377
22 0.03255018711090088
23 0.032519524097442624
24 0.031635284423828125
25 0.033479459285736084
26 0.031577076911926266
27 0.03240281581878662
28 0.032374165058135985
29 0.0333594012260437
30 0.03231409549713135
31 0.03143138170242309
32 0.032253971099853514
33 0.03137491703033447
34 0.03134473085403442
35 0.03131597757339478
36 0.03128791570663452
37 0.032105739116668704
38 0.031229381561279298
39 0.03204586505889893
40 0.031882760524749754
41 0.031135034561157227
42 0.03195031642913818
43 0.0310

342 0.022364938259124757
343 0.02233788013458252
344 0.023941285610198974
345 0.022991797924041747
346 0.023881411552429198
347 0.02221726894378662
348 0.02287290573120117
349 0.022845487594604492
350 0.02212477445602417
351 0.02278416633605957
352 0.02370215892791748
353 0.022044708728790285
354 0.022012364864349366
355 0.022666633129119873
356 0.021951615810394287
357 0.021918666362762452
358 0.021891570091247557
359 0.021867823600769044
360 0.02346337080001831
361 0.021800291538238526
362 0.021766881942749023
363 0.021735408306121827
364 0.022417070865631102
365 0.022348012924194336
366 0.021644773483276366
367 0.021613943576812743
368 0.02228957653045654
369 0.021550915241241454
370 0.02315016746520996
371 0.02149378776550293
372 0.022137186527252196
373 0.02210969924926758
374 0.02140225887298584
375 0.021374082565307616
376 0.02296356439590454
377 0.02198413372039795
378 0.021278908252716066
379 0.02191314458847046
380 0.021212582588195802
381 0.021181094646453857
382 0.021818718

676 0.012358992099761964
677 0.012334412336349488
678 0.012309898138046265
679 0.01228545069694519
680 0.012261065244674683
681 0.012236745357513427
682 0.012212493419647218
683 0.012188307046890258
684 0.012164181470870972
685 0.012140127420425416
686 0.012116137742996216
687 0.012092213630676269
688 0.012068355083465576
689 0.012044564485549927
690 0.012020840644836425
691 0.011997182369232178
692 0.011973590850830078
693 0.011950066089630127
694 0.011926610469818116
695 0.011903220415115356
696 0.011879900693893433
697 0.011856645345687866
698 0.011833460330963134
699 0.01181033968925476
700 0.011787290573120118
701 0.011764309406280517
702 0.01174139380455017
703 0.011718544960021973
704 0.0116957688331604
705 0.011673060655593872
706 0.011650418043136596
707 0.011627845764160157
708 0.01160534381866455
709 0.011582911014556885
710 0.011560548543930054
711 0.011538249254226685
712 0.011516021490097046
713 0.011493862867355346
714 0.011471773386001588
715 0.011449754238128662
716 0.

Something is probably wrong, since our cost function isn't converging, even with a very small train/val loader

In [53]:
from typing import *
from torchmetrics import Accuracy

class TEST(pl.LightningModule):
    def __init__(self, 
        N_features: int, 
        N_labels: int, 
        layers=2,
        width=1024,
    ):
        super(TEST, self).__init__()
        layers = layers*[
            nn.Linear(width, width),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.BatchNorm1d(width),
        ]

        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(N_features, width),
            nn.ReLU(),
            *layers,
            nn.Linear(width, N_labels),
        )

        self.accuracy = Accuracy(average='weighted', num_classes=N_labels)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [54]:
model = TEST(
    N_features = data.num_features(),
    N_labels = data.num_labels(),
)

In [55]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

for k, epoch in enumerate(range(100000)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, sample in enumerate(train, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = sample

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 0: # print 
            print(epoch, running_loss / 100)
            running_loss = 0.0

print('Finished Training')

0 0.03293801784515381
1 0.044853515625
2 0.03757031202316284
3 0.03323987007141113
4 0.03644559383392334
5 0.03299209117889404
6 0.03433607578277588
7 0.03255105972290039
8 0.03372215032577515
9 0.031756925582885745
10 0.037689692974090576
11 0.03528842687606812
12 0.03214954376220703
13 0.036893095970153805
14 0.03301583766937256
15 0.03560672998428345
16 0.028677701950073242
17 0.03517402648925781
18 0.03218102931976318
19 0.030492112636566163
20 0.032309136390686034
21 0.0330605411529541
22 0.02682661056518555
23 0.030409250259399414
24 0.03410139799118042
25 0.03389711856842041
26 0.02957784175872803
27 0.03209702491760254
28 0.031348743438720704
29 0.029141297340393068
30 0.030556657314300538
31 0.027296061515808104
32 0.028746213912963867
33 0.02895714044570923
34 0.028872203826904298
35 0.034731788635253905
36 0.03449397563934326
37 0.03328230381011963
38 0.02993844747543335
39 0.0232157564163208
40 0.03222908973693848
41 0.030746750831604004
42 0.030971498489379884
43 0.0281599

334 0.0018039412796497345
335 0.0033444014191627504
336 0.0009096585214138031
337 0.0023892217874526976
338 0.0017290091514587402
339 0.0017833843827247619
340 0.003009878695011139
341 0.003635408878326416
342 0.0015830060839653016
343 0.0031024092435836793
344 0.0037590292096138
345 0.000868474766612053
346 0.0023150569200515747
347 0.0016024169325828553
348 0.0019004002213478088
349 0.00287430077791214
350 0.0020556378364562987
351 0.0005921824648976326
352 0.0028123024106025698
353 0.0017230865359306336
354 0.001544734239578247
355 0.009020256996154784
356 0.005112286210060119
357 0.0016440656781196595
358 0.0016253237426280976
359 0.0006382230669260025
360 0.00703906238079071
361 0.001916084587574005
362 0.0017935961484909057
363 0.0007887876033782959
364 0.0008697514981031418
365 0.001658921241760254
366 0.0028033825755119325
367 0.0011922550201416016
368 0.0016173022985458375
369 0.0025128039717674257
370 0.0033849018812179567
371 0.0006813155859708786
372 0.0021956060826778413
3

654 0.00032979365438222884
655 0.0007357978820800781
656 0.0013643364608287812
657 0.003940569758415222
658 0.0005246769264340401
659 0.0011932282894849778
660 0.0014382383227348327
661 0.00022619422525167466
662 0.0017169933021068573
663 0.0008332099020481109
664 0.0004922162741422653
665 0.0005457992851734162
666 0.0009837263822555541
667 0.00047552458941936494
668 0.0002455732598900795
669 0.0008425668627023697
670 0.0004967732727527618
671 0.0011488518118858337
672 0.00040221504867076873
673 0.0005657364428043366
674 0.00040243905037641527
675 0.00043736826628446577
676 0.002789103388786316
677 0.00029478123411536217
678 0.0004341602697968483
679 0.0003037339448928833
680 0.0013181675970554351
681 0.0026980641484260558
682 0.0009571529924869537
683 0.0004125187546014786
684 0.0008973445743322373
685 0.00014202876016497612
686 0.0009159649908542633
687 0.0006589300185441971
688 0.0003310748934745789
689 0.001100209653377533
690 0.00023646429181098939
691 0.0002707640454173088
692 0.

968 0.0003068033792078495
969 0.00039302222430706026
970 0.00028149206191301345
971 0.00025759508833289146
972 0.0006992003321647644
973 0.0009348057210445404
974 0.0028185334801673887
975 0.0003106812387704849
976 0.0003751048073172569
977 6.461147218942642e-05
978 0.00021969547495245934
979 0.0018908849358558655
980 0.0007457595318555831
981 0.0005686936527490616
982 0.00034806262701749804
983 0.0008245741575956345
984 0.0002885007858276367
985 0.0012433966249227523
986 0.0003360601514577866
987 0.0007772120088338852
988 0.0007297714054584503
989 0.00011577671393752098
990 0.00027639318257570265
991 0.00027832437306642534
992 0.000514371171593666
993 0.0010178948193788528
994 0.00014750075526535512
995 0.0008123144507408142
996 0.00228926345705986
997 0.0005485118925571442
998 0.000935322642326355
999 0.00028572577983140943
1000 0.000125264972448349
1001 0.0006585969030857087
1002 0.0001810486987233162
1003 0.00014375759288668633
1004 0.00030718024820089343
1005 0.0002127713523805141

1270 0.0002994387038052082
1271 0.00022375214844942094
1272 0.001194804459810257
1273 0.00021220080554485321
1274 0.0004663629457354546
1275 9.665937162935734e-05
1276 0.000565032884478569
1277 0.0014268334209918976
1278 0.0005478523671627045
1279 0.0001955735683441162
1280 9.373882785439492e-05
1281 0.00019211944192647933
1282 0.0006180404126644134
1283 0.0002745657972991467
1284 0.0003673451766371727
1285 0.00027897635474801065
1286 0.0004050028324127197
1287 5.9923185035586356e-05
1288 0.0007238012552261353
1289 0.00010703170672059059
1290 5.219021812081337e-05
1291 4.16113156825304e-05
1292 0.00026006920263171195
1293 0.00019521251320838928
1294 0.0005247578769922256
1295 0.00010364471934735775
1296 0.0003485424444079399
1297 0.00043405979871749876
1298 0.00031775910407304763
1299 0.00011190858669579028
1300 0.0002537919953465462
1301 0.0002312827855348587
1302 8.973518386483192e-05
1303 0.0009563077241182327
1304 0.000250594187527895
1305 0.0002292044460773468
1306 0.0001002573594

1570 9.488415904343128e-05
1571 0.00011682363227009773
1572 0.0006081727147102356
1573 0.00011517653241753578
1574 0.00018295347690582274
1575 0.00010339384898543358
1576 8.362132124602795e-05
1577 0.00016482561826705933
1578 0.0002208838239312172
1579 0.0006018109992146492
1580 4.166674800217152e-05
1581 0.00015972012653946876
1582 0.00032294902950525284
1583 0.00019807728007435797
1584 0.00019946688786149025
1585 0.00028497375547885894
1586 0.00019999269396066665
1587 0.0003419845551252365
1588 0.0024483875930309297
1589 0.00019292308017611502
1590 0.0014789804816246032
1591 2.205904107540846e-05
1592 0.00018093740567564964
1593 0.00013511696830391883
1594 4.0442701429128644e-05
1595 0.00021089304238557815
1596 6.329326890408993e-05
1597 0.0002633197046816349
1598 3.302226308733225e-05
1599 0.0009076887369155884
1600 0.00018970929086208342
1601 0.000267583392560482
1602 6.998820696026087e-05
1603 0.0001181710883975029
1604 9.731241501867771e-05
1605 4.3081846088171004e-05
1606 0.0001

1870 8.781210519373417e-05
1871 4.266222473233938e-05
1872 0.0001573888212442398
1873 7.891922257840633e-05
1874 1.0034278966486454e-05
1875 7.410405669361353e-05
1876 8.256983011960983e-05
1877 0.00023286443203687668
1878 9.596183896064759e-05
1879 8.05390253663063e-05
1880 0.0002377861738204956
1881 0.00016492962837219238
1882 0.0005151508376002312
1883 2.897696802392602e-05
1884 4.1497391648590565e-05
1885 0.00036750223487615587
1886 8.039483800530434e-05
1887 0.0003295392543077469
1888 9.426104836165905e-05
1889 0.0003289288654923439
1890 0.0002563205733895302
1891 7.650138344615698e-05
1892 0.00011977646499872208
1893 0.0001060137152671814
1894 4.380539525300264e-05
1895 4.314201883971691e-05
1896 6.410415749996901e-05
1897 7.77190551161766e-05
1898 0.0005048053711652756
1899 8.424109779298306e-05
1900 8.89990571886301e-05
1901 3.555564675480127e-05
1902 0.00014939095824956895
1903 4.7866972163319585e-05
1904 9.103397838771344e-05
1905 0.0001223661098629236
1906 0.0002248023264110

2170 4.32915473356843e-05
2171 0.00030889036133885384
2172 0.00017669633030891418
2173 0.00034781888127326965
2174 0.00013798493891954423
2175 0.0006785459071397781
2176 0.00033357657492160797
2177 0.00012610074132680894
2178 4.222318064421415e-05
2179 0.00019812945276498795
2180 9.245615452527999e-05
2181 0.0005343940481543541
2182 0.00046576883643865584
2183 0.0002366589568555355
2184 0.00019763505086302756
2185 0.00040843844413757324
2186 0.00023893551900982857
2187 0.0003120744228363037
2188 0.00012733682990074158
2189 2.7445873711258174e-05
2190 5.867312196642161e-05
2191 2.8558718040585517e-05
2192 8.324584923684597e-05
2193 0.00014055250212550164
2194 0.0005516988039016723
2195 0.00025952834635972977
2196 0.00012722698040306567
2197 1.940750051289797e-05
2198 0.00019880468025803565
2199 1.025803852826357e-05
2200 0.00016048602759838104
2201 6.259850226342679e-05
2202 9.254885837435723e-05
2203 8.197073824703694e-05
2204 4.171667154878378e-05
2205 0.00022171400487422944
2206 4.01

2470 0.00010411121882498265
2471 0.00022225094959139825
2472 4.2435112409293655e-05
2473 0.0002197606861591339
2474 2.4101906456053256e-05
2475 9.512573014944792e-06
2476 7.304858416318894e-05
2477 0.000147769870236516
2478 2.3935793433338403e-05
2479 0.00023142991587519647
2480 4.9240649677813056e-05
2481 6.86139240860939e-05
2482 0.00010248444974422455
2483 2.9600099660456182e-05
2484 0.0002461232803761959
2485 2.5277873501181602e-05
2486 0.00012746474705636502
2487 0.0001555562112480402
2488 9.446617215871811e-05
2489 0.0007089269161224365
2490 8.45193862915039e-05
2491 8.775735273957253e-05
2492 0.00018990227952599524
2493 0.0002200971730053425
2494 8.4619065746665e-05
2495 3.61988740041852e-05
2496 0.0001977251097559929
2497 0.00021631479263305663
2498 8.395661599934101e-05
2499 5.418640561401844e-05
2500 4.0050526149570945e-05
2501 9.025178849697113e-05
2502 1.9716841634362937e-05
2503 5.2562509663403036e-05
2504 0.00013145660050213338
2505 7.32879899442196e-05
2506 5.59956580400

2772 0.0002326289564371109
2773 0.00011655551381409168
2774 8.447294123470784e-05
2775 0.00020934116095304488
2776 3.8669498171657326e-05
2777 0.0011669948697090148
2778 1.547182793729007e-05
2779 5.642538424581289e-05
2780 3.082570154219866e-05
2781 5.8453865349292756e-05
2782 7.027099840342999e-05
2783 6.254801526665687e-05
2784 3.899852745234966e-05
2785 2.4116458371281622e-05
2786 0.0002160458080470562
2787 5.479747429490089e-05
2788 2.6889685541391374e-05
2789 0.00015770168974995613
2790 3.576228162273764e-05
2791 9.551437571644783e-05
2792 9.834994561970235e-06
2793 0.0003740246593952179
2794 0.00023625604808330536
2795 3.7623164243996146e-05
2796 2.252334728837013e-05
2797 0.0006682906299829483
2798 1.2108276132494212e-05
2799 6.145913153886796e-05
2800 6.262449081987143e-05
2801 2.7605004142969846e-05
2802 8.250798098742962e-05
2803 1.6584928380325438e-05
2804 0.00025825059041380884
2805 0.0001434264425188303
2806 0.0002147473581135273
2807 4.800977185368538e-05
2808 5.75623475

3072 0.00015754887834191323
3073 1.9717251416295766e-05
3074 1.8291444284841418e-05
3075 0.0003021707385778427
3076 6.339550949633122e-05
3077 3.5340089816600084e-05
3078 0.00013832819648087024
3079 2.816846128553152e-05
3080 7.953240536153317e-05
3081 0.00023167245090007781
3082 6.1031635850667955e-05
3083 2.115731593221426e-05
3084 2.0623120944947005e-05
3085 3.803930943831801e-05
3086 1.6645288560539484e-05
3087 2.2310316562652588e-05
3088 0.0002319420874118805
3089 9.048041887581348e-05
3090 0.0003476947173476219
3091 0.00020413300022482873
3092 0.0001385125145316124
3093 5.859133787453175e-05
3094 1.9377223215997218e-05
3095 0.0003276212513446808
3096 0.00020389823243021965
3097 0.00014341222122311593
3098 9.262688457965851e-05
3099 0.00012519467622041703
3100 1.7102627316489817e-05
3101 8.375128731131553e-05
3102 5.4398258216679094e-05
3103 0.00018883951008319855
3104 0.00024835042655467986
3105 1.452817814424634e-05
3106 5.257799755781889e-05
3107 0.0003585170209407806
3108 2.11

3372 5.470859818160534e-05
3373 0.0006352733075618744
3374 5.456239450722933e-05
3375 0.00015276104211807252
3376 0.0001401587389409542
3377 0.00020475754514336585
3378 6.651911418884992e-05
3379 1.8505601910874248e-05
3380 4.388853441923857e-05
3381 4.612283315509558e-05
3382 6.741321180015802e-05
3383 3.467342350631952e-05
3384 1.3844312634319068e-05
3385 8.002257905900478e-05
3386 3.3236690796911716e-05
3387 6.26273825764656e-05
3388 0.00022938944399356843
3389 5.3877509199082855e-05
3390 1.9215308129787444e-05
3391 7.42340600118041e-05
3392 0.00019129492342472075
3393 1.4613269595429301e-05
3394 6.292944308370352e-05
3395 3.162039443850517e-05
3396 1.6748636262491346e-05
3397 0.00012596446089446544
3398 3.317772876471281e-05
3399 7.382409647107125e-05
3400 4.4750929810106754e-05
3401 4.9720574170351026e-05
3402 1.56102841719985e-05
3403 8.710768073797226e-05
3404 3.1032392289489504e-05
3405 0.00011441010981798172
3406 1.7205122858285904e-05
3407 5.943384487181902e-05
3408 1.4745111

KeyboardInterrupt: 

In [18]:
import json

with open('../src/genes.json', 'r') as f:
    genes = json.load(f)


JSONDecodeError: Expecting value: line 1 column 13 (char 12)

In [66]:
%%time 

with open('../src/genes.txt') as f:
    arr = eval(f.read())

CPU times: user 25.2 ms, sys: 3.92 ms, total: 29.2 ms
Wall time: 29.6 ms


In [75]:
import json

with open('test.json', 'w') as f:
    json.dump(arr, f)

In [78]:
%%time
with open('test.json', 'r') as f:
    arr = json.load(f)

CPU times: user 3.18 ms, sys: 1.45 ms, total: 4.63 ms
Wall time: 4.06 ms


In [77]:
print(arr)

['ttyh2', 'ankfy1', 'golm1', 'rp11-582j16.5', 'rp11-148l24.1', 'tsn', 'lrig3', 'eya2', 'fst', 'ccser1', 'prmt9', 'kiaa1958', 'mtif2', 'mt2a', 'acap2', 'kiaa0825', 'psme2', 'c11orf84', 'dvl1', 'c9orf72', 'slc25a20', 'gstm4', 'tyms', 'rspo1', 'enah', 'rad9a', 'arhgef40', 'tph1', 'fezf2', 'c1galt1', 'naaladl1', 'casc15', 'polr2j3', 'cep68', 'scn2b', 'asnsd1', 'c20orf194', 'hist1h3b', 'ngfrap1', 'fitm2', 'ptgis', 'linc01186', 'riok3', 'b3gnt8', 'umps', 'hdgf', 'svbp', 'helz', 'psme1', 'txn', 'acot13', 'mmel1', 'marc2', 'cdc42', 'fam118a', 'insig2', 'ccnt1', 'atp2a3', 'serpine2', 'scn3b', 'tunar', 'znf681', 'itgb1bp1', 'ffar4', 'slc24a4', 'ccdc14', 'gyg1', 'syt2', 'ccdc130', 'hist1h3j', 'tppp', 'znf267', 'kdsr', 'ac113404.1', 'ndufb8', 'rp11-96b5.3', 'c5orf56', 'atp8b2', 'c6orf141', 'timm23b', 'satb2', 'ppara', 'pdpr', 'trmt6', 'htra2', 'hlf', 'crebzf', 'tmem248', 'brcc3', 'c3orf14', 'znf551', 'serpinh1', 'podnl1', 'igfbp7', 'rarg', 'wls', 'rraga', 'rnf26', 'rasal2-as1', 'nck1-as1', 'cdk20'