In [3]:
from seq_loader.names_loader import TDTSDataset, collate_padded_batch_fn
from seq_loader.lightning import LitLSTM
from pathlib import Path
import torch
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
data_path = Path('/home/jovyan/data/tbts_data/')
filename = "tbts-names"
_NAME_CHARS_COL = str("name_chars")
tdts_dataset = TDTSDataset(
    tbts_path=data_path / f"{filename}.parquet",
    tbts_groupby_column="G",
    cat_columns=[_NAME_CHARS_COL],
    labels_path=data_path / f"{filename}-labels.parquet",
)

In [6]:
embedding_dim = 4
col_embedding_dims = dict.fromkeys(tdts_dataset.cat_columns, embedding_dim)

In [8]:
embedding_dims = tdts_dataset.cat_col_embeddings_params(col_embedding_dims=col_embedding_dims)

In [9]:
model = LitLSTM(embedding_dims=embedding_dims, h_size=8, n_classes=tdts_dataset.n_calsses, lr=1e-3)

In [10]:
model

LitLSTM(
  (lstm): TBDTSLstm(
    (encoder): CatColumnsDataEncoder(
      (embeddings): ModuleDict(
        (name_chars): Embedding(88, 4, padding_idx=0)
      )
    )
    (lstm): LSTM(4, 8, batch_first=True)
    (h2o): Linear(in_features=16, out_features=18, bias=True)
  )
  (loss): CrossEntropyLoss()
  (accuracy): MulticlassAccuracy()
)

In [11]:
loader = torch.utils.data.DataLoader(
        dataset=tdts_dataset,
        shuffle=True,
        batch_size=10,
        collate_fn=collate_padded_batch_fn,
    )

In [12]:
trainer = L.Trainer(default_root_dir=data_path / 'lit_lstm_names_log', accelerator="auto", devices=1, max_epochs=200)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [13]:
logger = TensorBoardLogger("tb_logs", name="last_names")

In [19]:
! pwd

/home/jovyan/work/src/rnn/ssl-e-seq/demo-dataloader/nbs
