In [10]:
import os, sys 
sys.path.append('..')

import random 
import robust as rb 
import torch as ch 
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from nff.io.ase_calcs import NeuralFF, AtomsBatch, EnsembleNFF
from nff.data import Dataset
from nff.train import load_model 

from ase.io import Trajectory, read 


In [11]:
# converting xyz to dataset 
def _xyz_to_dataset(xyz_file):
    # read the xyz file
    atoms_list = read(xyz_file, index=':')

    # initialize the props dictionary 
    props = {
        'nxyz' : [],
        'energy' : [],
        'energy_grad' : [],
        'stress' : [],
        'lattice' : [],
    }
    
    for atoms in atoms_list:
        # get the atomic number and positions
        numbers = atoms.get_atomic_numbers()
        positions = atoms.get_positions()
        
        #get the lattice
        lattice = atoms.get_cell().array
        props['lattice'].append(lattice)

        # combine the atomic numbers and positions 
        nxyz = np.column_stack((numbers, positions))
        props['nxyz'].append(nxyz)

        # xtract the energy
        energy = atoms.info['REF_energy']
        props['energy'].append(energy)
        
        # get the energy gradient
        forces = atoms.todict()['REF_force']
        energy_grad = -forces  # Note the negative sign
        props['energy_grad'].append(energy_grad)

        # Extract stress if available
        if atoms.info['REF_stress'].any():
            stress = atoms.info['REF_stress']
            props['stress'].append(stress)
        else:
            props['stress'].append(None)
        
        # create the dataset object
        dataset = Dataset(props=props, units='eV')
        
        return dataset
    

import numpy as np
import torch
from ase.io import read
from nff.data import Dataset


def xyz_to_dataset(xyz_file):
    # Read the XYZ file
    atoms_list = read(xyz_file, index=':')
    
    # Initialize the props dictionary
    props = {
        'nxyz': [],
        'energy': [],
        'energy_grad': [],
        'stress': [],
        'lattice': [],
        'num_atoms': []  # Add this to keep track of the number of atoms in each structure
    }
    
    for atoms in atoms_list:
        # Extract atomic numbers and positions
        numbers = atoms.get_atomic_numbers()
        positions = atoms.get_positions()
        
        # Get the lattice
        lattice = atoms.get_cell().array
        props['lattice'].append(torch.tensor(lattice, dtype=torch.float32))
        
        # Combine atomic numbers and positions into nxyz format
        nxyz = np.column_stack((numbers, positions))
        props['nxyz'].append(torch.tensor(nxyz, dtype=torch.float32))
        
        # Extract energy
        energy = atoms.info['REF_energy']
        props['energy'].append(torch.tensor(energy, dtype=torch.float32))
        
        # Extract forces and convert to energy gradient
        forces = atoms.todict()['REF_force']
        energy_grad = -forces  # Note the negative sign
        props['energy_grad'].append(torch.tensor(energy_grad, dtype=torch.float32))
        
        # Extract stress if available
        if atoms.info['REF_stress'].any():
            stress = atoms.info['REF_stress']
            props['stress'].append(torch.tensor(stress, dtype=torch.float32))
        else:
            props['stress'].append(torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.float32))
        
        # Add the number of atoms in this structure
        props['num_atoms'].append(torch.tensor(len(atoms), dtype=torch.int64))
    
    # Convert single-value properties to tensors
    props['energy'] = torch.tensor(props['energy'])
    props['stress'] = torch.stack(props['stress'])
    props['lattice'] = torch.stack(props['lattice'])
    props['num_atoms'] = torch.tensor(props['num_atoms'])
    
    # Keep multi-dimensional properties as lists of tensors
    # props['nxyz'] and props['energy_grad'] are already lists of tensors
    
    # Create the Dataset object
    dataset = Dataset(props=props, units='eV')
    
    return dataset

In [12]:
models = []
for i in range(8):
    model_path = f'finished_runs/2024-09-27_vcrtiwzr_fep+vac+neb+perf/model_{i}/vcrtiwzr_all_e1_f50_25s_seed_{i}_compiled.model'
    m = NeuralFF.from_file(model_path,device='cpu',).model
    models.append(m)

