diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index 2842f66e4b..bd582cc113 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -201,10 +201,6 @@ def test_condition_on_observations__(self): # check that fantasies of batched model are correct if len(batch_shape) > 0 and test_X.dim() == 2: - state_dict_non_batch = { - key: (val[0] if val.ndim > 1 else val) - for key, val in model.state_dict().items() - } model_kwargs_non_batch = { "train_X": train_X[0], "train_Y": train_Y[0], @@ -213,6 +209,20 @@ def test_condition_on_observations__(self): if observed_noise: model_kwargs_non_batch["train_Yvar"] = train_Yvar[0] model_non_batch = type(model)(**model_kwargs_non_batch) + non_batch_shapes = { + key: val.shape + for key, val in model_non_batch.state_dict().items() + } + state_dict_non_batch = {} + for key, val in model.state_dict().items(): + if key in non_batch_shapes: + expected_shape = non_batch_shapes[key] + if val.ndim > len(expected_shape): + state_dict_non_batch[key] = val[0] + else: + state_dict_non_batch[key] = val + else: + state_dict_non_batch[key] = val model_non_batch.load_state_dict(state_dict_non_batch) model_non_batch.eval() model_non_batch.likelihood.eval()