In [1]:
import torch
import pandas as pd
from torch import nn
import pytorch_lightning as pl
import torch.utils.data as data_utils
import numpy as np
from os import cpu_count
from types import NoneType

In [2]:
import wandb
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, LearningRateFinder
from pytorch_lightning.loggers import WandbLogger
from gensim.models import Word2Vec

In [3]:
torch.set_float32_matmul_precision("medium")

In [4]:
df=pd.read_parquet('./data/books.par')
df.columns

Index(['title', 'genre', 'summary', 'input_ids', 'att_mask', 'label',
       'mapped_inputs'],
      dtype='object')

# dataset

In [5]:
class BookDataset(data_utils.Dataset):
    def __init__(self, df: pd.DataFrame):
        self.input_ids=torch.tensor(df.mapped_inputs, dtype=torch.int32)
        self.att_masks=torch.tensor(df.att_mask, dtype=torch.float32)
        self.y=torch.tensor(df.label.values, dtype=torch.uint8)
        self.no_classes=df.label.nunique()

    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, indexes):
        return self.input_ids[indexes], self.att_masks[indexes], self.y[indexes]
        

In [6]:
df=df.iloc[np.random.permutation(df.shape[0])].reset_index(drop=True)
split=int(df.shape[0]*0.9)

train_df=df.iloc[:split]
val_df=df.iloc[split:].reset_index(drop=True)

In [7]:
_,label_weights=np.unique(train_df.label, return_counts=True)
label_weights=1/label_weights
label_weights=label_weights/np.sum(label_weights)
label_weights

array([0.02354139, 0.03259577, 0.04097754, 0.0342735 , 0.03478504,
       0.02017834, 0.20949194, 0.18279199, 0.21187253, 0.20949194])

In [8]:
train_data=BookDataset(train_df)
val_data=BookDataset(val_df)

train_dataloader=data_utils.DataLoader(train_data, batch_size=32, num_workers=cpu_count(),
                                       shuffle=True, drop_last=True)
val_dataloader=data_utils.DataLoader(val_data, batch_size=32, num_workers=cpu_count())

  self.input_ids=torch.tensor(df.mapped_inputs, dtype=torch.int32)


# pl module

In [9]:
class BookGenreClassifier(pl.LightningModule):
    def __init__(self, model, lr=1e-2, loss=nn.CrossEntropyLoss(), l2=1e-5, lr_dc_step=3, lr_dc=0.1, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=['model','loss'])
        if isinstance(loss.weight, NoneType):
            weighted_loss=False
        else: weighted_loss=True
        self.save_hyperparameters({'name':model.name, 
                                   'dropout_p':model.dropout_p,
                                   'weighted_loss':weighted_loss})
        self.lr=lr
        self.loss=loss
        self.model=model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch[:-1]
        y = batch[-1]
        logits=self(x)
        loss=self.loss(logits, y)

        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def evaluate(self, batch, mode=None):
        x = batch[:-1]
        y = batch[-1]
        logits=self(x)

        loss=self.loss(logits, y)

        preds=torch.argmax(logits, axis=1)
        acc=torch.sum(preds==y)/y.shape[0]
        # TODO add more metrics

        if mode:
            self.log(mode+'_loss', loss,  prog_bar=True)
            self.log(mode+'_acc', 100*acc,  prog_bar=True)

    def validation_step(self, batch, *args, **kwargs):
        return self.evaluate(batch, "val")
    def test_step(self, batch, *args, **kwargs):
        return self.evaluate(batch, "test")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.lr, weight_decay=self.hparams.l2
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=self.hparams.lr_dc_step,
            factor=self.hparams.lr_dc,
            cooldown=1,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "strict": False,
                "interval": "epoch",
                "frequency": 1,
                "name": "scheduler_lr",
            },
        }

# different models

## dummy example

In [10]:
class DummyModel(nn.Module): 
    def __init__(self, in_dim=7031, hid_dim=128, out_dim=10):
        # dummy model as an example, just one hidden layer straight from list of tokens
        super().__init__()
        self.name='DummyModel'
        self.dropout_p=0
        self.l1=nn.Linear(in_dim, hid_dim)
        self.nonlinear=nn.Tanh()
        self.l2=nn.Linear(hid_dim, out_dim)

    def forward(self, x):
        in_ids, att_mask = x
        x=in_ids*att_mask
        x=self.l1(x)
        x=self.nonlinear(x)
        return self.l2(x)

