In [1]:
### import software packages that we'll use
import os
import pickle
import numpy as np
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from torchsummary import summary
import wandb

In [2]:
#get data
data_path = 'dataset.pkl'
with open(data_path, 'rb') as handle:
    data = pickle.load(handle)

#remove unecessary column
data.drop(labels='cell',axis=1,inplace=True)
data = data.to_numpy()

In [11]:
embeddings = data[:,0]
new = []
for e in embeddings:
    new.append(np.asarray(e))

In [21]:
with open('X.pkl','wb') as handle:
    pickle.dump(new,handle)

In [17]:
with open('y.pkl','wb') as handle:
    pickle.dump(new,handle)

[array([-0.18930584,  0.20833486, -0.04554228, -0.04072357, -0.08874279,
         0.19690122, -0.01202217,  0.31777978, -0.14343381,  0.03894588,
        -0.02063541, -0.06787176,  0.14168671,  0.02597792,  0.08872835,
         0.14837307, -0.14171536, -0.10525542, -0.03470885,  0.08391233,
        -0.05429509, -0.20762084,  0.02723555, -0.07137816, -0.00861356,
         0.10153604,  0.02968225, -0.14749146, -0.14440762,  0.05250481,
        -0.10693254,  0.11649041, -0.14695837, -0.0568473 ,  0.0490993 ,
         0.05954179, -0.03064919, -0.40967634, -0.02195007, -0.17128149,
        -0.4638243 , -0.15289706, -0.03331146, -0.01543703, -0.11540428,
         0.09537826,  0.16973943,  0.05611289, -0.16244477, -0.1314123 ,
        -0.12735651,  0.11263308, -0.02883502, -0.02701937,  0.01163912,
         0.02279525,  0.07580364,  0.01141435, -0.18365163,  0.01741027,
        -0.00928673, -0.07021639, -0.13827209,  0.03561032]),
 array([ 1.1973336 , -0.00281745,  0.01887686,  0.25683272,  0

In [31]:
#define model class
"""
class Model(nn.Module):
     
    def __init__(self, input_size):
        super().__init__()
        
        #fully connected layers:
        self.l1 = nn.Linear(input_size, 16)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(16, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        x = self.sigmoid(x)
        return x
"""

In [48]:
#define Pytorch Lightning module (executes during training)
class Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = nn.Linear(64, 16)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(16, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        x = self.sigmoid(x)
        return x
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x = self(x)
        loss = F.binary_cross_entropy(x, y.reshape(-1,1).float())
        self.log("train_loss",loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = self(x)
        loss = F.binary_cross_entropy(x, y.reshape(-1,1).float())
        self.log("val_loss",loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [49]:
#define Pytorch mappable dataset
class PBMCDataset(Dataset):
    def __init__(self, data):
        self.embeddings = data[:,0]
        self.labels = data[:,1]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        embedding = self.embeddings[idx]
        label = self.labels[idx]
        
        return torch.tensor(embedding), torch.tensor(label)

In [50]:
#training and validation data split
train_set, val_set = torch.utils.data.random_split(PBMCDataset(data), [2000,834])

In [51]:
#initialize weights and biases logger
wandb_logger = WandbLogger(project='classifier')

In [52]:
#define dataloaders
train_loader = DataLoader(train_set)
valid_loader = DataLoader(val_set)

#instantiate pl.LightningModule
classifier = Classifier()

#train model
trainer = Trainer(logger=wandb_logger)
trainer.fit(classifier, train_loader, valid_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [4]

  | Name    | Type    | Params
------------------------------------
0 | l1      | Linear  | 1.0 K 
1 | relu    | ReLU    | 0     
2 | l2      | Linear  | 17    
3 | sigmoid | Sigmoid | 0     
------------------------------------
1.1 K     Trainable params
0         Non-trainable params
1.1 K     Total params
0.004     Total estimated model params size (MB)


Epoch 0: 100%|█| 2000/2000 [00:06<00:00, 330.96it/s, v_num=h5et
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                      | 0/834 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|         | 0/834 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%| | 1/834 [00:00<00:02, 307.23it/s[A
Validation DataLoader 0:   0%| | 2/834 [00:00<00:02, 390.19it/s[A
Validation DataLoader 0:   0%| | 3/834 [00:00<00:01, 422.61it/s[A
Validation DataLoader 0:   0%| | 4/834 [00:00<00:02, 365.55it/s[A
Validation DataLoader 0:   1%| | 5/834 [00:00<00:02, 338.43it/s[A
Validation DataLoader 0:   1%| | 6/834 [00:00<00:02, 328.15it/s[A
Validation DataLoader 0:   1%| | 7/834 [00:00<00:02, 327.36it/s[A
Validation DataLoader 0:   1%| | 8/834 [00:00<00:02, 325.63it/s[A
Validation DataLoader 0:   1%| | 9/834 [00:00<00:02, 326.95it/s[A
Validation DataLoader 0:   1%| | 10/834 [00:00<00:02, 326.63it/[A
Validation DataLoader 0:   1%| | 11/834 [00:00<00:02, 326.36it/[A
Validation DataLoader 0:   1%| 

Validation DataLoader 0:  29%|▎| 241/834 [00:00<00:01, 356.22it[A
Validation DataLoader 0:  29%|▎| 242/834 [00:00<00:01, 356.09it[A
Validation DataLoader 0:  29%|▎| 243/834 [00:00<00:01, 355.95it[A
Validation DataLoader 0:  29%|▎| 244/834 [00:00<00:01, 355.79it[A
Validation DataLoader 0:  29%|▎| 245/834 [00:00<00:01, 355.66it[A
Validation DataLoader 0:  29%|▎| 246/834 [00:00<00:01, 355.52it[A
Validation DataLoader 0:  30%|▎| 247/834 [00:00<00:01, 355.38it[A
Validation DataLoader 0:  30%|▎| 248/834 [00:00<00:01, 355.24it[A
Validation DataLoader 0:  30%|▎| 249/834 [00:00<00:01, 355.10it[A
Validation DataLoader 0:  30%|▎| 250/834 [00:00<00:01, 354.96it[A
Validation DataLoader 0:  30%|▎| 251/834 [00:00<00:01, 354.83it[A
Validation DataLoader 0:  30%|▎| 252/834 [00:00<00:01, 354.66it[A
Validation DataLoader 0:  30%|▎| 253/834 [00:00<00:01, 354.55it[A
Validation DataLoader 0:  30%|▎| 254/834 [00:00<00:01, 354.39it[A
Validation DataLoader 0:  31%|▎| 255/834 [00:00<00:01, 354.26i

Validation DataLoader 0:  58%|▌| 485/834 [00:01<00:01, 346.91it[A
Validation DataLoader 0:  58%|▌| 486/834 [00:01<00:01, 347.06it[A
Validation DataLoader 0:  58%|▌| 487/834 [00:01<00:00, 347.17it[A
Validation DataLoader 0:  59%|▌| 488/834 [00:01<00:00, 347.36it[A
Validation DataLoader 0:  59%|▌| 489/834 [00:01<00:00, 347.56it[A
Validation DataLoader 0:  59%|▌| 490/834 [00:01<00:00, 347.78it[A
Validation DataLoader 0:  59%|▌| 491/834 [00:01<00:00, 347.84it[A
Validation DataLoader 0:  59%|▌| 492/834 [00:01<00:00, 348.03it[A
Validation DataLoader 0:  59%|▌| 493/834 [00:01<00:00, 348.22it[A
Validation DataLoader 0:  59%|▌| 494/834 [00:01<00:00, 348.43it[A
Validation DataLoader 0:  59%|▌| 495/834 [00:01<00:00, 348.62it[A
Validation DataLoader 0:  59%|▌| 496/834 [00:01<00:00, 348.81it[A
Validation DataLoader 0:  60%|▌| 497/834 [00:01<00:00, 349.01it[A
Validation DataLoader 0:  60%|▌| 498/834 [00:01<00:00, 349.07it[A
Validation DataLoader 0:  60%|▌| 499/834 [00:01<00:00, 349.28i

Validation DataLoader 0:  87%|▊| 729/834 [00:02<00:00, 352.04it[A
Validation DataLoader 0:  88%|▉| 730/834 [00:02<00:00, 351.98it[A
Validation DataLoader 0:  88%|▉| 731/834 [00:02<00:00, 351.93it[A
Validation DataLoader 0:  88%|▉| 732/834 [00:02<00:00, 351.87it[A
Validation DataLoader 0:  88%|▉| 733/834 [00:02<00:00, 351.81it[A
Validation DataLoader 0:  88%|▉| 734/834 [00:02<00:00, 351.75it[A
Validation DataLoader 0:  88%|▉| 735/834 [00:02<00:00, 351.69it[A
Validation DataLoader 0:  88%|▉| 736/834 [00:02<00:00, 351.66it[A
Validation DataLoader 0:  88%|▉| 737/834 [00:02<00:00, 351.59it[A
Validation DataLoader 0:  88%|▉| 738/834 [00:02<00:00, 351.54it[A
Validation DataLoader 0:  89%|▉| 739/834 [00:02<00:00, 351.49it[A
Validation DataLoader 0:  89%|▉| 740/834 [00:02<00:00, 351.46it[A
Validation DataLoader 0:  89%|▉| 741/834 [00:02<00:00, 351.40it[A
Validation DataLoader 0:  89%|▉| 742/834 [00:02<00:00, 351.35it[A
Validation DataLoader 0:  89%|▉| 743/834 [00:02<00:00, 351.29i

Validation DataLoader 0:  16%|▏| 135/834 [00:00<00:01, 433.86it[A
Validation DataLoader 0:  16%|▏| 136/834 [00:00<00:01, 433.26it[A
Validation DataLoader 0:  16%|▏| 137/834 [00:00<00:01, 433.38it[A
Validation DataLoader 0:  17%|▏| 138/834 [00:00<00:01, 432.59it[A
Validation DataLoader 0:  17%|▏| 139/834 [00:00<00:01, 432.65it[A
Validation DataLoader 0:  17%|▏| 140/834 [00:00<00:01, 432.76it[A
Validation DataLoader 0:  17%|▏| 141/834 [00:00<00:01, 432.70it[A
Validation DataLoader 0:  17%|▏| 142/834 [00:00<00:01, 432.79it[A
Validation DataLoader 0:  17%|▏| 143/834 [00:00<00:01, 432.68it[A
Validation DataLoader 0:  17%|▏| 144/834 [00:00<00:01, 432.32it[A
Validation DataLoader 0:  17%|▏| 145/834 [00:00<00:01, 432.50it[A
Validation DataLoader 0:  18%|▏| 146/834 [00:00<00:01, 431.89it[A
Validation DataLoader 0:  18%|▏| 147/834 [00:00<00:01, 432.11it[A
Validation DataLoader 0:  18%|▏| 148/834 [00:00<00:01, 431.60it[A
Validation DataLoader 0:  18%|▏| 149/834 [00:00<00:01, 431.88i

Validation DataLoader 0:  45%|▍| 379/834 [00:00<00:01, 392.30it[A
Validation DataLoader 0:  46%|▍| 380/834 [00:00<00:01, 392.06it[A
Validation DataLoader 0:  46%|▍| 381/834 [00:00<00:01, 391.83it[A
Validation DataLoader 0:  46%|▍| 382/834 [00:00<00:01, 391.62it[A
Validation DataLoader 0:  46%|▍| 383/834 [00:00<00:01, 391.35it[A
Validation DataLoader 0:  46%|▍| 384/834 [00:00<00:01, 391.12it[A
Validation DataLoader 0:  46%|▍| 385/834 [00:00<00:01, 390.87it[A
Validation DataLoader 0:  46%|▍| 386/834 [00:00<00:01, 390.63it[A
Validation DataLoader 0:  46%|▍| 387/834 [00:00<00:01, 390.45it[A
Validation DataLoader 0:  47%|▍| 388/834 [00:00<00:01, 390.17it[A
Validation DataLoader 0:  47%|▍| 389/834 [00:00<00:01, 389.94it[A
Validation DataLoader 0:  47%|▍| 390/834 [00:01<00:01, 389.73it[A
Validation DataLoader 0:  47%|▍| 391/834 [00:01<00:01, 389.50it[A
Validation DataLoader 0:  47%|▍| 392/834 [00:01<00:01, 389.28it[A
Validation DataLoader 0:  47%|▍| 393/834 [00:01<00:01, 389.09i

Validation DataLoader 0:  75%|▋| 623/834 [00:01<00:00, 366.03it[A
Validation DataLoader 0:  75%|▋| 624/834 [00:01<00:00, 365.95it[A
Validation DataLoader 0:  75%|▋| 625/834 [00:01<00:00, 365.88it[A
Validation DataLoader 0:  75%|▊| 626/834 [00:01<00:00, 365.82it[A
Validation DataLoader 0:  75%|▊| 627/834 [00:01<00:00, 365.73it[A
Validation DataLoader 0:  75%|▊| 628/834 [00:01<00:00, 365.65it[A
Validation DataLoader 0:  75%|▊| 629/834 [00:01<00:00, 365.56it[A
Validation DataLoader 0:  76%|▊| 630/834 [00:01<00:00, 365.48it[A
Validation DataLoader 0:  76%|▊| 631/834 [00:01<00:00, 365.39it[A
Validation DataLoader 0:  76%|▊| 632/834 [00:01<00:00, 365.29it[A
Validation DataLoader 0:  76%|▊| 633/834 [00:01<00:00, 365.21it[A
Validation DataLoader 0:  76%|▊| 634/834 [00:01<00:00, 365.15it[A
Validation DataLoader 0:  76%|▊| 635/834 [00:01<00:00, 365.06it[A
Validation DataLoader 0:  76%|▊| 636/834 [00:01<00:00, 364.98it[A
Validation DataLoader 0:  76%|▊| 637/834 [00:01<00:00, 364.79i

Validation DataLoader 0:   3%| | 29/834 [00:00<00:02, 310.56it/[A
Validation DataLoader 0:   4%| | 30/834 [00:00<00:02, 310.50it/[A
Validation DataLoader 0:   4%| | 31/834 [00:00<00:02, 310.82it/[A
Validation DataLoader 0:   4%| | 32/834 [00:00<00:02, 310.98it/[A
Validation DataLoader 0:   4%| | 33/834 [00:00<00:02, 311.15it/[A
Validation DataLoader 0:   4%| | 34/834 [00:00<00:02, 311.67it/[A
Validation DataLoader 0:   4%| | 35/834 [00:00<00:02, 312.07it/[A
Validation DataLoader 0:   4%| | 36/834 [00:00<00:02, 311.64it/[A
Validation DataLoader 0:   4%| | 37/834 [00:00<00:02, 311.88it/[A
Validation DataLoader 0:   5%| | 38/834 [00:00<00:02, 312.17it/[A
Validation DataLoader 0:   5%| | 39/834 [00:00<00:02, 312.69it/[A
Validation DataLoader 0:   5%| | 40/834 [00:00<00:02, 312.71it/[A
Validation DataLoader 0:   5%| | 41/834 [00:00<00:02, 313.04it/[A
Validation DataLoader 0:   5%| | 42/834 [00:00<00:02, 313.47it/[A
Validation DataLoader 0:   5%| | 43/834 [00:00<00:02, 313.44it

Validation DataLoader 0:  33%|▎| 273/834 [00:00<00:01, 318.45it[A
Validation DataLoader 0:  33%|▎| 274/834 [00:00<00:01, 318.22it[A
Validation DataLoader 0:  33%|▎| 275/834 [00:00<00:01, 318.04it[A
Validation DataLoader 0:  33%|▎| 276/834 [00:00<00:01, 317.87it[A
Validation DataLoader 0:  33%|▎| 277/834 [00:00<00:01, 317.82it[A
Validation DataLoader 0:  33%|▎| 278/834 [00:00<00:01, 317.77it[A
Validation DataLoader 0:  33%|▎| 279/834 [00:00<00:01, 317.72it[A
Validation DataLoader 0:  34%|▎| 280/834 [00:00<00:01, 317.66it[A
Validation DataLoader 0:  34%|▎| 281/834 [00:00<00:01, 317.63it[A
Validation DataLoader 0:  34%|▎| 282/834 [00:00<00:01, 317.54it[A
Validation DataLoader 0:  34%|▎| 283/834 [00:00<00:01, 317.48it[A
Validation DataLoader 0:  34%|▎| 284/834 [00:00<00:01, 317.36it[A
Validation DataLoader 0:  34%|▎| 285/834 [00:00<00:01, 317.07it[A
Validation DataLoader 0:  34%|▎| 286/834 [00:00<00:01, 316.85it[A
Validation DataLoader 0:  34%|▎| 287/834 [00:00<00:01, 316.62i

Validation DataLoader 0:  62%|▌| 517/834 [00:01<00:00, 324.47it[A
Validation DataLoader 0:  62%|▌| 518/834 [00:01<00:00, 324.49it[A
Validation DataLoader 0:  62%|▌| 519/834 [00:01<00:00, 324.47it[A
Validation DataLoader 0:  62%|▌| 520/834 [00:01<00:00, 324.46it[A
Validation DataLoader 0:  62%|▌| 521/834 [00:01<00:00, 324.38it[A
Validation DataLoader 0:  63%|▋| 522/834 [00:01<00:00, 324.27it[A
Validation DataLoader 0:  63%|▋| 523/834 [00:01<00:00, 324.18it[A
Validation DataLoader 0:  63%|▋| 524/834 [00:01<00:00, 324.09it[A
Validation DataLoader 0:  63%|▋| 525/834 [00:01<00:00, 324.01it[A
Validation DataLoader 0:  63%|▋| 526/834 [00:01<00:00, 324.13it[A
Validation DataLoader 0:  63%|▋| 527/834 [00:01<00:00, 324.37it[A
Validation DataLoader 0:  63%|▋| 528/834 [00:01<00:00, 324.61it[A
Validation DataLoader 0:  63%|▋| 529/834 [00:01<00:00, 324.84it[A
Validation DataLoader 0:  64%|▋| 530/834 [00:01<00:00, 325.03it[A
Validation DataLoader 0:  64%|▋| 531/834 [00:01<00:00, 325.01i

Validation DataLoader 0:  91%|▉| 761/834 [00:02<00:00, 322.88it[A
Validation DataLoader 0:  91%|▉| 762/834 [00:02<00:00, 322.87it[A
Validation DataLoader 0:  91%|▉| 763/834 [00:02<00:00, 322.86it[A
Validation DataLoader 0:  92%|▉| 764/834 [00:02<00:00, 322.84it[A
Validation DataLoader 0:  92%|▉| 765/834 [00:02<00:00, 322.84it[A
Validation DataLoader 0:  92%|▉| 766/834 [00:02<00:00, 322.82it[A
Validation DataLoader 0:  92%|▉| 767/834 [00:02<00:00, 322.79it[A
Validation DataLoader 0:  92%|▉| 768/834 [00:02<00:00, 322.77it[A
Validation DataLoader 0:  92%|▉| 769/834 [00:02<00:00, 322.76it[A
Validation DataLoader 0:  92%|▉| 770/834 [00:02<00:00, 322.74it[A
Validation DataLoader 0:  92%|▉| 771/834 [00:02<00:00, 322.72it[A
Validation DataLoader 0:  93%|▉| 772/834 [00:02<00:00, 322.71it[A
Validation DataLoader 0:  93%|▉| 773/834 [00:02<00:00, 322.68it[A
Validation DataLoader 0:  93%|▉| 774/834 [00:02<00:00, 322.66it[A
Validation DataLoader 0:  93%|▉| 775/834 [00:02<00:00, 322.63i

Validation DataLoader 0:  20%|▏| 167/834 [00:00<00:02, 315.34it[A
Validation DataLoader 0:  20%|▏| 168/834 [00:00<00:02, 315.69it[A
Validation DataLoader 0:  20%|▏| 169/834 [00:00<00:02, 315.90it[A
Validation DataLoader 0:  20%|▏| 170/834 [00:00<00:02, 316.14it[A
Validation DataLoader 0:  21%|▏| 171/834 [00:00<00:02, 316.40it[A
Validation DataLoader 0:  21%|▏| 172/834 [00:00<00:02, 316.75it[A
Validation DataLoader 0:  21%|▏| 173/834 [00:00<00:02, 316.98it[A
Validation DataLoader 0:  21%|▏| 174/834 [00:00<00:02, 317.24it[A
Validation DataLoader 0:  21%|▏| 175/834 [00:00<00:02, 317.44it[A
Validation DataLoader 0:  21%|▏| 176/834 [00:00<00:02, 317.75it[A
Validation DataLoader 0:  21%|▏| 177/834 [00:00<00:02, 318.03it[A
Validation DataLoader 0:  21%|▏| 178/834 [00:00<00:02, 318.29it[A
Validation DataLoader 0:  21%|▏| 179/834 [00:00<00:02, 318.56it[A
Validation DataLoader 0:  22%|▏| 180/834 [00:00<00:02, 318.77it[A
Validation DataLoader 0:  22%|▏| 181/834 [00:00<00:02, 319.09i

Validation DataLoader 0:  49%|▍| 411/834 [00:01<00:01, 325.26it[A
Validation DataLoader 0:  49%|▍| 412/834 [00:01<00:01, 325.37it[A
Validation DataLoader 0:  50%|▍| 413/834 [00:01<00:01, 325.57it[A
Validation DataLoader 0:  50%|▍| 414/834 [00:01<00:01, 325.71it[A
Validation DataLoader 0:  50%|▍| 415/834 [00:01<00:01, 325.98it[A
Validation DataLoader 0:  50%|▍| 416/834 [00:01<00:01, 326.27it[A
Validation DataLoader 0:  50%|▌| 417/834 [00:01<00:01, 326.54it[A
Validation DataLoader 0:  50%|▌| 418/834 [00:01<00:01, 326.64it[A
Validation DataLoader 0:  50%|▌| 419/834 [00:01<00:01, 326.77it[A
Validation DataLoader 0:  50%|▌| 420/834 [00:01<00:01, 327.03it[A
Validation DataLoader 0:  50%|▌| 421/834 [00:01<00:01, 327.12it[A
Validation DataLoader 0:  51%|▌| 422/834 [00:01<00:01, 327.34it[A
Validation DataLoader 0:  51%|▌| 423/834 [00:01<00:01, 327.59it[A
Validation DataLoader 0:  51%|▌| 424/834 [00:01<00:01, 327.82it[A
Validation DataLoader 0:  51%|▌| 425/834 [00:01<00:01, 328.08i

Validation DataLoader 0:  79%|▊| 655/834 [00:01<00:00, 329.65it[A
Validation DataLoader 0:  79%|▊| 656/834 [00:01<00:00, 329.56it[A
Validation DataLoader 0:  79%|▊| 657/834 [00:01<00:00, 329.52it[A
Validation DataLoader 0:  79%|▊| 658/834 [00:01<00:00, 329.48it[A
Validation DataLoader 0:  79%|▊| 659/834 [00:02<00:00, 329.44it[A
Validation DataLoader 0:  79%|▊| 660/834 [00:02<00:00, 329.40it[A
Validation DataLoader 0:  79%|▊| 661/834 [00:02<00:00, 329.36it[A
Validation DataLoader 0:  79%|▊| 662/834 [00:02<00:00, 329.33it[A
Validation DataLoader 0:  79%|▊| 663/834 [00:02<00:00, 329.30it[A
Validation DataLoader 0:  80%|▊| 664/834 [00:02<00:00, 329.26it[A
Validation DataLoader 0:  80%|▊| 665/834 [00:02<00:00, 329.23it[A
Validation DataLoader 0:  80%|▊| 666/834 [00:02<00:00, 329.19it[A
Validation DataLoader 0:  80%|▊| 667/834 [00:02<00:00, 329.15it[A
Validation DataLoader 0:  80%|▊| 668/834 [00:02<00:00, 329.12it[A
Validation DataLoader 0:  80%|▊| 669/834 [00:02<00:00, 329.09i

Validation DataLoader 0:   7%| | 61/834 [00:00<00:01, 394.79it/[A
Validation DataLoader 0:   7%| | 62/834 [00:00<00:01, 396.38it/[A
Validation DataLoader 0:   8%| | 63/834 [00:00<00:01, 397.75it/[A
Validation DataLoader 0:   8%| | 64/834 [00:00<00:01, 399.19it/[A
Validation DataLoader 0:   8%| | 65/834 [00:00<00:01, 400.83it/[A
Validation DataLoader 0:   8%| | 66/834 [00:00<00:01, 402.28it/[A
Validation DataLoader 0:   8%| | 67/834 [00:00<00:01, 403.56it/[A
Validation DataLoader 0:   8%| | 68/834 [00:00<00:01, 404.93it/[A
Validation DataLoader 0:   8%| | 69/834 [00:00<00:01, 406.39it/[A
Validation DataLoader 0:   8%| | 70/834 [00:00<00:01, 407.69it/[A
Validation DataLoader 0:   9%| | 71/834 [00:00<00:01, 408.25it/[A
Validation DataLoader 0:   9%| | 72/834 [00:00<00:01, 407.00it/[A
Validation DataLoader 0:   9%| | 73/834 [00:00<00:01, 406.16it/[A
Validation DataLoader 0:   9%| | 74/834 [00:00<00:01, 405.13it/[A
Validation DataLoader 0:   9%| | 75/834 [00:00<00:01, 404.20it

Validation DataLoader 0:  37%|▎| 305/834 [00:00<00:01, 341.99it[A
Validation DataLoader 0:  37%|▎| 306/834 [00:00<00:01, 341.92it[A
Validation DataLoader 0:  37%|▎| 307/834 [00:00<00:01, 341.64it[A
Validation DataLoader 0:  37%|▎| 308/834 [00:00<00:01, 341.40it[A
Validation DataLoader 0:  37%|▎| 309/834 [00:00<00:01, 341.22it[A
Validation DataLoader 0:  37%|▎| 310/834 [00:00<00:01, 341.08it[A
Validation DataLoader 0:  37%|▎| 311/834 [00:00<00:01, 340.95it[A
Validation DataLoader 0:  37%|▎| 312/834 [00:00<00:01, 340.83it[A
Validation DataLoader 0:  38%|▍| 313/834 [00:00<00:01, 340.66it[A
Validation DataLoader 0:  38%|▍| 314/834 [00:00<00:01, 340.51it[A
Validation DataLoader 0:  38%|▍| 315/834 [00:00<00:01, 340.39it[A
Validation DataLoader 0:  38%|▍| 316/834 [00:00<00:01, 340.25it[A
Validation DataLoader 0:  38%|▍| 317/834 [00:00<00:01, 340.11it[A
Validation DataLoader 0:  38%|▍| 318/834 [00:00<00:01, 339.99it[A
Validation DataLoader 0:  38%|▍| 319/834 [00:00<00:01, 339.89i

Validation DataLoader 0:  66%|▋| 549/834 [00:01<00:00, 336.87it[A
Validation DataLoader 0:  66%|▋| 550/834 [00:01<00:00, 336.85it[A
Validation DataLoader 0:  66%|▋| 551/834 [00:01<00:00, 336.82it[A
Validation DataLoader 0:  66%|▋| 552/834 [00:01<00:00, 336.78it[A
Validation DataLoader 0:  66%|▋| 553/834 [00:01<00:00, 336.74it[A
Validation DataLoader 0:  66%|▋| 554/834 [00:01<00:00, 336.69it[A
Validation DataLoader 0:  67%|▋| 555/834 [00:01<00:00, 336.66it[A
Validation DataLoader 0:  67%|▋| 556/834 [00:01<00:00, 336.62it[A
Validation DataLoader 0:  67%|▋| 557/834 [00:01<00:00, 336.58it[A
Validation DataLoader 0:  67%|▋| 558/834 [00:01<00:00, 336.54it[A
Validation DataLoader 0:  67%|▋| 559/834 [00:01<00:00, 336.51it[A
Validation DataLoader 0:  67%|▋| 560/834 [00:01<00:00, 336.46it[A
Validation DataLoader 0:  67%|▋| 561/834 [00:01<00:00, 336.43it[A
Validation DataLoader 0:  67%|▋| 562/834 [00:01<00:00, 336.40it[A
Validation DataLoader 0:  68%|▋| 563/834 [00:01<00:00, 336.36i

Validation DataLoader 0:  95%|▉| 793/834 [00:02<00:00, 336.01it[A
Validation DataLoader 0:  95%|▉| 794/834 [00:02<00:00, 335.99it[A
Validation DataLoader 0:  95%|▉| 795/834 [00:02<00:00, 335.94it[A
Validation DataLoader 0:  95%|▉| 796/834 [00:02<00:00, 335.90it[A
Validation DataLoader 0:  96%|▉| 797/834 [00:02<00:00, 335.87it[A
Validation DataLoader 0:  96%|▉| 798/834 [00:02<00:00, 335.84it[A
Validation DataLoader 0:  96%|▉| 799/834 [00:02<00:00, 335.79it[A
Validation DataLoader 0:  96%|▉| 800/834 [00:02<00:00, 335.76it[A
Validation DataLoader 0:  96%|▉| 801/834 [00:02<00:00, 335.73it[A
Validation DataLoader 0:  96%|▉| 802/834 [00:02<00:00, 335.69it[A
Validation DataLoader 0:  96%|▉| 803/834 [00:02<00:00, 335.65it[A
Validation DataLoader 0:  96%|▉| 804/834 [00:02<00:00, 335.61it[A
Validation DataLoader 0:  97%|▉| 805/834 [00:02<00:00, 335.58it[A
Validation DataLoader 0:  97%|▉| 806/834 [00:02<00:00, 335.56it[A
Validation DataLoader 0:  97%|▉| 807/834 [00:02<00:00, 335.56i

Validation DataLoader 0:  24%|▏| 199/834 [00:00<00:01, 378.03it[A
Validation DataLoader 0:  24%|▏| 200/834 [00:00<00:01, 377.46it[A
Validation DataLoader 0:  24%|▏| 201/834 [00:00<00:01, 376.89it[A
Validation DataLoader 0:  24%|▏| 202/834 [00:00<00:01, 376.33it[A
Validation DataLoader 0:  24%|▏| 203/834 [00:00<00:01, 375.82it[A
Validation DataLoader 0:  24%|▏| 204/834 [00:00<00:01, 375.38it[A
Validation DataLoader 0:  25%|▏| 205/834 [00:00<00:01, 374.82it[A
Validation DataLoader 0:  25%|▏| 206/834 [00:00<00:01, 372.04it[A
Validation DataLoader 0:  25%|▏| 207/834 [00:00<00:01, 371.54it[A
Validation DataLoader 0:  25%|▏| 208/834 [00:00<00:01, 371.04it[A
Validation DataLoader 0:  25%|▎| 209/834 [00:00<00:01, 370.36it[A
Validation DataLoader 0:  25%|▎| 210/834 [00:00<00:01, 369.89it[A
Validation DataLoader 0:  25%|▎| 211/834 [00:00<00:01, 369.48it[A
Validation DataLoader 0:  25%|▎| 212/834 [00:00<00:01, 369.09it[A
Validation DataLoader 0:  26%|▎| 213/834 [00:00<00:01, 368.63i

Validation DataLoader 0:  53%|▌| 443/834 [00:01<00:00, 418.48it[A
Validation DataLoader 0:  53%|▌| 444/834 [00:01<00:00, 418.65it[A
Validation DataLoader 0:  53%|▌| 445/834 [00:01<00:00, 418.85it[A
Validation DataLoader 0:  53%|▌| 446/834 [00:01<00:00, 419.03it[A
Validation DataLoader 0:  54%|▌| 447/834 [00:01<00:00, 419.23it[A
Validation DataLoader 0:  54%|▌| 448/834 [00:01<00:00, 419.43it[A
Validation DataLoader 0:  54%|▌| 449/834 [00:01<00:00, 419.66it[A
Validation DataLoader 0:  54%|▌| 450/834 [00:01<00:00, 419.82it[A
Validation DataLoader 0:  54%|▌| 451/834 [00:01<00:00, 419.98it[A
Validation DataLoader 0:  54%|▌| 452/834 [00:01<00:00, 420.17it[A
Validation DataLoader 0:  54%|▌| 453/834 [00:01<00:00, 420.37it[A
Validation DataLoader 0:  54%|▌| 454/834 [00:01<00:00, 420.56it[A
Validation DataLoader 0:  55%|▌| 455/834 [00:01<00:00, 420.60it[A
Validation DataLoader 0:  55%|▌| 456/834 [00:01<00:00, 420.77it[A
Validation DataLoader 0:  55%|▌| 457/834 [00:01<00:00, 420.96i

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 6: 100%|█| 2000/2000 [00:05<00:00, 337.47it/s, v_num=h5et
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                      | 0/834 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|         | 0/834 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%| | 1/834 [00:00<00:02, 353.50it/s[A
Validation DataLoader 0:   0%| | 2/834 [00:00<00:02, 399.76it/s[A
Validation DataLoader 0:   0%| | 3/834 [00:00<00:01, 439.26it/s[A
Validation DataLoader 0:   0%| | 4/834 [00:00<00:01, 457.88it/s[A
Validation DataLoader 0:   1%| | 5/834 [00:00<00:01, 472.85it/s[A
Validation DataLoader 0:   1%| | 6/834 [00:00<00:01, 481.11it/s[A
Validation DataLoader 0:   1%| | 7/834 [00:00<00:01, 486.14it/s[A
Validation DataLoader 0:   1%| | 8/834 [00:00<00:01, 492.71it/s[A
Validation DataLoader 0:   1%| | 9/834 [00:00<00:01, 496.30it/s[A
Validation DataLoader 0:   1%| | 10/834 [00:00<00:01, 499.09it/[A
Validation DataLoader 0:   1%| | 11/834 [00:00<00:01, 502.04it/[A
Validation DataLoader 0:   1%| 

Validation DataLoader 0:  29%|▎| 241/834 [00:00<00:01, 484.18it[A
Validation DataLoader 0:  29%|▎| 242/834 [00:00<00:01, 483.06it[A
Validation DataLoader 0:  29%|▎| 243/834 [00:00<00:01, 482.05it[A
Validation DataLoader 0:  29%|▎| 244/834 [00:00<00:01, 481.00it[A
Validation DataLoader 0:  29%|▎| 245/834 [00:00<00:01, 479.96it[A
Validation DataLoader 0:  29%|▎| 246/834 [00:00<00:01, 478.92it[A
Validation DataLoader 0:  30%|▎| 247/834 [00:00<00:01, 477.97it[A
Validation DataLoader 0:  30%|▎| 248/834 [00:00<00:01, 476.97it[A
Validation DataLoader 0:  30%|▎| 249/834 [00:00<00:01, 475.98it[A
Validation DataLoader 0:  30%|▎| 250/834 [00:00<00:01, 475.04it[A
Validation DataLoader 0:  30%|▎| 251/834 [00:00<00:01, 474.05it[A
Validation DataLoader 0:  30%|▎| 252/834 [00:00<00:01, 473.15it[A
Validation DataLoader 0:  30%|▎| 253/834 [00:00<00:01, 472.17it[A
Validation DataLoader 0:  30%|▎| 254/834 [00:00<00:01, 470.69it[A
Validation DataLoader 0:  31%|▎| 255/834 [00:00<00:01, 469.36i

Validation DataLoader 0:  58%|▌| 485/834 [00:01<00:00, 390.83it[A
Validation DataLoader 0:  58%|▌| 486/834 [00:01<00:00, 390.66it[A
Validation DataLoader 0:  58%|▌| 487/834 [00:01<00:00, 390.48it[A
Validation DataLoader 0:  59%|▌| 488/834 [00:01<00:00, 390.31it[A
Validation DataLoader 0:  59%|▌| 489/834 [00:01<00:00, 390.16it[A
Validation DataLoader 0:  59%|▌| 490/834 [00:01<00:00, 389.98it[A
Validation DataLoader 0:  59%|▌| 491/834 [00:01<00:00, 389.81it[A
Validation DataLoader 0:  59%|▌| 492/834 [00:01<00:00, 389.66it[A
Validation DataLoader 0:  59%|▌| 493/834 [00:01<00:00, 389.49it[A
Validation DataLoader 0:  59%|▌| 494/834 [00:01<00:00, 389.31it[A
Validation DataLoader 0:  59%|▌| 495/834 [00:01<00:00, 389.15it[A
Validation DataLoader 0:  59%|▌| 496/834 [00:01<00:00, 388.98it[A
Validation DataLoader 0:  60%|▌| 497/834 [00:01<00:00, 388.82it[A
Validation DataLoader 0:  60%|▌| 498/834 [00:01<00:00, 388.66it[A
Validation DataLoader 0:  60%|▌| 499/834 [00:01<00:00, 388.51i

In [72]:
#after training, predict with the model
model = Classifier()
#load trained model from checkpoint
model = model.load_from_checkpoint("classifier/ihz4h5et/checkpoints/epoch=5-step=12000.ckpt").cuda()
model.eval()

y_hat = []
for d in range(len(data[:,0])):
    with torch.no_grad():
        pred = model(torch.tensor(data[d,0]).cuda())
        y_hat.append(pred.cpu().numpy()[0])

In [75]:
y_hat

[0.52879834,
 0.6483393,
 0.45728728,
 0.4032106,
 0.0016445473,
 0.042009585,
 0.6578892,
 0.790471,
 0.075764984,
 0.012288238,
 0.0041661374,
 0.35567552,
 0.0039904416,
 0.7432866,
 3.705965e-05,
 0.0019764774,
 0.59204656,
 0.7826386,
 0.3586085,
 0.09210616,
 0.0060447655,
 0.5428829,
 0.39771703,
 0.2300884,
 0.7108333,
 2.9001371e-06,
 0.007382021,
 0.0070047304,
 0.77353716,
 0.0072345617,
 0.43571988,
 0.0042139087,
 0.51941496,
 0.7713114,
 0.8272231,
 0.065245226,
 0.03863523,
 0.006525429,
 0.073173486,
 0.0033161067,
 0.6728593,
 0.41580087,
 0.01721523,
 0.7994674,
 0.002706018,
 0.0009948802,
 0.08782581,
 0.8460949,
 0.5053517,
 2.4307296e-06,
 0.3967664,
 0.011632122,
 0.0688197,
 0.00832421,
 0.0011835509,
 0.732456,
 0.055329926,
 0.6027824,
 0.3395144,
 0.5788254,
 0.040473726,
 1.9964402e-05,
 0.6885024,
 0.7944049,
 0.06885471,
 0.35941705,
 0.8016028,
 0.004677476,
 0.047347393,
 0.00015167957,
 0.74756205,
 0.7749744,
 0.1704944,
 0.003018222,
 0.011591476,
 0.