From bdff48f26a487eff989767c27a469167995346ca Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Thu, 13 Jul 2023 11:25:20 -0700 Subject: [PATCH 1/4] Allow tuple units --- src/progpy/data_models/lstm_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/progpy/data_models/lstm_model.py b/src/progpy/data_models/lstm_model.py index 58ec924e..f4496851 100644 --- a/src/progpy/data_models/lstm_model.py +++ b/src/progpy/data_models/lstm_model.py @@ -476,7 +476,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise ValueError(f"layers must be greater than 0, got {params['layers']}") if np.isscalar(params['units']): params['units'] = [params['units'] for _ in range(params['layers'])] - if not isinstance(params['units'], (list, np.ndarray)): + if not isinstance(params['units'], (list, np.ndarray, tuple)): raise TypeError(f"units must be a list of integers, not {type(params['units'])}") if len(params['units']) != params['layers']: raise ValueError(f"units must be a list of integers of length {params['layers']}, got {params['units']}") From c0a49260f028834e8bb85e9b6e49af256f632e71 Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Thu, 13 Jul 2023 11:31:48 -0700 Subject: [PATCH 2/4] Update test --- tests/test_data_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 235e4b24..e3c18104 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -184,6 +184,7 @@ def future_loading(t, x=None): [future_loading for _ in range(5)], dt=[TIMESTEP, TIMESTEP/2, TIMESTEP/4, TIMESTEP*2, TIMESTEP*4], window=2, + units=(16, ), # Units as tuple epochs=20) # Should get keys from original model From d9d532f014c2771eadfb36c5c84edc270c0894ef Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Thu, 13 Jul 2023 12:34:14 -0700 Subject: [PATCH 3/4] Changed list to sequence --- src/progpy/data_models/lstm_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/progpy/data_models/lstm_model.py b/src/progpy/data_models/lstm_model.py index f4496851..219fdc1b 100644 --- a/src/progpy/data_models/lstm_model.py +++ b/src/progpy/data_models/lstm_model.py @@ -1,6 +1,7 @@ # Copyright © 2021 United States Government as represented by the Administrator of the # National Aeronautics and Space Administration. All Rights Reserved. +from collections import abc from itertools import chain import matplotlib.pyplot as plt from numbers import Number @@ -476,8 +477,8 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise ValueError(f"layers must be greater than 0, got {params['layers']}") if np.isscalar(params['units']): params['units'] = [params['units'] for _ in range(params['layers'])] - if not isinstance(params['units'], (list, np.ndarray, tuple)): - raise TypeError(f"units must be a list of integers, not {type(params['units'])}") + if not isinstance(params['units'], (abc.Sequence, np.ndarray)): + raise TypeError(f"units must be a Sequence (e.g., list or tuple) of integers, not {type(params['units'])}") if len(params['units']) != params['layers']: raise ValueError(f"units must be a list of integers of length {params['layers']}, got {params['units']}") for i in range(params['layers']): @@ -487,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise TypeError(f"dropout must be an float greater than or equal to 0, not {type(params['dropout'])}") if params['dropout'] < 0: raise ValueError(f"dropout must be greater than or equal to 0, got {params['dropout']}") - if not isinstance(params['activation'], (list, np.ndarray)): + if not isinstance(params['activation'], (abc.Sequence, np.ndarray)): params['activation'] = [params['activation'] for _ in range(params['layers'])] if not np.isscalar(params['validation_split']): raise TypeError(f"validation_split must be an float between 0 and 1, not {type(params['validation_split'])}") From 0c33d996b295b8b894ca7a3d7766ddc40dffaaec Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Thu, 13 Jul 2023 16:26:11 -0700 Subject: [PATCH 4/4] Fix activation bug with string --- src/progpy/data_models/lstm_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/progpy/data_models/lstm_model.py b/src/progpy/data_models/lstm_model.py index 219fdc1b..fdf61487 100644 --- a/src/progpy/data_models/lstm_model.py +++ b/src/progpy/data_models/lstm_model.py @@ -488,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise TypeError(f"dropout must be an float greater than or equal to 0, not {type(params['dropout'])}") if params['dropout'] < 0: raise ValueError(f"dropout must be greater than or equal to 0, got {params['dropout']}") - if not isinstance(params['activation'], (abc.Sequence, np.ndarray)): + if not isinstance(params['activation'], (list, tuple, np.ndarray)): params['activation'] = [params['activation'] for _ in range(params['layers'])] if not np.isscalar(params['validation_split']): raise TypeError(f"validation_split must be an float between 0 and 1, not {type(params['validation_split'])}")