- This script is used to clean the data from the OC22-2M dataset, and the "no-noise" dataset is uploaded to the zenodo database.

- We first the pretained model trained by Deepmd-kit to predict the energy and force of the OC22 dataset. We do the "dp test --detail" operation to get the predicted energy and force. The output file is named as "OC22_test.out".

In [None]:
sys=dict()  #  the results of the statistics will be stored in this dictionary
energy=[]
force=[]
with open("OC22_test.out", 'r', encoding='utf-8') as file:
    output=file.readlines()

for i in range(len(output)):
    if 'testing system' in output[i]:
        if 'test data' in output[i+1]:
            sys[output[i].split()[8]]={'energy':eval(output[i+5].split()[7]), 'force':eval(output[i+7].split()[7])}
            energy.append(eval(output[i+5].split()[7]))
            force.append(eval(output[i+7].split()[7]))
        if 'test data' in output[i+3]:
            sys[output[i].split()[8]]={'energy':eval(output[i+7].split()[7]), 'force':eval(output[i+9].split()[7])}

            energy.append(eval(output[i+7].split()[7]))
            force.append(eval(output[i+9].split()[7]))

In [None]:
# use energy_set and force_set to store the top 500 energy and top 20 force. The specific value must be justed according to the your system. 
# You also can set the baseline of energy and force. For example, you can set the baseline of energy(rmse/atom) larger than 0.05 eV and force(rmse) larger than 0.5

energy_set=[]  # use energy_set to store the top 500 energy and their system name.
for i in sorted(sys.items(), key=lambda item: item[1]['energy'], reverse=True)[:500]:
    energy_set.append(i[0])

force_set=[]
for i in sorted(sys.items(), key=lambda item: item[1]['force'], reverse=True)[:20]:
    # print(i)
    force_set.append(i[0])

In [None]:
import numpy as np
import dpdata as dp


def find_nth_element_e(list_of_lists, loc, file):
    """ Find the index of the first element in a list of lists that matches a given condition.
    Args:
        list_of_lists (list): A list of lists. For example, dp.MultiSystems()
        loc (int): The index of the sublist to search.
        file (list): The file to search.

    Returns:
        tuple: The index of the first element in the sublist that matches the condition.
    """
    # count = 0
    for index_sublist, sublist in enumerate(list_of_lists):
        for index_element, element in enumerate(sublist):
            # count += 1
            if element['energies'] == file[loc, 0]:
                return (index_sublist, index_element)

def find_nth_element_f(list_of_lists, loc, file):
    # count = 0
    for index_sublist, sublist in enumerate(list_of_lists):
        for index_element, element in enumerate(sublist):
            # count += 1
            if file[loc][:3] in element['forces'][0]:
                return (index_sublist, index_element)
            
# all_set includes all the information of energy and force which have top error.  
all_set = list(set(energy_set).add(set(force_set)))


def clean_dataset(i):
    # i 从0 开始计算
    if all_set[i] in force_set and  all_set[i] not in energy_set: 
        # print(i)
        f = np.loadtxt(f'yanzhen_{i+1}.f.out')   # yanzhen_{i+1}.f.out include the predict force information for each structure frame 
        cond = np.where((abs((f[:,0]-f[:,3])) > 1) & (abs((f[:,1]-f[:,4])) > 1) & (abs((f[:,2]-f[:,5])) > 1))  # the baseline of energy errror for each structure frame 
        cond = cond[0]

        system = dp.MultiSystems().load_systems_from_file(file_name=all_set[i], fmt="deepmd/npy/mixed")  # we use the oc22 dataset download from aisi square which the mixed format

        for j in cond:
            try:
                m, n = find_nth_element_f(system, j, f)
            except Exception as e:
                print(f"system {all_set[i]} can't be search, please check the systems")
                continue

            if n == 0:   # if sevaral atoms which have top force error in the same structure frame  
                if len(system[m]) == 1: # if only one structure frame in the system
                    system.systems.pop(system[m].formula)  # delete the system
                else:
                    system[m].dell(n)  # delete the structure frame. We add same code in dpdata's systems module to realise delete the specific structure frame 
            else:
                system[m].dell(n)
        print(system)
        dp.MultiSystems(system).to_deepmd_npy_mixed(all_set[i])  # transform the cleaned system to deepmd format

    elif all_set[i] in energy_set and  all_set[i] not in force_set:
        print(all_set[i], i)
        e = np.loadtxt(f'yanzhen_{i + 1}.e.out') # yanzhen_{i+1}.f.out include the predict energy information for each structure frame 

        try:
            cond = np.where(abs(e[:, 1] - e[:, 0]) > 5)
        except Exception as e:
            cond = np.where(abs(e[0] - e[1]) > 5)
        cond=cond[0]

        system=dp.MultiSystems().load_systems_from_file(file_name=all_set[i], fmt="deepmd/npy/mixed")
        for j in cond:
            m, n = find_nth_element_e(system, j, e)

            if n == 0:
                if len(system[m]) == 1:
                    system.systems.pop(system[m].formula)
                else:
                    system[m].dell(n)
            else:
                system[m].dell(n)
        print(system)
        dp.MultiSystems(system).to_deepmd_npy_mixed(all_set[i])

    else:
        print(all_set[i], i)   # if the system which has top energy error and top force error at the same time
        e = np.loadtxt(f'yanzhen_{i+1}.e.out')
        f = np.loadtxt(f'yanzhen_{i+1}.f.out')

        cond_f = np.where((abs((f[:,0] - f[:,3])) > 1) & (abs((f[:,1] - f[:,4])) > 1) & (abs((f[:,2]-f[:,5])) > 1)) 
        cond_f = cond_f[0]

        try:
            cond_e = np.where(abs(e[:,1]-e[:,0]) > 5)
        except Exception as e:
            cond_e = np.where(abs(e[0]-e[1])>5)
        
        cond_e = cond_e[0]

        system = dp.MultiSystems().load_systems_from_file(file_name=energy_set[i], fmt="deepmd/npy/mixed")
        for j in cond_e:
            m, n = find_nth_element_e(system, j, e)
            if n == 0:
                if len(system[m]) == 1:
                    # system[m].dell(0)
                    system.systems.pop(system[m].formula)
                else:
                    system[m].dell(n)
            else:
                system[m].dell(n)
        for j in cond_f:
            try:
                m, n = find_nth_element_f(system, j, f)
            except Exception as e:
                continue

            if n == 0:
                if len(system[m]) == 1:
                    system.systems.pop(system[m].formula)
                else:
                    system[m].dell(n)
            else:
                system[m].dell(n)
        print(system)
        dp.MultiSystems(system).to_deepmd_npy_mixed(all_set[i])

for k in range(len(all_set)):
    clean_dataset(k)