In [10]:
import logging

import pytest
import torch
from torch.utils.data import Dataset
from dry_torch import Trainer
from dry_torch import StandardLoader
from dry_torch import Experiment
from dry_torch import ModelOptimizer
from dry_torch import CheckpointIO
from dry_torch import LossAndMetricsCalculator
from dry_torch import exceptions
from dry_torch import default_logging


In [11]:
class IndexDataset(Dataset[tuple[torch.Tensor, torch.Tensor]]):

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.Tensor([index]), torch.Tensor([index])

    def __len__(self) -> int:
        return 1600

In [12]:
class Linear(torch.nn.Module):

    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.linear(inputs)

In [13]:
def simple_fun(tensor: torch.Tensor, second: torch.Tensor) -> torch.Tensor:
    return tensor + second

In [14]:
Experiment('test_simple_training', config={'answer': 42}).run()
exp_pardir = 'test_experiments'
model = Linear(1, 1)
model_opt = ModelOptimizer(model)
cloned_model_opt = model_opt.clone('cloned_model')
checkpoint = CheckpointIO(model_opt, exp_pardir=exp_pardir)
loss_calc = LossAndMetricsCalculator(simple_fun)
dataset = IndexDataset()
loader = StandardLoader(dataset=dataset, batch_size=4)
trainer = Trainer(cloned_model_opt, loss_calc=loss_calc, loader=loader)
trainer.train(2)
checkpoint.save()



2024-05-27 21:00:43,674
Training cloned_model.
2024-05-27 21:00:44,088
End of training


In [6]:
logger = logging.getLogger('dry_torch')
logger.setLevel(25)