In [1]:
import logging
import os
from pathlib import Path
from typing import Any
import torch
import pytorch_lightning as pl
from hydra.utils import instantiate
from omegaconf import DictConfig, ListConfig, OmegaConf
import matplotlib.pyplot as plt
from hydra import compose, initialize

from emg2qwerty import transforms, utils


In [6]:

log = logging.getLogger(__name__)

def _full_session_paths(dataset: ListConfig) -> list[Path]:
    sessions = [session["session"] for session in dataset]
    return [
        Path(config.dataset.root).joinpath(f"{session}.hdf5")
        for session in sessions
    ]

def _build_transform(configs: list[DictConfig]) -> Any:
    return transforms.Compose([instantiate(cfg) for cfg in configs])

def plot_lr_vs_loss(config: DictConfig):
    log.info(f"\nConfig:\n{OmegaConf.to_yaml(config)}")

    pl.seed_everything(config.seed, workers=True)

    log.info(f"Instantiating LightningModule {config.module}")
    module = instantiate(
        config.module,
        optimizer=config.optimizer,
        lr_scheduler=config.lr_scheduler,
        decoder=config.decoder,
        _recursive_=False,
    )

    log.info(f"Instantiating LightningDataModule {config.datamodule}")
    datamodule = instantiate(
        config.datamodule,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        train_sessions=_full_session_paths(config.dataset.train),
        val_sessions=_full_session_paths(config.dataset.val),
        test_sessions=_full_session_paths(config.dataset.test),
        train_transform=_build_transform(config.transforms.train),
        val_transform=_build_transform(config.transforms.val),
        test_transform=_build_transform(config.transforms.test),
        _convert_="object",
    )

    trainer = pl.Trainer(**config.trainer)

    lr_finder = trainer.tuner.lr_find(module, datamodule)

    fig = lr_finder.plot(suggest=True)
    plt.show()


In [None]:
@hydra.main(version_base=None, config_path="../config", config_name="base")

with initialize(version_base=None,config_path="config"):
    config = compose(config_name="base")
    print(OmegaConf.to_yaml(config))
plot_lr_vs_loss(config)


Global seed set to 1501


user: single_user
dataset:
  train:
  - user: 89335547
    session: 2021-06-03-1622765527-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-06-02-1622681518-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-06-04-1622863166-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-07-22-1627003020-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-07-21-1626916256-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-07-22-1627004019-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-06-05-1622885888-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-06-02-1622679967-keystrokes-dca-study@1-0efbe614-9ae6-4131-9192-4398359b4f5f
  - user: 89335547
    session: 2021-06-03-162276439

InterpolationResolutionError: ValueError raised while resolving interpolation: HydraConfig was not set
    full_key: dataset.root
    object_type=dict

In [7]:
import logging
from pathlib import Path
from typing import Sequence
import hydra
from hydra.utils import instantiate, get_original_cwd
from omegaconf import DictConfig, ListConfig, OmegaConf
from hydra import compose, initialize
import os

# Helper to instantiate full paths for dataset sessions
def _full_session_paths(config: DictConfig, dataset: ListConfig) -> list[Path]:
    sessions = [session["session"] for session in dataset]
    return [
        Path(config.dataset.root).joinpath(f"{session}.hdf5")
        for session in sessions
    ]

# Helper to instantiate transforms
def _build_transform(configs: Sequence[DictConfig]):
    return transforms.Compose([instantiate(cfg) for cfg in configs])

# Function to get the datamodule
def get_datamodule(config: DictConfig):
    datamodule = instantiate(
        config.datamodule,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        train_sessions=_full_session_paths(config, config.dataset.train),
        val_sessions=_full_session_paths(config, config.dataset.val),
        test_sessions=_full_session_paths(config, config.dataset.test),
        train_transform=_build_transform(config.transforms.train),
        val_transform=_build_transform(config.transforms.val),
        test_transform=_build_transform(config.transforms.test),
        _convert_="object",
    )
    return datamodule

@hydra.main(version_base=None, config_path="../config", config_name="base")
def main(config: DictConfig):
    log.info(f"\nConfig:\n{OmegaConf.to_yaml(config)}")

    # Add working dir to PYTHONPATH
    working_dir = get_original_cwd()
    python_paths = os.environ.get("PYTHONPATH", "").split(os.pathsep)
    if working_dir not in python_paths:
        python_paths.append(working_dir)
        os.environ["PYTHONPATH"] = os.pathsep.join(python_paths)

    # Get the datamodule
    datamodule = get_datamodule(config)

    # Calculate steps per epoch
    train_dataloader = datamodule.train_dataloader()
    steps_per_epoch = len(train_dataloader.dataset) // config.batch_size
    logging.info(f"Steps per epoch: {steps_per_epoch}")

if __name__ == "__main__":
    # initialize(config_path="../config", job_name="data_discovery")
    # config = compose(config_name="base")
    main()

usage: ipykernel_launcher.py [--help] [--hydra-help] [--version]
                             [--cfg {job,hydra,all}] [--resolve]
                             [--package PACKAGE] [--run] [--multirun]
                             [--shell-completion] [--config-path CONFIG_PATH]
                             [--config-name CONFIG_NAME]
                             [--config-dir CONFIG_DIR]
                             [--experimental-rerun EXPERIMENTAL_RERUN]
                             [--info [{all,config,defaults,defaults-tree,plugins,searchpath}]]
                             [overrides ...]
ipykernel_launcher.py: error: unrecognized arguments: --f=/run/user/0/jupyter/runtime/kernel-v354ea66aff144cda0d1d157922705ccee7ac54d23.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