In [13]:
ensemble = EnsembleNFF(models,device='cpu')

In [14]:
dset = xyz_to_dataset('data/fep_vac_neb_perf_test.xyz')

In [15]:
CUTOFF = 6

def get_atoms(props):
    atoms = AtomsBatch(
        positions=props['nxyz'][:,1:],
        numbers=props['nxyz'][:,0],
        cell=props['lattice'],
        pbc=True,
        cutoff=CUTOFF,
        props={'energy': 0, 'energy_grad': [], 'stress': []},
        calculator=ensemble,
        device='cpu',
    )
    _ = atoms.update_nbr_list()

    return atoms 

initial = get_atoms(random.choice(dset))

In [16]:
# define the attacker
energies_per_atom = ch.tensor(dset.props['energy']) / dset.props['num_atoms']
#forces = ch.tensor(dset.props['energy_grad'])

energy_dset = rb.PotentialDataset(
    ch.zeros_like(energies_per_atom),
    energies_per_atom,
    energies_per_atom,
)

loss_fn = rb.loss.AdvLoss(
    train=energy_dset,
    temperature=20,
)


In [17]:
attacker = rb.schnet.Attacker(
    initial,
    ensemble,
    loss_fn,
    device='cpu',
)

In [18]:
results = attacker.attack(epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]


