In [1]:
import numpy as np
import torch
from src.dataloaders import CoLADataModule
from sentence_transformers import SentenceTransformer

from torch import optim, nn, utils, Tensor
from torchmetrics.classification import BinaryAccuracy
import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
available = torch.cuda.is_available()
curr_device = torch.cuda.current_device()
device = torch.device("cuda:0" if available else "cpu") 
device_count = torch.cuda.device_count() 
device_name =  torch.cuda.get_device_name(0)

print(f'Cuda available: {available}')
print(f'Current device: {curr_device}')
print(f'Device: {device}')
print(f'Device count: {device_count}')
print(f'Device name: {device_name}')


Cuda available: True
Current device: 0
Device: cuda:0
Device count: 1
Device name: NVIDIA GeForce RTX 3090


In [3]:
sentence_encoder = SentenceTransformer('all-mpnet-base-v2')
dm = CoLADataModule(data_dir='./glue_data/CoLA/', batch_size=1000, sentence_encoder=sentence_encoder)

#dm.setup(stage='fit')
#dm.setup(stage='validate')
#dm.setup(stage='test')
#
#for i, (x, y) in enumerate(dm.train):
#    if i < 10:
#        print(f'i: x = {x}, y = {y}')

#sentence_encoder.encode(dm.train.x_train.flatten())

In [9]:
# define the LightningModule
class CoLAClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(nn.Linear(768, 768), 
                                     nn.ReLU(),
                                     nn.Linear(768, 1),
                                     nn.Sigmoid())
        self.accuracy = BinaryAccuracy()

    def forward(self, x):
        y = self.network(x)
        return y

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, targets = batch
        y = self.network(x)
        print(x.shape, y.shape)

        loss = nn.functional.binary_cross_entropy(y, targets)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, targets = batch
        y = self.network(x)
        val_loss = nn.functional.binary_cross_entropy(y, targets)
        self.log("test_accuracy", self.accuracy(y, targets))

    def test_step(self, batch, batch_idx):
        # this is the validation loop
        x, targets = batch
        y = self.network(x)
        metric = BinaryAccuracy()
        self.log("test_accuracy", self.accuracy(y, targets))

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

cls = CoLAClassifier()

In [10]:
dm.setup(stage='fit')
dm.setup(stage='validate')
dm.setup(stage='test')

trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=20)
trainer.fit(model=cls, datamodule=dm)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type           | Params
--------------------------------------------
0 | network  | Sequential     | 591 K 
1 | accuracy | BinaryAccuracy | 0     
--------------------------------------------
591 K     Trainable params
0         Non-trainable params
591 K     Total params
2.365     Total estimated model params size (MB)


                                                                            

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/11 [00:00<?, ?it/s] torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:   9%|▉         | 1/11 [00:00<00:00, 127.43it/s, loss=0.686, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  18%|█▊        | 2/11 [00:00<00:00, 141.25it/s, loss=0.683, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  27%|██▋       | 3/11 [00:00<00:00, 149.85it/s, loss=0.678, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  36%|███▋      | 4/11 [00:00<00:00, 150.41it/s, loss=0.671, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  45%|████▌     | 5/11 [00:00<00:00, 149.25it/s, loss=0.668, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  55%|█████▍    | 6/11 [00:00<00:00, 150.43it/s, loss=0.661, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  64%|██████▎   | 7/11 [00:00<00:00, 148.02it/s, loss=0.653, v_num=35]torch.Size([1000, 768]) torch.Size([1000, 1])
Epoch 0:  73%|███████▎  | 8/11 [00:00<

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 11/11 [00:00<00:00, 127.80it/s, loss=0.483, v_num=35]


In [11]:
trainer.predict(model=cls, datamodule=dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 589.13it/s]


[tensor([[0.6881],
         [0.7564],
         [0.9387],
         [0.7084],
         [0.6955],
         [0.6842],
         [0.7257],
         [0.8523],
         [0.4178],
         [0.8290],
         [0.7476],
         [0.5684],
         [0.5793],
         [0.7689],
         [0.7537],
         [0.7329],
         [0.5245],
         [0.7752],
         [0.3226],
         [0.5683],
         [0.6753],
         [0.9400],
         [0.8018],
         [0.8396],
         [0.6740],
         [0.8224],
         [0.6328],
         [0.4659],
         [0.4450],
         [0.7123],
         [0.7595],
         [0.6831],
         [0.5121],
         [0.5543],
         [0.6204],
         [0.7741],
         [0.6844],
         [0.5232],
         [0.9160],
         [0.6659],
         [0.5258],
         [0.8834],
         [0.7505],
         [0.6454],
         [0.7164],
         [0.4878],
         [0.6649],
         [0.7732],
         [0.5654],
         [0.8847],
         [0.6966],
         [0.6523],
         [0.