In [1]:
import hydra
import pandas as pd
import pytorch_lightning as pl
from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning.callbacks import EarlyStopping, GPUStatsMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from src import DATA_DIR, LOGGING_DIR, MODEL_CHECKPOINTS_DIR, TRACK1_DIR
from src.configs import register_configs
from src.configs.train import TrainConfig
from src.data import LenaDataModule
from src.models import LenaTrans
from src.system import LenaSystem
from src.utils.torch import get_embeddings_projections


In [2]:
def get_datamodule(batch_size, num_workers):
    train = pd.read_csv(TRACK1_DIR / "train.csv")
    test = pd.read_csv(TRACK1_DIR / "test.csv")
    features_df = pd.read_csv(DATA_DIR / "features.csv")
    datamodule = LenaDataModule(
        train=train, test=test, features_df=features_df, batch_size=batch_size, num_workers=num_workers
    )

    return datamodule


def train(cfg: TrainConfig, trial=None):
    logger = TensorBoardLogger(
        str(LOGGING_DIR),
        name=cfg.name,
        version=cfg.version,
        log_graph=False,
        default_hp_metric=True,
    )

    checkpoints = ModelCheckpoint(
        dirpath=str(MODEL_CHECKPOINTS_DIR / cfg.name),
        monitor="hp_metric",
        verbose=True,
        mode="max",
        save_top_k=-1,
    )

    early_stopping = EarlyStopping(monitor="Val/f1_score", patience=cfg.patience)
    if trial:
        early_stopping = PyTorchLightningPruningCallback(monitor="Val/f1_score", trial=trial)  # type: ignore

    gpu_monitor = GPUStatsMonitor()

    datamodule = get_datamodule(batch_size=cfg.batch_size, num_workers=cfg.num_workers)

    # trainer
    trainer = pl.Trainer(
        logger=logger,
        callbacks=[gpu_monitor, checkpoints, early_stopping],
        profiler="simple",
        benchmark=True,
        gpus=cfg.gpus,
        max_epochs=cfg.max_epochs
        # enable_pl_optimizer=True,
    )

    embeddings_projections = get_embeddings_projections(
        categorical_features=datamodule.categorical_features, features_df=datamodule.features_df
    )

    model = LenaTrans(
        cat_features=datamodule.categorical_features,
        embeddings_projections=embeddings_projections,
        numerical_features=datamodule.numerical_features,
        station_col_name="hydro_fixed_station_id_categorical",
        day_col_name="day_target_categorical",
        rnn_units=cfg.rnn_units,
        top_classifier_units=cfg.top_classifier_units,
    )

    system = LenaSystem(model=model, alpha=cfg.alpha, gamma=cfg.gamma, lr=cfg.lr, weight_decay=cfg.weight_decay)

    trainer.fit(system, datamodule=datamodule)

    return trainer, system, datamodule

In [3]:
from src.configs.train import TrainConfig

In [4]:
cfg = TrainConfig()

In [None]:
datamodule = train(cfg)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                      | Params
--------------------------------------------------------
0 | model     | LenaTrans                 | 443 K 
1 | criterion | BinaryFocalLossWithLogits | 0     
--------------------------------------------------------
443 K     Trainable params
0         Non-trainable params
443 K     Total params
1.772     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  recall = tps / tps[-1]
  F = 2 / (1 / precision + 1 / recall)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
%debug

> [0;32m/home/dan/.cache/pypoetry/virtualenvs/emergency-hack-xcMZg9e2-py3.8/lib/python3.8/site-packages/torch/nn/modules/sparse.py[0m(137)[0;36m__init__[0;34m()[0m
[0;32m    135 [0;31m        [0mself[0m[0;34m.[0m[0mscale_grad_by_freq[0m [0;34m=[0m [0mscale_grad_by_freq[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    136 [0;31m        [0;32mif[0m [0m_weight[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 137 [0;31m            [0mself[0m[0;34m.[0m[0mweight[0m [0;34m=[0m [0mParameter[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m([0m[0mnum_embeddings[0m[0;34m,[0m [0membedding_dim[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    138 [0;31m            [0mself[0m[0;34m.[0m[0mreset_parameters[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    139 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m/home/dan/Projects/emergency

In [16]:
train_dl = datamodule.train_dataloader()

In [17]:
for batch in train_dl:
    pass

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/dan/.cache/pypoetry/virtualenvs/emergency-hack-xcMZg9e2-py3.8/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/dan/.cache/pypoetry/virtualenvs/emergency-hack-xcMZg9e2-py3.8/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/dan/.cache/pypoetry/virtualenvs/emergency-hack-xcMZg9e2-py3.8/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/dan/Projects/emergency_datahack/src/data.py", line 120, in __getitem__
    encoded_station_id = self.full_df[features_mask]["hydro_fixed_station_id_categorical"].values[0]
IndexError: index 0 is out of bounds for axis 0 with size 0


In [18]:
%debug

> [0;32m/home/dan/.cache/pypoetry/virtualenvs/emergency-hack-xcMZg9e2-py3.8/lib/python3.8/site-packages/torch/_utils.py[0m(429)[0;36mreraise[0;34m()[0m
[0;32m    427 [0;31m            [0;31m# have message field[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    428 [0;31m            [0;32mraise[0m [0mself[0m[0;34m.[0m[0mexc_type[0m[0;34m([0m[0mmessage[0m[0;34m=[0m[0mmsg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 429 [0;31m        [0;32mraise[0m [0mself[0m[0;34m.[0m[0mexc_type[0m[0;34m([0m[0mmsg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    430 [0;31m[0;34m[0m[0m
[0m[0;32m    431 [0;31m[0;34m[0m[0m
[0m
ipdb> q
