diff --git a/tests/conftest.py b/tests/conftest.py index ae4035d..57836b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ def rng(request): @pytest.fixture(scope='session', params=[ *product( [Linear], - prodict(in_features=[16], out_features=[16], bias=[True, False]), + prodict(in_features=[16], out_features=[15], bias=[True, False]), ), *product( [Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d], @@ -36,7 +36,7 @@ def rng(request): ]) def module_linear(rng, request): module_type, kwargs = request.param - return module_type(**kwargs).eval() + return module_type(**kwargs).to(torch.float64).eval() @pytest.fixture(scope='session') @@ -47,8 +47,8 @@ def module_batchnorm(module_linear): ((Conv3d, ConvTranspose3d), BatchNorm3d), ] feature_index_map = [ - ((Linear, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d), 1), - ((Conv1d, Conv2d, Conv3d), 0), + ((ConvTranspose1d, ConvTranspose2d, ConvTranspose3d), 1), + ((Linear, Conv1d, Conv2d, Conv3d), 0), ] batchnorm_type = None @@ -67,12 +67,12 @@ def module_batchnorm(module_linear): if feature_index is None: raise RuntimeError('No feature index for linear layer found.') - return batchnorm_type(num_features=module_linear.weight.shape[feature_index]).eval() + return batchnorm_type(num_features=module_linear.weight.shape[feature_index]).to(torch.float64).eval() @pytest.fixture(scope='session') def data_input(rng, module_linear): - shape = (16,) + shape = (4,) setups = [ (Conv1d, 1, 1), (ConvTranspose1d, 0, 1), @@ -86,6 +86,6 @@ def data_input(rng, module_linear): else: for module_type, dim, ndims in setups: if isinstance(module_linear, module_type): - shape += (module_linear.weight.shape[dim],) + (16,) * ndims + shape += (module_linear.weight.shape[dim],) + (4,) * ndims - return torch.empty(*shape).normal_(generator=rng) + return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng)