diff --git a/src/progpy/data_models/lstm_model.py b/src/progpy/data_models/lstm_model.py index 853338f7..6231e1fe 100644 --- a/src/progpy/data_models/lstm_model.py +++ b/src/progpy/data_models/lstm_model.py @@ -1,6 +1,7 @@ # Copyright © 2021 United States Government as represented by the Administrator of the # National Aeronautics and Space Administration. All Rights Reserved. +from collections import abc from itertools import chain import matplotlib.pyplot as plt from numbers import Number @@ -476,8 +477,8 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise ValueError(f"layers must be greater than 0, got {params['layers']}") if np.isscalar(params['units']): params['units'] = [params['units'] for _ in range(params['layers'])] - if not isinstance(params['units'], (list, np.ndarray)): - raise TypeError(f"units must be a list of integers, not {type(params['units'])}") + if not isinstance(params['units'], (abc.Sequence, np.ndarray)): + raise TypeError(f"units must be a Sequence (e.g., list or tuple) of integers, not {type(params['units'])}") if len(params['units']) != params['layers']: raise ValueError(f"units must be a list of integers of length {params['layers']}, got {params['units']}") for i in range(params['layers']): @@ -487,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise TypeError(f"dropout must be an float greater than or equal to 0, not {type(params['dropout'])}") if params['dropout'] < 0: raise ValueError(f"dropout must be greater than or equal to 0, got {params['dropout']}") - if not isinstance(params['activation'], (list, np.ndarray)): + if not isinstance(params['activation'], (list, tuple, np.ndarray)): params['activation'] = [params['activation'] for _ in range(params['layers'])] if not np.isscalar(params['validation_split']): raise TypeError(f"validation_split must be an float between 0 and 1, not {type(params['validation_split'])}") diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 235e4b24..e3c18104 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -184,6 +184,7 @@ def future_loading(t, x=None): [future_loading for _ in range(5)], dt=[TIMESTEP, TIMESTEP/2, TIMESTEP/4, TIMESTEP*2, TIMESTEP*4], window=2, + units=(16, ), # Units as tuple epochs=20) # Should get keys from original model