Skip to content

Commit

Permalink
Merge pull request #681 from mv1388/torch_metrics_support
Browse files Browse the repository at this point in the history
Torch metrics support
  • Loading branch information
mv1388 committed Jul 8, 2022
2 parents e570771 + b5a1409 commit 245d94f
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 4 deletions.
33 changes: 33 additions & 0 deletions aitoolbox/experiment/result_package/torch_metrics_packages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from aitoolbox.experiment.result_package.abstract_result_packages import AbstractResultPackage


class TorchMetricsPackage(AbstractResultPackage):
def __init__(self, torch_metrics):
"""Torch Metrics result package wrapper
https://github.com/Lightning-AI/metrics
Args:
torch_metrics (torchmetrics.Metric or torchmetrics.MetricCollection): single torchmetrics metric object or
a collection of such metrics wrapped inside the MetricCollection
"""
AbstractResultPackage.__init__(self, pkg_name='Torch Metrics', np_array=False)

self.metric = torch_metrics

def prepare_results_dict(self):
metric_result = self.metric(self.y_predicted, self.y_true)

if not isinstance(metric_result, dict):
metric_result = {self.metric.__class__.__name__: metric_result}

# Add suffix PTLMetrics to indicate that we are using PyTorch Lightning metrics instead of aitoolbox metric
metric_result = {f'{k}_PTLMetrics': v for k, v in metric_result.items()}

return metric_result

def metric_compute(self):
return self.metric.compute()

def metric_reset(self):
self.metric.reset()
15 changes: 15 additions & 0 deletions aitoolbox/torchtrain/callbacks/performance_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aitoolbox.cloud import s3_available_options, gcs_available_options
from aitoolbox.experiment.local_save.local_results_save import BaseLocalResultsSaver
from aitoolbox.experiment.result_reporting.report_generator import TrainingHistoryPlotter, TrainingHistoryWriter
from aitoolbox.experiment.result_package.torch_metrics_packages import TorchMetricsPackage


class ModelPerformanceEvaluation(AbstractCallback):
Expand Down Expand Up @@ -67,6 +68,13 @@ def on_epoch_end(self):
print(f'Skipping performance evaluation on this epoch ({self.train_loop_obj.epoch}). '
f'Evaluating every {self.eval_frequency} epochs.')

if isinstance(self.result_package, TorchMetricsPackage):
if self.on_train_data:
self.train_result_package.metric_reset()

if self.on_val_data:
self.result_package.metric_reset()

def evaluate_model_performance(self, prefix=''):
"""Calculate performance based on the provided result packages
Expand Down Expand Up @@ -129,6 +137,13 @@ def on_train_loop_registration(self):
self.train_loop_obj.experiment_timestamp,
self.train_loop_obj.local_model_result_folder_path)

if isinstance(self.result_package, TorchMetricsPackage):
if self.on_train_data:
self.train_result_package.metric.to(self.train_loop_obj.device)

if self.on_val_data:
self.result_package.metric.to(self.train_loop_obj.device)


class ModelPerformancePrintReport(AbstractCallback):
def __init__(self, metrics, on_each_epoch=True, report_frequency=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_predictions(self, model, batch_data, device):

model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
criterion = F.nll_loss
criterion = nn.NLLLoss()

callbacks = [ModelPerformanceEvaluation(ClassificationResultPackage(), args.__dict__,
on_train_data=True, on_val_data=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_predictions(self, batch_data, device):

model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
criterion = F.nll_loss
criterion = nn.NLLLoss()

callbacks = [ModelPerformanceEvaluation(ClassificationResultPackage(), args.__dict__,
on_train_data=True, on_val_data=True),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest

from tests.utils import DummyTorchMetrics

from aitoolbox.experiment.result_package.torch_metrics_packages import TorchMetricsPackage


class TestTorchMetricsPackage(unittest.TestCase):
def test_prepare_results_dict_formatting_float_result(self):
metric = DummyTorchMetrics(return_float=True)
result_package = TorchMetricsPackage(metric)

result_package.prepare_result_package([], [])
result_dict = result_package.get_results()

self.assertEqual(result_dict, {'DummyTorchMetrics_PTLMetrics': 123.4})

def test_prepare_results_dict_formatting_dict_result(self):
metric = DummyTorchMetrics(return_float=False)
result_package = TorchMetricsPackage(metric)

result_package.prepare_result_package([], [])
result_dict = result_package.get_results()

self.assertEqual(result_dict, {'metric_1_PTLMetrics': 12.34, 'metric_2_PTLMetrics': 56.78})

def test_metric_compute(self):
metric = DummyTorchMetrics()
result_package = TorchMetricsPackage(metric)

for i in range(100):
result_package.metric_compute()
self.assertEqual(metric.compute_ctr, i + 1)

def test_metric_reset(self):
metric = DummyTorchMetrics()
result_package = TorchMetricsPackage(metric)

for i in range(100):
result_package.metric_reset()
self.assertEqual(metric.reset_ctr, i + 1)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ModelTrainHistoryFileWriter, MetricHistoryRename
from aitoolbox.torchtrain.train_loop import TrainLoop, TrainLoopCheckpoint
from aitoolbox.experiment.training_history import TrainingHistory
from aitoolbox.experiment.result_package.torch_metrics_packages import TorchMetricsPackage


THIS_DIR = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -184,6 +185,68 @@ def test_store_evaluated_metrics_to_history_multi_epoch_simulation(self):
{'loss': [], 'accumulated_loss': [], 'val_loss': [], 'val_dummy': [111.0, 123.0, 135.0],
'val_extended_dummy': [1323123.44, 1323135.44, 1323147.44]})

def test_torch_metrics_result_package_performance_eval_epoch_end_reset_val_data(self):
metric = DummyTorchMetrics(return_float=True)
result_package = TorchMetricsPackage(metric)

callback = ModelPerformanceEvaluation(
result_package, {},
on_each_epoch=False, on_train_data=False, on_val_data=True
)

for i in range(100):
callback.on_epoch_end()
self.assertEqual(metric.reset_ctr, i + 1)
self.assertEqual(callback.result_package.metric.reset_ctr, i + 1)

def test_torch_metrics_result_package_performance_eval_epoch_end_reset_train_val_data(self):
metric = DummyTorchMetrics(return_float=True)
result_package = TorchMetricsPackage(metric)

callback = ModelPerformanceEvaluation(
result_package, {},
on_each_epoch=False, on_train_data=True, on_val_data=True
)
self.assertEqual(callback.result_package, result_package)
self.assertNotEqual(result_package, callback.train_result_package)

for i in range(100):
callback.on_epoch_end()
self.assertEqual(metric.reset_ctr, i + 1)
self.assertEqual(callback.result_package.metric.reset_ctr, i + 1)
self.assertEqual(callback.train_result_package.metric.reset_ctr, i + 1)

def test_torch_metrics_result_package_performance_eval_tl_registration_move_to_device_val_data(self):
metric = DummyTorchMetrics(return_float=True)
result_package = TorchMetricsPackage(metric)

callback = ModelPerformanceEvaluation(
result_package, {},
on_each_epoch=False, on_train_data=False, on_val_data=True,
if_available_output_to_project_dir=False
)
callback.train_loop_obj = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)

callback.on_train_loop_registration()
self.assertEqual(metric.to_result, 'cpu_1')
self.assertEqual(callback.result_package.metric.to_result, 'cpu_1')

def test_torch_metrics_result_package_performance_eval_tl_registration_move_to_device_train_val_data(self):
metric = DummyTorchMetrics(return_float=True)
result_package = TorchMetricsPackage(metric)

callback = ModelPerformanceEvaluation(
result_package, {},
on_each_epoch=False, on_train_data=True, on_val_data=True,
if_available_output_to_project_dir=False
)
callback.train_loop_obj = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)

callback.on_train_loop_registration()
self.assertEqual(metric.to_result, 'cpu_1')
self.assertEqual(callback.result_package.metric.to_result, 'cpu_1')
self.assertEqual(callback.train_result_package.metric.to_result, 'cpu_1')


class TestModelTrainHistoryFileWriter(unittest.TestCase):
def test_execute_callback(self):
Expand Down
28 changes: 26 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def __init__(self, additional_results):

def list_additional_results_dump_paths(self):
return self.additional_results


class DummyResultPackageExtendVariable(DummyResultPackageExtend):
def __init__(self, result_d):
Expand All @@ -378,7 +378,31 @@ def __init__(self, result_d):

def prepare_results_dict(self):
return self.result_d



class DummyTorchMetrics:
def __init__(self, return_float=True):
self.return_float = return_float

self.compute_ctr = 0
self.reset_ctr = 0
self.to_result = None

def compute(self):
self.compute_ctr += 1

def reset(self):
self.reset_ctr += 1

def to(self, device):
self.to_result = f'{device}_1'

def __call__(self, *args, **kwargs):
if self.return_float:
return 123.4
else:
return {'metric_1': 12.34, 'metric_2': 56.78}


class DummyAbstractBaseMetric(AbstractBaseMetric):
def __init__(self, val):
Expand Down

0 comments on commit 245d94f

Please sign in to comment.