In [None]:
! uv pip install drytorch

In [1]:
import torch
import torchmetrics

from torcheval import metrics as eval_metrics

from drytorch.core import protocols as p


tensor_a = torch.ones(1, 1, dtype=torch.float)
tensor_b = 3 * torch.ones(1, 1, dtype=torch.float)
torch_metric = torchmetrics.MeanSquaredError()
eval_metric = eval_metrics.MeanSquaredError()


def is_valid_objective(
    metric: p.ObjectiveProtocol[torch.Tensor, torch.Tensor],
) -> bool:
    """Test metric follows the Objective protocol."""
    return isinstance(metric, p.ObjectiveProtocol)


torch_metric.update(tensor_a, tensor_b)
eval_metric.update(tensor_a, tensor_b)

if not torch.isclose(torch_metric.compute(), eval_metric.compute()):
    raise AssertionError('Metrics values should match.')

if not (is_valid_objective(eval_metric) and is_valid_objective(torch_metric)):
    raise AssertionError('These objects should follow the ObjectiveProtocol.')

In [2]:
def is_valid_loss(
    metric: p.LossProtocol[torch.Tensor, torch.Tensor],
) -> bool:
    """Test metric follows the Loss protocol."""
    return isinstance(metric, p.LossProtocol)


if not is_valid_loss(torch_metric):
    raise AssertionError('This object should also follow the LossProtocol.')

In [3]:
from drytorch.contrib.torchmetrics import from_torchmetrics


new_metric = 1 + torch_metric
imported_metric = from_torchmetrics(new_metric)
imported_metric.update(tensor_a, tensor_b)
expected_metrics_from_torchmetrics = {
    'Combined Loss': torch.tensor(5.0),
    'MeanSquaredError': torch.tensor(4.0),
}
if not imported_metric.compute() == expected_metrics_from_torchmetrics:
    raise AssertionError('Metrics values should be as expected.')

In [4]:
from torch.nn.functional import mse_loss as mse_loss_fn  # returns scalar value

from drytorch.lib.objectives import Metric


def mae_loss_fn(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Returns batched Meas Absolute Error (MAE) values."""
    return torch.abs(outputs - targets).flatten(1).mean(1)


mse_metric = Metric(mse_loss_fn, name='MSE', higher_is_better=False)
mae_metric = Metric(mae_loss_fn, 'MAE', higher_is_better=False)
metric_collection = mse_metric | mae_metric
metric_collection.update(tensor_a, tensor_b)
metric_collection.compute()
expected_metric_collection = {
    'MSE': torch.tensor(4.0),
    'MAE': torch.tensor(2.0),
}
if not metric_collection.compute() == expected_metric_collection:
    raise AssertionError('Metrics values should be as expected.')

In [5]:
from typing_extensions import override

from drytorch.lib.objectives import Objective


class MyMetrics(Objective[torch.Tensor, torch.Tensor]):
    """Class to calculate MSE and MAE more efficiently."""

    @override
    def calculate(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> dict[str, torch.Tensor]:
        diff = outputs - targets
        return {
            'MSE': torch.pow(diff, 2).flatten(1).mean(1),
            'MAE': torch.abs(diff).flatten(1).mean(1),
        }


my_metrics = MyMetrics()
my_metrics.update(tensor_a, tensor_b)
my_metrics.compute()
if not my_metrics.compute() == expected_metric_collection:
    raise AssertionError('Metrics values should be as before.')

In [6]:
from torch.nn.functional import mse_loss as mse_loss_fn  # returns scalar value

from drytorch.lib.objectives import Loss


mse_loss = Loss(mse_loss_fn, name='MSE')
mae_loss = Loss(mae_loss_fn, 'MAE')
composed_loss = mse_loss**2 + 0.5 * mae_loss
composed_loss.update(tensor_a, tensor_b)
expected_metrics_from_loss = {
    'Combined Loss': torch.tensor(17.0),
    'MSE': torch.tensor(4.0),
    'MAE': torch.tensor(2.0),
}
if not composed_loss.compute() == expected_metrics_from_loss:
    raise AssertionError('Metrics values should be as expected.')

In [7]:
if composed_loss.formula != '[MSE]^2 + 0.5 x [MAE]':
    raise AssertionError('Formula mismatch.')

In [8]:
from drytorch.contrib.torcheval import from_torcheval


eval_metric_with_sync = from_torcheval(eval_metric)

eval_metric_with_sync.sync()

World size is 1, and metric(s) not synced. returning the input metric(s).