RuntimeError: forward() Expected a value of type 'Dict[str, Tensor]' for argument 'data' but instead found type 'dict'.
Position: 1
Value: {'energy': tensor(0.), 'energy_grad': tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]), 'stress': tensor([0., 0., 0., 0., 0., 0.]), 'num_atoms': tensor([124]), 'nbr_list': tensor([[  0,   1],
        [  0,   3],
        [  0,   4],
        ...,
        [121, 122],
        [121, 123],
        [122, 123]]), 'offsets': tensor(indices=tensor([[    0,     0,     0,  ..., 12900, 12900, 12900],
                       [    0,     1,     2,  ...,     0,     1,     2]]),
       values=tensor([ 6.3086, 10.8680, -8.8094,  ...,  6.3086, 10.8680,
                      -8.8094]),
       size=(12904, 3), nnz=22692, layout=torch.sparse_coo), 'cell': tensor([[ 1.2612e+01, -3.6813e-02, -4.4531e+00],
        [-6.3037e+00,  1.0905e+01, -4.3563e+00],
        [-1.2380e-02,  9.7350e-02,  1.3324e+01]]), 'lattice': tensor([[ 1.2612e+01, -3.6813e-02, -4.4531e+00],
        [-6.3037e+00,  1.0905e+01, -4.3563e+00],
        [-1.2380e-02,  9.7350e-02,  1.3324e+01]], dtype=torch.float64), 'nxyz': tensor([[ 4.0000e+01,  7.4146e+00,  8.7504e+00, -2.7836e+00],
        [ 4.0000e+01,  2.4671e+00,  1.4621e-01,  7.0905e+00],
        [ 4.0000e+01,  1.2476e+00,  2.3217e+00,  1.1514e+01],
        [ 4.0000e+01, -1.3462e+00,  6.5356e+00,  4.3810e+00],
        [ 4.0000e+01, -2.4706e+00,  8.6695e+00,  8.2488e-01],
        [ 4.0000e+01,  3.7842e+00,  2.2125e+00,  5.8137e-02],
        [ 4.0000e+01, -2.5721e-02,  8.6543e+00, -2.4998e+00],
        [ 4.0000e+01,  1.9152e-01,  8.8615e+00,  5.3019e+00],
        [ 4.0000e+01,  5.0605e+00,  4.4237e+00,  8.8908e+00],
        [ 4.0000e+01,  4.9347e+00,  8.7714e+00,  9.8331e-01],
        [ 2.2000e+01,  6.1132e+00,  1.0829e+01, -6.1018e+00],
        [ 2.2000e+01, -1.2548e+00,  2.3091e+00,  9.7359e+00],
        [ 2.2000e+01,  8.7753e+00,  6.4054e+00, -1.7335e+00],
        [ 2.2000e+01, -2.1754e-02,  4.2954e+00,  2.6812e+00],
        [ 2.2000e+01, -8.2604e-02,  4.4695e+00,  8.0474e+00],
        [ 2.2000e+01,  5.2318e+00,  1.4455e-01,  6.1814e+00],
        [ 2.2000e+01,  3.9508e+00,  2.1815e+00,  1.0501e+01],
        [ 2.2000e+01,  1.2982e+00,  6.5365e+00,  9.8642e-01],
        [ 2.2000e+01,  7.6722e+00,  1.4307e-01, -2.5230e+00],
        [ 2.2000e+01,  7.5634e+00,  1.4962e-01,  2.5938e+00],
        [ 2.2000e+01,  7.4822e+00,  1.4709e-01,  7.8721e+00],
        [ 2.2000e+01,  5.0839e+00,  4.4377e+00, -1.5865e+00],
        [ 2.2000e+01,  3.7637e+00,  6.5361e+00,  2.7470e+00],
        [ 2.2000e+01,  2.9994e+00,  9.3758e+00,  4.8114e+00],
        [ 2.2000e+01,  8.8663e+00,  2.1270e+00, -4.3513e+00],
        [ 2.2000e+01,  8.7727e+00,  2.2294e+00,  6.1476e+00],
        [ 2.2000e+01,  7.5389e+00,  4.3442e+00,  7.1980e-02],
        [ 2.2000e+01,  5.1560e+00,  8.6788e+00,  3.7544e+00],
        [ 2.3000e+01,  1.2409e+01,  1.8510e-01,  8.6290e-01],
        [ 2.3000e+01,  1.2452e+01,  8.3813e-02,  3.6087e+00],
        [ 2.3000e+01,  6.3092e+00,  1.0939e+01,  1.9948e+00],
        [ 2.3000e+01,  1.1303e+01,  2.2009e+00, -5.3179e+00],
        [ 2.3000e+01,  1.1277e+01,  2.1472e+00, -4.8324e-02],
        [ 2.3000e+01,  1.1282e+01,  2.1815e+00,  2.6107e+00],
        [ 2.3000e+01,  9.9991e+00,  4.4390e+00,  7.1078e+00],
        [ 2.3000e+01,  1.0019e+01,  4.1195e+00, -9.1255e-01],
        [ 2.3000e+01,  1.0034e+01,  4.3643e+00,  1.8698e+00],
        [ 2.3000e+01,  8.7705e+00,  6.6091e+00,  6.1481e+00],
        [ 2.3000e+01,  8.7144e+00,  6.2790e+00, -4.5387e+00],
        [ 2.3000e+01,  8.7719e+00,  6.5710e+00,  1.1257e+00],
        [ 2.3000e+01,  7.5162e+00,  8.6071e+00, -5.5526e+00],
        [ 2.3000e+01, -4.9842e+00,  8.7528e+00,  7.3109e+00],
        [ 2.3000e+01, -3.8060e+00,  1.0816e+01,  8.0936e+00],
        [ 2.3000e+01, -3.8484e+00,  1.0858e+01, -2.6079e+00],
        [ 2.3000e+01,  2.4066e+00,  9.9451e-02,  4.1944e+00],
        [ 2.3000e+01, -3.7534e+00,  1.0883e+01,  5.5210e+00],
        [ 2.3000e+01,  1.2828e+00,  2.1490e+00,  3.5953e+00],
        [ 2.3000e+01,  1.1600e+00,  2.3502e+00,  6.2059e+00],
        [ 2.3000e+01,  1.3622e+00,  2.2789e+00,  8.7462e+00],
        [ 2.3000e+01, -7.5781e-02,  4.5331e+00, -2.5778e+00],
        [ 2.3000e+01, -1.3146e-01,  4.3342e+00,  4.8638e-03],
        [ 2.3000e+01, -1.3475e+00,  6.5144e+00,  9.7532e+00],
        [ 2.3000e+01, -1.3790e+00,  6.4477e+00, -9.5860e-01],
        [ 2.3000e+01, -1.3640e+00,  6.2347e+00,  1.5263e+00],
        [ 2.3000e+01, -1.3943e+00,  6.6350e+00,  7.2270e+00],
        [ 2.3000e+01, -2.6133e+00,  8.7412e+00,  8.9831e+00],
        [ 2.3000e+01, -2.6961e+00,  8.6969e+00, -1.8769e+00],
        [ 2.3000e+01, -2.5017e+00,  8.9338e+00,  3.8203e+00],
        [ 2.3000e+01, -2.5667e+00,  8.7896e+00,  6.3988e+00],
        [ 2.3000e+01, -1.2519e+00,  1.0824e+01, -6.0595e+00],
        [ 2.3000e+01,  5.0808e+00,  1.8648e-02,  8.8138e-01],
        [ 2.3000e+01,  4.9476e+00,  1.3566e-01,  3.4415e+00],
        [ 2.3000e+01,  3.7722e+00,  2.2114e+00,  2.7638e+00],
        [ 2.3000e+01,  3.8700e+00,  2.2503e+00,  7.8596e+00],
        [ 2.3000e+01,  2.3455e+00,  4.6263e+00,  9.7133e+00],
        [ 2.3000e+01,  2.4062e+00,  4.5797e+00, -7.7441e-01],
        [ 2.3000e+01,  2.5784e+00,  4.3002e+00,  1.9229e+00],
        [ 2.3000e+01,  2.5234e+00,  4.4887e+00,  7.0368e+00],
        [ 2.3000e+01,  1.3030e+00,  6.4104e+00, -1.7962e+00],
        [ 2.3000e+01,  1.4267e+00,  6.4792e+00,  3.7847e+00],
        [ 2.3000e+01,  1.2748e+00,  6.5028e+00,  6.4278e+00],
        [ 2.3000e+01,  1.4666e-01,  8.8371e+00,  1.4248e-01],
        [ 2.3000e+01,  2.9344e-02,  8.7019e+00,  2.5317e+00],
        [ 2.3000e+01,  1.4088e+00,  1.0888e+01,  8.4981e-01],
        [ 2.3000e+01,  6.4077e+00,  2.0303e+00, -9.0608e-01],
        [ 2.3000e+01,  6.2521e+00,  2.2974e+00,  1.7827e+00],
        [ 2.3000e+01,  6.3490e+00,  2.2838e+00,  4.4197e+00],
        [ 2.3000e+01,  6.3132e+00,  2.4440e+00,  7.1404e+00],
        [ 2.3000e+01,  5.0498e+00,  4.3563e+00,  3.5782e+00],
        [ 2.3000e+01,  5.0688e+00,  4.4946e+00,  6.1092e+00],
        [ 2.3000e+01,  3.6740e+00,  6.6287e+00,  8.0330e+00],
        [ 2.3000e+01,  3.7270e+00,  6.5517e+00, -2.6775e+00],
        [ 2.3000e+01,  3.8088e+00,  6.3555e+00, -5.3497e-02],
        [ 2.3000e+01,  3.7833e+00,  6.6091e+00,  5.4973e+00],
        [ 2.3000e+01,  2.5713e+00,  8.7025e+00, -5.9411e+00],
        [ 2.3000e+01,  2.5017e+00,  8.7175e+00, -8.9264e-01],
        [ 2.3000e+01,  2.3606e+00,  8.9141e+00,  2.0308e+00],
        [ 2.3000e+01,  1.0046e+01,  5.2505e-02,  1.6647e+00],
        [ 2.3000e+01,  9.9669e+00,  1.5159e-01,  4.3249e+00],
        [ 2.3000e+01,  1.0157e+01,  1.8834e-01,  7.1288e+00],
        [ 2.3000e+01,  8.8438e+00,  2.2589e+00, -1.6157e+00],
        [ 2.3000e+01,  8.8027e+00,  2.1762e+00,  9.3300e-01],
        [ 2.3000e+01,  8.8793e+00,  2.3602e+00,  3.5169e+00],
        [ 2.3000e+01,  7.6104e+00,  4.2379e+00, -2.7372e+00],
        [ 2.3000e+01,  7.6143e+00,  4.3794e+00,  2.7575e+00],
        [ 2.3000e+01,  7.5744e+00,  4.4722e+00,  5.3377e+00],
        [ 2.3000e+01,  6.2976e+00,  6.5272e+00,  7.1292e+00],
        [ 2.3000e+01,  6.3079e+00,  6.4506e+00, -3.4636e+00],
        [ 2.3000e+01,  6.2489e+00,  6.6003e+00, -8.0817e-01],
        [ 2.3000e+01,  6.3716e+00,  6.5278e+00,  4.5388e+00],
        [ 2.3000e+01,  5.1207e+00,  8.5235e+00, -6.8987e+00],
        [ 2.3000e+01,  4.8895e+00,  8.7183e+00, -4.4129e+00],
        [ 2.4000e+01,  1.1208e+01,  2.2071e+00, -2.7342e+00],
        [ 2.4000e+01,  1.0076e+01,  4.2684e+00, -3.5428e+00],
        [ 2.4000e+01,  9.9965e+00,  4.4821e+00,  4.4442e+00],
        [ 2.4000e+01,  8.8223e+00,  6.4946e+00,  3.6211e+00],
        [ 2.4000e+01,  7.5334e+00,  8.7286e+00,  5.3763e+00],
        [ 2.4000e+01,  5.0608e+00,  2.0514e-01,  8.9415e+00],
        [ 2.4000e+01,  2.5168e+00,  4.4337e+00,  4.4770e+00],
        [ 2.4000e+01,  1.2502e+00,  6.6235e+00,  8.9006e+00],
        [ 2.4000e+01,  7.6151e+00,  1.1651e-03,  2.3352e-03],
        [ 2.4000e+01,  6.3463e+00,  2.1420e+00,  9.8058e+00],
        [ 2.4000e+01,  5.0582e+00,  4.4144e+00,  1.0811e+00],
        [ 2.4000e+01,  2.5542e+00,  8.7086e+00, -3.5664e+00],
        [ 2.4000e+01,  9.9372e+00,  2.1468e-02, -7.4761e-01],
        [ 2.4000e+01,  7.6382e+00,  4.4388e+00,  7.9015e+00],
        [ 7.4000e+01, -5.0086e+00,  8.7917e+00,  4.6004e+00],
        [ 7.4000e+01,  1.0800e+00,  2.1437e+00,  1.0248e+00],
        [ 7.4000e+01, -1.0935e-02,  4.3574e+00,  5.3976e+00],
        [ 7.4000e+01,  3.7836e+00,  2.3327e+00,  5.2800e+00],
        [ 7.4000e+01,  2.5233e-02,  8.6950e+00, -5.2400e+00],
        [ 7.4000e+01,  6.1931e+00,  1.0933e+01,  4.5458e+00],
        [ 7.4000e+01,  6.3392e+00,  6.4710e+00,  1.8929e+00],
        [ 7.4000e+01,  4.8821e+00,  8.7523e+00, -1.7877e+00]],
       grad_fn=<AddBackward0>), 'mol_nbrs': ([tensor([[  0,   0],
        [  0,   0],
        [  0,   0],
        ...,
        [123, 123],
        [123, 123],
        [123, 123]])], [tensor([[  0,   0],
        [  0,   1],
        [  0,   2],
        ...,
        [123, 121],
        [123, 122],
        [123, 123]])], [tensor([40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 22, 22, 22, 22, 22, 22, 22, 22,
        22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24,
        24, 24, 24, 24, 24, 24, 24, 24, 74, 74, 74, 74, 74, 74, 74, 74])], [343], [tensor([[-2.3394e-02,  1.8397e-01,  2.5179e+01],
        [-4.6789e-02,  3.6793e-01,  5.0357e+01],
        [-7.0183e-02,  5.5190e-01,  7.5536e+01],
        ...,
        [ 7.0183e-02, -5.5190e-01, -7.5536e+01],
        [ 4.6789e-02, -3.6793e-01, -5.0357e+01],
        [ 2.3394e-02, -1.8397e-01, -2.5179e+01]], dtype=torch.float64)], [tensor([False,  True,  True,  ...,  True,  True,  True])]), 'mol_idx': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0]), 'delta': tensor([[ 0.0109,  0.0035, -0.0055],
        [-0.0260, -0.0031,  0.0036],
        [-0.0020, -0.0011, -0.0016],
        [-0.0015,  0.0105,  0.0004],
        [-0.0106,  0.0028, -0.0033],
        [-0.0082, -0.0040,  0.0126],
        [-0.0019, -0.0115,  0.0094],
        [ 0.0105,  0.0082,  0.0085],
        [ 0.0155, -0.0135, -0.0047],
        [ 0.0364, -0.0011,  0.0003],
        [ 0.0112, -0.0164,  0.0088],
        [ 0.0111, -0.0052, -0.0031],
        [ 0.0008, -0.0015,  0.0137],
        [-0.0062, -0.0053, -0.0048],
        [-0.0119, -0.0037, -0.0190],
        [ 0.0018, -0.0105,  0.0069],
        [ 0.0007, -0.0069,  0.0081],
        [-0.0013, -0.0153, -0.0056],
        [-0.0200, -0.0036,  0.0079],
        [-0.0142, -0.0077, -0.0085],
        [ 0.0026,  0.0069,  0.0057],
        [-0.0077,  0.0213, -0.0050],
        [ 0.0074,  0.0017,  0.0002],
        [-0.0003,  0.0073,  0.0005],
        [-0.0153, -0.0111,  0.0106],
        [-0.0002, -0.0047, -0.0145],
        [-0.0004, -0.0053,  0.0149],
        [-0.0084,  0.0144, -0.0016],
        [ 0.0013, -0.0039, -0.0091],
        [-0.0132, -0.0169,  0.0059],
        [-0.0010,  0.0216,  0.0062],
        [ 0.0005, -0.0002, -0.0045],
        [ 0.0015,  0.0088, -0.0195],
        [-0.0062, -0.0068,  0.0104],
        [-0.0059,  0.0076, -0.0012],
        [-0.0137, -0.0138,  0.0057],
        [-0.0154, -0.0019,  0.0023],
        [-0.0120,  0.0024, -0.0053],
        [-0.0110, -0.0075, -0.0067],
        [ 0.0013,  0.0009, -0.0040],
        [ 0.0018, -0.0083,  0.0044],
        [-0.0165,  0.0121,  0.0165],
        [-0.0148, -0.0042, -0.0318],
        [-0.0069,  0.0092,  0.0064],
        [ 0.0005,  0.0006,  0.0026],
        [ 0.0069, -0.0008, -0.0086],
        [ 0.0271, -0.0188, -0.0096],
        [-0.0034,  0.0140,  0.0037],
        [-0.0118, -0.0002,  0.0080],
        [ 0.0159, -0.0013,  0.0080],
        [-0.0130, -0.0084, -0.0033],
        [-0.0026,  0.0085,  0.0028],
        [ 0.0029,  0.0046, -0.0196],
        [-0.0068, -0.0158, -0.0031],
        [-0.0083, -0.0005, -0.0135],
        [-0.0332, -0.0114,  0.0066],
        [ 0.0085, -0.0180, -0.0026],
        [-0.0008, -0.0083, -0.0099],
        [-0.0160,  0.0019,  0.0090],
        [ 0.0066, -0.0043,  0.0117],
        [ 0.0016,  0.0022, -0.0091],
        [-0.0007,  0.0054,  0.0162],
        [-0.0061,  0.0110,  0.0141],
        [-0.0032,  0.0023,  0.0005],
        [ 0.0034,  0.0128,  0.0029],
        [-0.0072,  0.0072,  0.0091],
        [-0.0012, -0.0108, -0.0117],
        [ 0.0019,  0.0023, -0.0133],
        [ 0.0032, -0.0041, -0.0044],
        [-0.0063,  0.0099,  0.0021],
        [ 0.0072, -0.0059,  0.0160],
        [-0.0101, -0.0019,  0.0035],
        [-0.0151,  0.0153,  0.0030],
        [ 0.0174, -0.0071, -0.0176],
        [-0.0058, -0.0043, -0.0091],
        [-0.0176, -0.0201, -0.0063],
        [ 0.0050,  0.0005,  0.0056],
        [-0.0152, -0.0013,  0.0209],
        [-0.0167, -0.0212, -0.0094],
        [-0.0134,  0.0040, -0.0067],
        [-0.0028, -0.0175,  0.0036],
        [ 0.0053,  0.0009,  0.0065],
        [-0.0010, -0.0119, -0.0052],
        [ 0.0207, -0.0131,  0.0243],
        [-0.0066,  0.0105, -0.0191],
        [ 0.0016,  0.0088,  0.0188],
        [-0.0113, -0.0032, -0.0016],
        [-0.0088,  0.0159, -0.0113],
        [-0.0145,  0.0005, -0.0200],
        [-0.0062, -0.0061, -0.0049],
        [-0.0115,  0.0101,  0.0075],
        [ 0.0044,  0.0102, -0.0126],
        [ 0.0076,  0.0001,  0.0009],
        [-0.0037, -0.0005,  0.0037],
        [ 0.0147, -0.0138,  0.0034],
        [-0.0033,  0.0063,  0.0064],
        [-0.0153, -0.0061,  0.0042],
        [-0.0003,  0.0026,  0.0022],
        [-0.0236, -0.0104,  0.0117],
        [-0.0039, -0.0106, -0.0039],
        [ 0.0186,  0.0069,  0.0105],
        [-0.0055, -0.0010, -0.0156],
        [-0.0041,  0.0093, -0.0100],
        [ 0.0033,  0.0150,  0.0087],
        [ 0.0113, -0.0056, -0.0012],
        [-0.0229, -0.0337, -0.0178],
        [ 0.0200,  0.0007, -0.0183],
        [ 0.0084, -0.0072,  0.0025],
        [-0.0075, -0.0175, -0.0136],
        [ 0.0125, -0.0038, -0.0005],
        [-0.0005, -0.0003,  0.0021],
        [-0.0145, -0.0241,  0.0201],
        [ 0.0045,  0.0034, -0.0053],
        [ 0.0173,  0.0019, -0.0013],
        [ 0.0089,  0.0135,  0.0144],
        [ 0.0059, -0.0202, -0.0080],
        [-0.0060, -0.0037,  0.0029],
        [-0.0065, -0.0050, -0.0020],
        [-0.0163,  0.0006,  0.0031],
        [ 0.0075,  0.0068,  0.0136],
        [ 0.0091, -0.0066,  0.0014],
        [-0.0007,  0.0136,  0.0045],
        [-0.0062,  0.0004,  0.0155],
        [-0.0122,  0.0038, -0.0028]], requires_grad=True)}
