Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions test/models/test_gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def test_condition_on_observations__(self):

# check that fantasies of batched model are correct
if len(batch_shape) > 0 and test_X.dim() == 2:
state_dict_non_batch = {
key: (val[0] if val.ndim > 1 else val)
for key, val in model.state_dict().items()
}
model_kwargs_non_batch = {
"train_X": train_X[0],
"train_Y": train_Y[0],
Expand All @@ -213,6 +209,20 @@ def test_condition_on_observations__(self):
if observed_noise:
model_kwargs_non_batch["train_Yvar"] = train_Yvar[0]
model_non_batch = type(model)(**model_kwargs_non_batch)
non_batch_shapes = {
key: val.shape
for key, val in model_non_batch.state_dict().items()
}
state_dict_non_batch = {}
for key, val in model.state_dict().items():
if key in non_batch_shapes:
expected_shape = non_batch_shapes[key]
if val.ndim > len(expected_shape):
state_dict_non_batch[key] = val[0]
else:
state_dict_non_batch[key] = val
else:
state_dict_non_batch[key] = val
model_non_batch.load_state_dict(state_dict_non_batch)
model_non_batch.eval()
model_non_batch.likelihood.eval()
Expand Down