In [None]:
! uv pip install drytorch

In [1]:
import dataclasses

from drytorch import Experiment as GenericExperiment


@dataclasses.dataclass(frozen=True)
class SimpleConfig:
    """A simple configuration."""

    batch_size: int


class MyExperiment(GenericExperiment[SimpleConfig]):
    """Class for Simple Experiments."""


my_config = SimpleConfig(32)
first_experiment = MyExperiment(
    my_config,
    name='FirstExp',
    par_dir='experiments/',
    tags=[],
)

In [2]:
def implement_experiment() -> None:
    """Here should be the code for the experiment."""


with first_experiment.create_run() as run:
    first_id = run.id
    implement_experiment()


with first_experiment.create_run(resume=True) as run:
    second_id = run.id
    implement_experiment()

if first_id != second_id:
    raise AssertionError('The resumed run should keep the id.')

In [3]:
run = first_experiment.create_run()
run.start()
run.stop()

In [4]:
from drytorch.core import exceptions


def get_batch() -> int:
    """Retrieve the batch size setting."""
    return MyExperiment.get_config().batch_size


with first_experiment.create_run():
    get_batch()

try:
    get_batch()
except (exceptions.AccessOutsideScopeError, exceptions.NoActiveExperimentError):
    pass
else:
    raise AssertionError('Configuration accessed when no run is on.')

In [5]:
from torch import nn

from drytorch import Model
from drytorch.core import exceptions


second_experiment = MyExperiment(
    my_config,
    name='SecondExp',
    par_dir='experiments/',
    tags=[],
)
module = nn.Linear(1, 1)

with first_experiment.create_run():
    first_model = Model(module)

try:
    second_model = Model(module)
except exceptions.NoActiveExperimentError:
    pass
else:
    raise AssertionError('Model instantiated when no experiment is running.')


with second_experiment.create_run():
    try:
        second_model = Model(module)
    except exceptions.ModuleAlreadyRegisteredError:
        pass
    else:
        raise AssertionError('Module registered through two Model instances.')

In [6]:
from drytorch.core import register


with second_experiment.create_run():
    register.unregister_model(first_model)
    second_model = Model(first_model.module)

In [7]:
import torch

from torch.utils.data import Dataset
from typing_extensions import override

from drytorch.lib.load import DataLoader
from drytorch.lib.runners import ModelRunner


class MyDataset(Dataset[tuple[torch.Tensor, torch.Tensor]]):
    """Example dataset containing tensor with value one."""

    def __init__(self) -> None:
        """Initialize some dummy attributes."""
        super().__init__()
        self.empty_container = []
        self.none = None

    def __len__(self) -> int:
        """Size of the dataset."""
        return 1

    @override
    def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.ones(1), torch.ones(1)


one_dataset: Dataset[tuple[torch.Tensor, torch.Tensor]] = MyDataset()

with second_experiment.create_run(resume=True):  # correctly resuming run
    loader = DataLoader(one_dataset, batch_size=1)
    model_caller = ModelRunner(second_model, loader=loader)
    model_caller()

    ModelRunner:   0%|          | 0/1, 00:00<?

In [8]:
with second_experiment.create_run():  # new run
    loader = DataLoader(one_dataset, batch_size=1)
    model_caller = ModelRunner(second_model, loader=loader)
    try:
        model_caller()
    except exceptions.ModuleNotRegisteredError:
        pass
    else:
        raise AssertionError('Model not registered in the current run')

In [9]:
import functools
import pprint

from drytorch.core import log_events
from drytorch.core.track import Tracker


class MetadataVisualizer(Tracker):
    """Tracker that prints the metadata on the console."""

    @functools.singledispatchmethod
    @override
    def notify(self, event: log_events.Event) -> None:
        return super().notify(event)

    @notify.register
    def _(self, event: log_events.ModelRegistrationEvent) -> None:
        pprint.pp(event.architecture_repr)
        return super().notify(event)

    @notify.register
    def _(self, event: log_events.ActorRegistrationEvent) -> None:
        pprint.pp(event.metadata)
        return super().notify(event)


third_experiment = MyExperiment(
    my_config,
    name='ThirdExp',
    par_dir='experiments/',
    tags=[],
)

third_experiment.trackers.subscribe(MetadataVisualizer())

In [10]:
with third_experiment.create_run():  # correctly resuming run
    third_model = Model(nn.Linear(1, 1))

'Linear(in_features=1, out_features=1, bias=True)'


In [11]:
with third_experiment.create_run(resume=True):  # correctly resuming run
    loader = DataLoader(one_dataset, batch_size=1)
    model_caller = ModelRunner(third_model, loader=loader)
    model_caller()

{'class': 'ModelRunner',
 'loader': {'class': 'DataLoader',
            'batch_size': 1,
            'dataset': 'MyDataset',
            'dataset_len': 1,
            'sampler': {'class': 'RandomSampler',
                        'data_source': 'range(0, 1)',
                        'replacement': False}},
 'model': {'class': 'Model',
           'checkpoint': 'LocalCheckpoint',
           'epoch': 0,
           'mixed_precision': False,
           'module': {'class': 'Linear',
                      'in_features': 1,
                      'out_features': 1,
                      'training': True}}}


  ModelRunner_2:   0%|          | 0/1, 00:00<?