diff --git a/src/mattersim/forcefield/m3gnet/m3gnet.py b/src/mattersim/forcefield/m3gnet/m3gnet.py index 66a5f96..9c9bb3d 100644 --- a/src/mattersim/forcefield/m3gnet/m3gnet.py +++ b/src/mattersim/forcefield/m3gnet/m3gnet.py @@ -57,7 +57,7 @@ def __init__( in_dim=max_z + 1, out_dims=[units], activation=None, use_bias=False ) self.atom_embedding.apply(self.init_weights_uniform) - self.normalizer = AtomScaling(verbose=False, max_z=max_z) + self.normalizer = AtomScaling(verbose=False, max_z=max_z, device=device) self.max_z = max_z self.device = device self.model_args = {