Skip to content

Commit

Permalink
Merge pull request #656 from mv1388/prediction-metadata-combine
Browse files Browse the repository at this point in the history
Prediction metadata combination support for tensors and np.arrays
  • Loading branch information
mv1388 committed Jul 12, 2022
2 parents 4d497cd + d87fb1a commit 78d9371
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 6 deletions.
3 changes: 2 additions & 1 deletion aitoolbox/torchtrain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def get_predictions(self, batch_data, device):
device: device on which the model is making the prediction
Returns:
np.array, np.array, dict: y_pred.cpu(), y_test.cpu(), metadata
torch.Tensor, torch.Tensor, dict: y_pred.cpu(), y_test.cpu(), metadata
in the form of dict of lists/torch.Tensors/np.arrays
"""
pass

Expand Down
4 changes: 4 additions & 0 deletions aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def predict_on_train_set(self, force_prediction=False):
Returns:
(torch.Tensor, torch.Tensor, dict): y_pred, y_true, metadata
in the form of dict of lists/torch.Tensors/np.arrays
"""
if not self.prediction_store.has_train_predictions(self.total_iteration_idx) or force_prediction:
predictions = self.predict_with_model(self.train_loader)
Expand All @@ -580,6 +581,7 @@ def predict_on_validation_set(self, force_prediction=False):
Returns:
(torch.Tensor, torch.Tensor, dict): y_pred, y_true, metadata
in the form of dict of lists/torch.Tensors/np.arrays
"""
if not self.prediction_store.has_val_predictions(self.total_iteration_idx) or force_prediction:
predictions = self.predict_with_model(self.validation_loader)
Expand All @@ -598,6 +600,7 @@ def predict_on_test_set(self, force_prediction=False):
Returns:
(torch.Tensor, torch.Tensor, dict): y_pred, y_true, metadata
in the form of dict of lists/torch.Tensors/np.arrays
"""
if not self.prediction_store.has_test_predictions(self.total_iteration_idx) or force_prediction:
predictions = self.predict_with_model(self.test_loader)
Expand All @@ -616,6 +619,7 @@ def predict_with_model(self, data_loader):
Returns:
(torch.Tensor, torch.Tensor, dict): y_pred, y_true, metadata
in the form of dict of lists/torch.Tensors/np.arrays
"""
self.model = self.model.to(self.device)

Expand Down
27 changes: 23 additions & 4 deletions aitoolbox/utils/dict_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import collections
import copy
import numpy as np
import torch

from aitoolbox.utils.util import flatten_list_of_lists


def combine_prediction_metadata_batches(metadata_list):
"""Combines a list of dicts with the same keys and lists as values into a single dict with concatenated lists
for each corresponding key
"""Combines a list of dicts with the same keys and [lists or torch.Tensors or np.arrays] as values into
a single dict with concatenated [lists or torch.Tensors or np.arrays] for each corresponding key
Args:
metadata_list (list): list of dicts with matching keys and lists for values
metadata_list (list): list of dicts with matching keys and [lists or torch.Tensors or np.arrays] for values
Returns:
dict: combined single dict
Expand All @@ -18,7 +22,22 @@ def combine_prediction_metadata_batches(metadata_list):
for meta_el in metadata_batch:
if meta_el not in combined_metadata:
combined_metadata[meta_el] = []
combined_metadata[meta_el] += metadata_batch[meta_el]

combined_metadata[meta_el].append(metadata_batch[meta_el])

for meta_el in combined_metadata:
metadata_elements_list = combined_metadata[meta_el]

if isinstance(metadata_elements_list[0], list):
combined_metadata[meta_el] = flatten_list_of_lists(metadata_elements_list)
elif isinstance(metadata_elements_list[0], torch.Tensor):
combined_metadata[meta_el] = torch.cat(metadata_elements_list, dim=0)
elif isinstance(metadata_elements_list[0], np.ndarray):
combined_metadata[meta_el] = np.concatenate(metadata_elements_list, axis=0)
else:
raise TypeError(f'Provided metadata element data type which is not supported '
f'by the function (type: {type(metadata_elements_list[0])}). '
f'Function supports the following data types: list, torch.Tensor and np.array')

return combined_metadata

Expand Down
119 changes: 118 additions & 1 deletion tests/test_utils/test_dict_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import unittest
import numpy as np
import torch

from aitoolbox.utils import dict_util

Expand Down Expand Up @@ -32,7 +34,7 @@ def test_combine_metadata_dicts_with_elements_missing(self):
'meta_2_special': list(range(4)) * 2}
)

def test_combine_metadata_dicts_with_varying_elements(self):
def test_combine_metadata_dicts_with_varying_list_elements(self):
metadata_batches = [
{'meta_1': [1, 100, 1000, 10000], 'meta_2': list(range(4)), 'meta_2_special': list(range(2))},
{'meta_1': [1, 100, 1000, 10000], 'meta_2': list(range(4))},
Expand All @@ -49,6 +51,121 @@ def test_combine_metadata_dicts_with_varying_elements(self):
'completely_new_meta': ['334', '1000', 'bla']}
)

def test_combine_metadata_dicts_with_torch_tensors(self):
meta_1 = torch.rand(20, 50)
meta_2 = torch.randint(0, 100, (20, 12))

metadata_batches = [
{'meta_1': meta_1, 'meta_2': meta_2},
{'meta_1': meta_1, 'meta_2': meta_2},
{'meta_1': meta_1, 'meta_2': meta_2},
{'meta_1': meta_1, 'meta_2': meta_2}
]

combined_metadata = dict_util.combine_prediction_metadata_batches(metadata_batches)
self.assertEqual(sorted(combined_metadata.keys()), ['meta_1', 'meta_2'])
self.assertEqual(combined_metadata['meta_1'].shape, (meta_1.shape[0] * 4, meta_1.shape[1]))
self.assertEqual(combined_metadata['meta_2'].shape, (meta_2.shape[0] * 4, meta_2.shape[1]))
for vals in combined_metadata.values():
self.assertIsInstance(vals, torch.Tensor)

self.assertEqual(
{k: v.tolist() for k, v in combined_metadata.items()},
{
'meta_1': torch.cat([meta_1, meta_1, meta_1, meta_1]).tolist(),
'meta_2': torch.cat([meta_2, meta_2, meta_2, meta_2]).tolist()
}
)

def test_combine_metadata_dicts_with_nunmpy_arrays(self):
meta_1 = np.random.rand(20, 50)
meta_2 = np.random.randint(0, 100, (20, 12))

metadata_batches = [
{'meta_1': meta_1, 'meta_2': meta_2},
{'meta_1': meta_1, 'meta_2': meta_2},
{'meta_1': meta_1, 'meta_2': meta_2},
{'meta_1': meta_1, 'meta_2': meta_2}
]

combined_metadata = dict_util.combine_prediction_metadata_batches(metadata_batches)
self.assertEqual(sorted(combined_metadata.keys()), ['meta_1', 'meta_2'])
self.assertEqual(combined_metadata['meta_1'].shape, (meta_1.shape[0] * 4, meta_1.shape[1]))
self.assertEqual(combined_metadata['meta_2'].shape, (meta_2.shape[0] * 4, meta_2.shape[1]))
for vals in combined_metadata.values():
self.assertIsInstance(vals, np.ndarray)

self.assertEqual(
{k: v.tolist() for k, v in combined_metadata.items()},
{
'meta_1': np.concatenate([meta_1, meta_1, meta_1, meta_1]).tolist(),
'meta_2': np.concatenate([meta_2, meta_2, meta_2, meta_2]).tolist()
}
)

def test_combine_metadata_dicts_with_mixed_lists_torch_tensors_np_arrays(self):
meta_torch_1 = torch.rand(20, 50)
meta_torch_2 = torch.randint(0, 100, (20, 12))
meta_np_1 = np.random.rand(20, 25)
meta_np_2 = np.random.randint(0, 200, (20, 34))
meta_list_1 = list(range(20))
meta_list_2 = list(range(19, -1, -1))
meta_list_3 = [
[1, 2, 3],
[4, 5, 3],
[3, 3, 3]
]

metadata_batches = [
{'meta_torch_1': meta_torch_1, 'meta_torch_2': meta_torch_2,
'meta_np_1': meta_np_1, 'meta_np_2': meta_np_2,
'meta_list_1': meta_list_1, 'meta_list_2': meta_list_2, 'meta_list_3': meta_list_3},
{'meta_torch_1': meta_torch_1, 'meta_torch_2': meta_torch_2,
'meta_np_1': meta_np_1, 'meta_np_2': meta_np_2,
'meta_list_1': meta_list_1, 'meta_list_2': meta_list_2, 'meta_list_3': meta_list_3},
{'meta_torch_1': meta_torch_1, 'meta_torch_2': meta_torch_2,
'meta_np_1': meta_np_1, 'meta_np_2': meta_np_2,
'meta_list_1': meta_list_1, 'meta_list_2': meta_list_2, 'meta_list_3': meta_list_3},
{'meta_torch_1': meta_torch_1, 'meta_torch_2': meta_torch_2,
'meta_np_1': meta_np_1, 'meta_np_2': meta_np_2,
'meta_list_1': meta_list_1, 'meta_list_2': meta_list_2, 'meta_list_3': meta_list_3}
]

combined_metadata = dict_util.combine_prediction_metadata_batches(metadata_batches)
self.assertEqual(
sorted(combined_metadata.keys()),
sorted(['meta_torch_1', 'meta_torch_2', 'meta_np_1', 'meta_np_2',
'meta_list_1', 'meta_list_2', 'meta_list_3'])
)
self.assertEqual(combined_metadata['meta_torch_1'].shape, (meta_torch_1.shape[0] * 4, meta_torch_1.shape[1]))
self.assertEqual(combined_metadata['meta_torch_2'].shape, (meta_torch_2.shape[0] * 4, meta_torch_2.shape[1]))
self.assertEqual(combined_metadata['meta_np_1'].shape, (meta_np_1.shape[0] * 4, meta_np_1.shape[1]))
self.assertEqual(combined_metadata['meta_np_2'].shape, (meta_np_2.shape[0] * 4, meta_np_2.shape[1]))
self.assertEqual(len(combined_metadata['meta_list_1']), len(meta_list_1) * 4)
self.assertEqual(len(combined_metadata['meta_list_2']), len(meta_list_2) * 4)
self.assertEqual(len(combined_metadata['meta_list_3']), len(meta_list_3) * 4)

self.assertIsInstance(combined_metadata['meta_torch_1'], torch.Tensor)
self.assertIsInstance(combined_metadata['meta_torch_2'], torch.Tensor)
self.assertIsInstance(combined_metadata['meta_np_1'], np.ndarray)
self.assertIsInstance(combined_metadata['meta_np_2'], np.ndarray)
self.assertIsInstance(combined_metadata['meta_list_1'], list)
self.assertIsInstance(combined_metadata['meta_list_2'], list)
self.assertIsInstance(combined_metadata['meta_list_3'], list)

self.assertEqual(
{k: v.tolist() if not isinstance(v, list) else v for k, v in combined_metadata.items()},
{
'meta_torch_1': torch.cat([meta_torch_1, meta_torch_1, meta_torch_1, meta_torch_1]).tolist(),
'meta_torch_2': torch.cat([meta_torch_2, meta_torch_2, meta_torch_2, meta_torch_2]).tolist(),
'meta_np_1': np.concatenate([meta_np_1, meta_np_1, meta_np_1, meta_np_1]).tolist(),
'meta_np_2': np.concatenate([meta_np_2, meta_np_2, meta_np_2, meta_np_2]).tolist(),
'meta_list_1': meta_list_1 + meta_list_1 + meta_list_1 + meta_list_1,
'meta_list_2': meta_list_2 + meta_list_2 + meta_list_2 + meta_list_2,
'meta_list_3': meta_list_3 + meta_list_3 + meta_list_3 + meta_list_3
}
)


class TestFlattenDict(unittest.TestCase):
def test_flatten_dict(self):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def test_flatten_list_of_lists(self):
self.assertEqual(util.flatten_list_of_lists([[1, 2, 3], [4, 5], [3, 3, 3, 3]]),
[1, 2, 3, 4, 5, 3, 3, 3, 3])

self.assertEqual(
util.flatten_list_of_lists(
[
[[1, 2, 3], [4, 5, 3], [3, 3, 3]],
[[10, 2, 3], [40, 5, 3], [30, 3, 3]],
[[100, 2, 3], [400, 5, 3], [300, 3, 3]]
]),
[[1, 2, 3], [4, 5, 3], [3, 3, 3], [10, 2, 3], [40, 5, 3], [30, 3, 3], [100, 2, 3], [400, 5, 3], [300, 3, 3]]
)


class EmptyFunctions:
def __init__(self):
Expand Down

0 comments on commit 78d9371

Please sign in to comment.