Declaration: forward(__torch__.mace.modules.models.ScaleShiftMACE self, Dict(str, Tensor) data, bool training=False, bool compute_force=True, bool compute_virials=False, bool compute_stress=False, bool compute_displacement=False, bool compute_hessian=False) -> Dict(str, Tensor?)
Cast error details: Unable to cast ([tensor([[  0,   0],
        [  0,   0],
        [  0,   0],
        ...,
        [123, 123],
        [123, 123],
        [123, 123]])], [tensor([[  0,   0],
        [  0,   1],
        [  0,   2],
        ...,
        [123, 121],
        [123, 122],
        [123, 123]])], [tensor([40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 22, 22, 22, 22, 22, 22, 22, 22,
        22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24,
        24, 24, 24, 24, 24, 24, 24, 24, 74, 74, 74, 74, 74, 74, 74, 74])], [343], [tensor([[-2.3394e-02,  1.8397e-01,  2.5179e+01],
        [-4.6789e-02,  3.6793e-01,  5.0357e+01],
        [-7.0183e-02,  5.5190e-01,  7.5536e+01],
        ...,
        [ 7.0183e-02, -5.5190e-01, -7.5536e+01],
        [ 4.6789e-02, -3.6793e-01, -5.0357e+01],
        [ 2.3394e-02, -1.8397e-01, -2.5179e+01]], dtype=torch.float64)], [tensor([False,  True,  True,  ...,  True,  True,  True])]) to Tensor