### train

In [11]:
dm_model=BookGenreClassifier(DummyModel(), loss=nn.CrossEntropyLoss(weight=torch.tensor(label_weights, dtype=torch.float32)))

In [12]:
trainer=pl.Trainer(max_epochs=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/szymon/pythonvenvs/rocmwork/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [13]:
trainer.fit(dm_model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | loss  | CrossEntropyLoss | 0      | train
1 | model | DummyModel       | 901 K  | train
---------------------------------------------------
901 K     Trainable params
0         Non-trainable params
901 K     Total params
3.606     Total estimated model params size (MB)


Sanity Checking: |                                                            | 0/? [00:00<?, ?it/s]

Training: |                                                                   | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


# global variables

In [14]:
train_data[0]

(tensor([  302,  1373,  8965,  ..., 36332, 36332, 36332], dtype=torch.int32),
 tensor([1., 1., 1.,  ..., 0., 0., 0.]),
 tensor(0, dtype=torch.uint8))

In [15]:
tokens_set=set()
for tokens in df.mapped_inputs.values:
    tokens_set=tokens_set.union(set(tokens))
vocab_size=len(tokens_set)
vocab_size

36333

# w2v

In [16]:
from gensim.utils import tokenize

df=pd.read_parquet('./data/books.par')
bookwords = []
#for s in df.summary:
#    bookwords.append(list(tokenize(s, lowercase=True)))
for s in df.mapped_inputs:
    bookwords.append([str(i) for i in s])
print(bookwords[:1])

[['279', '1908', '195', '3327', '237', '181', '633', '9021', '1337', '1734', '181', '27437', '11320', '426', '237', '238', '12110', '257', '1608', '209', '11513', '181', '2372', '203', '226', '19617', '13', '1263', '3313', '274', '176', '362', '8167', '6176', '14', '1788', '909', '209', '18106', '341', '2477', '18154', '84', '13', '302', '1373', '5673', '535', '13092', '237', '9147', '12110', '209', '402', '308', '10117', '181', '14988', '245', '279', '1908', '195', '3327', '1768', '599', '13', '12110', '468', '504', '6317', '203', '4116', '696', '205', '181', '2287', '1633', '203', '384', '1129', '535', '258', '6888', '203', '3529', '3187', '257', '17462', '8043', '13', '20972', '11', '181', '4116', '1967', '4018', '18110', '270', '1536', '274', '181', '650', '308', '16290', '203', '181', '13704', '6360', '205', '181', '1950', '13', '13092', '237', '32497', '1378', '335', '176', '1468', '3787', '270', '3789', '35327', '11', '1760', '274', '181', '17215', '256', '17937', '11', '892', '

In [17]:
#w2vmodel = Word2Vec(bookwords, vector_size=100, min_count=0)

In [33]:
class Word2VecSimple(nn.Module):
    def __init__(self, hid_dim, out_dim, vocab_size):
        super().__init__()
        self.name='Word2VecSimple'
        self.dropout_p=0

        w2vmodel = Word2Vec(bookwords, vector_size=hid_dim, min_count=0)
        self.emb = nn.Embedding(vocab_size, hid_dim, padding_idx=-1)

        emb_lst = []
        for v in range(vocab_size):
            emb_lst.append(w2vmodel.wv[str(v)])
        
        emb_mat = np.array(emb_lst)
        self.emb.load_state_dict({'weight': torch.from_numpy(emb_mat)})
        # load embeddings from pretrained word2vec
        #self.emb=nn.Embedding(vocab_size, hid_dim, padding_idx=-1)
        self.out_layer=nn.Linear(hid_dim, out_dim)

    def forward(self, x):
        inputs, _ = x

        w2v_output = self.emb(inputs)
        avg_output = w2v_output.mean(dim=1)
        #print(f"w2v_output:{w2v_output.shape}")
        #print(f"avg_output:{avg_output.shape}")
        return self.out_layer(avg_output)

In [60]:
w2v_simple_model = BookGenreClassifier(Word2VecSimple(256, 10, vocab_size=vocab_size), 
                             loss=nn.CrossEntropyLoss(),#weight=torch.tensor(label_weights, dtype=torch.float32),
                             lr_dc=0.1,
                             lr_dc_step=4,
                             )

In [61]:
wandb_logger = WandbLogger(
        project="ecoNLP", entity="kpuchalskixiv", log_model=True
    )

trainer=pl.Trainer(max_epochs=50,
                   callbacks=[
            EarlyStopping(
                monitor="val_acc", patience=25, mode="max", check_finite=True, check_on_train_epoch_end=False
            ),
            LearningRateMonitor(),
            ModelCheckpoint(monitor="val_acc", mode="max"),
         #   LearningRateFinder(min_lr=1e-4, num_training_steps=1000)
            ],
            logger=wandb_logger,
        )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [62]:
trainer.fit(w2v_simple_model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | loss  | CrossEntropyLoss | 0      | train
1 | model | Word2VecSimple   | 9.3 M  | train
---------------------------------------------------
9.3 M     Trainable params
0         Non-trainable params
9.3 M     Total params
37.215    Total estimated model params size (MB)


Sanity Checking: |                                                            | 0/? [00:00<?, ?it/s]

Training: |                                                                   | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [63]:
wandb.finish()

VBox(children=(Label(value='106.484 MB of 106.484 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
scheduler_lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,██▇▇▆▇▇▆▆▆▆▅▆▅▄▆▄▄▄▅▄▃▄▃▃▄▄▃▃▃▄▂▃▃▁▃▄▃▂▃
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_acc,▁▁▁▁▂▂▃▃▄▄▅▅▅▆▅▆▆▇▆▆▇▇▇▇▇▇▇█████████████
val_loss,██▇▇▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
scheduler_lr,0.01
train_loss,1.02797
trainer/global_step,6499.0
val_acc,55.79399
val_loss,1.31672


# RNN

In [54]:
class SimpleRNN(nn.Module):
    def __init__(self, hid_dim, out_dim, vocab_size, dropout_p=0):
        super().__init__()
        self.name='SimpleRNN'
        self.dropout_p=dropout_p

        # last token states 'end of string' and is repeated multiple time at the end of an input
        # therefore set its embedding to 0 with padding_idx=-1
        self.emb=nn.Embedding(vocab_size, hid_dim, padding_idx=-1) 
        self.model=nn.RNN(input_size=hid_dim, hidden_size=hid_dim, batch_first=True, dropout=dropout_p)
        self.out_layer=nn.Linear(hid_dim, out_dim)

    def forward(self, x):
        inputs, _=x
        rnn_input=self.emb(inputs)
        _, hn=self.model(rnn_input)
        #use last hidden state as sequence representation
        hn=hn.squeeze()
        return self.out_layer(hn)

In [55]:
rnn_model=BookGenreClassifier(SimpleRNN(128, 10, vocab_size=vocab_size), 
                             loss=nn.CrossEntropyLoss(),#weight=torch.tensor(label_weights, dtype=torch.float32),
                             lr_dc=0.5,
                             lr_dc_step=4,
                             )

In [56]:
wandb_logger = WandbLogger(
        project="ecoNLP", entity="kpuchalskixiv", log_model=True
    )

In [57]:
trainer=pl.Trainer(max_epochs=50,
                   callbacks=[
            EarlyStopping(
                monitor="val_acc", patience=25, mode="max", check_finite=True, check_on_train_epoch_end=False
            ),
            LearningRateMonitor(),
            ModelCheckpoint(monitor="val_acc", mode="max"),
         #   LearningRateFinder(min_lr=1e-4, num_training_steps=1000)
            ],
            logger=wandb_logger,
        )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [58]:
trainer.fit(rnn_model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | loss  | CrossEntropyLoss | 0      | train
1 | model | SimpleRNN        | 4.7 M  | train
---------------------------------------------------
4.7 M     Trainable params
0         Non-trainable params
4.7 M     Total params
18.740    Total estimated model params size (MB)


Sanity Checking: |                                                            | 0/? [00:00<?, ?it/s]

Training: |                                                                   | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

/home/szymon/pythonvenvs/rocmwork/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [59]:
wandb.finish()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
scheduler_lr,███████▃▃▃▃▃▃▃▃▃▁▁▁▁
train_loss,▃▂▄▅▃▄█▄▆▃▅▇▃▄▇▄▃▄▃▅▅▃▂▄▄▃▂▆▃▁▃▅▂▃▂▃▄▂▂▃
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇████
val_acc,▁███▆▁█▄█▆▂██▆██▆██
val_loss,▅▄▄▆▆▅█▂▂▃▁▃▅▂▅▂▁▁▁

0,1
epoch,19.0
scheduler_lr,0.0025
train_loss,2.03964
trainer/global_step,2499.0
val_acc,21.24463
val_loss,2.06619


# LSTM

In [48]:
class SimpleLSTM(nn.Module):
    def __init__(self, hid_dim, out_dim, vocab_size, dropout_p=0):
        super().__init__()
        self.name='SimpleLSTM'
        self.dropout_p=dropout_p
        # last token states 'end of string' and is repeated multiple time at the end of an input
        # therefore set its embedding to 0 with padding_idx=-1
        self.emb=nn.Embedding(vocab_size, hid_dim, padding_idx=-1) 
        self.model=nn.LSTM(input_size=hid_dim, hidden_size=hid_dim, batch_first=True, dropout=dropout_p)
        self.out_layer=nn.Linear(hid_dim, out_dim)

    def forward(self, x):
        inputs, _=x
        rnn_input=self.emb(inputs)

        h0 = torch.randn(1, rnn_input.shape[0], rnn_input.shape[-1], device=rnn_input.device)
        c0 = torch.randn(1, rnn_input.shape[0], rnn_input.shape[-1], device=rnn_input.device)
        _, (hn, _)=self.model(rnn_input, (h0, c0))
        #use last hidden state as sequence representation
        hn=hn.squeeze()
        return self.out_layer(hn)

In [49]:
lstm_model=BookGenreClassifier(SimpleLSTM(128, 10, vocab_size=vocab_size, dropout_p=0.2137), 
                             loss=nn.CrossEntropyLoss(),#weight=torch.tensor(label_weights, dtype=torch.float32),
                             lr_dc=0.5,
                             lr_dc_step=4,
                             )



In [50]:
wandb_logger = WandbLogger(
        project="ecoNLP", entity="kpuchalskixiv", log_model=True
    )

In [51]:
trainer=pl.Trainer(max_epochs=50,
                   callbacks=[
            EarlyStopping(
                monitor="val_acc", patience=25, mode="max", check_finite=True, check_on_train_epoch_end=False
            ),
            LearningRateMonitor(),
            ModelCheckpoint(monitor="val_acc", mode="max"),
         #   LearningRateFinder(min_lr=1e-4, num_training_steps=1000)
            ],
            logger=wandb_logger,
        )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [52]:
trainer.fit(lstm_model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | loss  | CrossEntropyLoss | 0      | train
1 | model | SimpleLSTM       | 4.8 M  | train
---------------------------------------------------
4.8 M     Trainable params
0         Non-trainable params
4.8 M     Total params
19.136    Total estimated model params size (MB)


Sanity Checking: |                                                            | 0/? [00:00<?, ?it/s]

Training: |                                                                   | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                                                 | 0/? [00:00<?, ?it/s]

/home/szymon/pythonvenvs/rocmwork/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [53]:
wandb.finish()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▃▃▃▃▅▅▅▆▆▆▆████
scheduler_lr,▁▁▁▁▁▁
train_loss,▇█▆▅█▁▆▄▁▄▄▅▅
trainer/global_step,▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇███
val_acc,▁▁▁▁▁
val_loss,▁▂█▂▂

0,1
epoch,4.0
scheduler_lr,0.01
train_loss,2.02803
trainer/global_step,650.0
val_acc,21.24463
val_loss,2.04265


# Simple Transformer