# Analyze DPA labels with Ovito on MD simulation

In [2]:
import numpy as np
import ase.io as ase
from torch_nl import compute_neighborlist, ase2data

## Add DPA labels to XYZ file to visualize in Ovito

In [7]:
## add dpa labels to xyz file ##
name = "results/dpa_inter/20ns_interval_100ps_64_ACE_torch_Z_3.5_isHalo_False_maxk_50_rmax_6.0_lmax_3_nmax_2.npy"
labels_dpa = np.load(name).reshape(200, 16080)

# -1 for non Li environments
labels = np.ones((200, 21120))*-1
labels[:, :16080] = labels_dpa

# load xyz file
atoms = ase.read('20ns_interval_100ps.xyz', ':')

# add column of labels as a custom property
for i, atom in enumerate(atoms):
    atom.set_array('labels', labels[i])

# add column of index as a custom property
for i, atom in enumerate(atoms):
    atom.set_array('index', np.arange(0, 21120, 1))

# save the modified atoms with labels (optional, to verify or save the new file)
ase.write(f"results/visual/{name.split('/')[-1].split('.npy')[0]}.xyz", atoms)

## Code for Figure 4

In [None]:
def add_column_xyz(input_file, output_file, arr, col_name):
    """ 
    Add column of arr values to XYZ input file.
    
    Parameters:
        input_file (str): Path to XYZ file to add labels to.
        output_file (str): Path new .xyz file with added labels.
        arr (numpy.ndarray): Arr to add to XYZ file
        col_name (str): Name of new column.
    __________
    
    """
    # load xyz file
    atoms = ase.read(input_file, '199:200')

    # add column of labels as a custom property
    for i, atom in enumerate(atoms):
        atom.set_array(col_name, arr[i])

    # add column of index as a custom property
    for i, atom in enumerate(atoms):
        atom.set_array('index', np.arange(0, 21120, 1))

    # save the modified atoms with labels
    ase.write(output_file, atoms)

def define_clusters(labels, input_xyz_file, output_xyz_file,  mapping, start, stop, interval, relevant_labels):
    """
    Determine which species are within an atoms' cutoff, add columns to .xyz file.
    
    Parameters:
        labels (str): Path to .npy file of labels for each atom in the mapping.
        input_xyz_file (str): Path to .xyz file to add labels for each atom to.
        output_xyz_file (str): Path to .xyz file to save original .xyz file atributes and new columns.
        mapping (array 2 x X): Neighbors of each atom.
        start (int): Start index of XYZ file.
        stop (int): Stop index of XYZ file. 
        interval (int): Interval of XYZ file.
    """
    size = 200 # number of frames in the xyz file
    neighbs = np.zeros((len(relevant_labels), size, 21120))

    # for each frame add column to xyz file of neigh species for relevant clusters
    for label_frame, idx in zip(labels, np.arange(start, stop, interval)):
        print(idx)
        offset = idx*21120
        for idx_label, relevant_label in enumerate(relevant_labels):
            neighbs[idx_label][idx][mapping[1][np.where(np.isin(mapping[0], np.where(label_frame == relevant_label)[0]+offset))[0]] - offset] = 1
            if(idx_label == 0):
                add_column_xyz(input_xyz_file, output_xyz_file, neighbs[idx_label], "is"+str(relevant_label))
            else:
                add_column_xyz(output_xyz_file, output_xyz_file, neighbs[idx_label], "is"+str(relevant_label))

In [None]:
# add neighbors for interesting clusterings (modified for cluster 351)
input_xyz_file = f"results/visual/{name.split('/')[-1].split('.npy')[0]}.xyz"
output_xyz_file = input_xyz_file
mapping = np.load("results/dpa_inter/21120_atoms_rmax_6.0_lmax_3_nmax_2.npy")
relevant_labels = [351]
define_clusters(labels, input_xyz_file, output_xyz_file,  mapping, 0, 200, 1, relevant_labels)

In [13]:
# find neighbors in 20th frame for certain atoms
frames = ase.read('20ns_interval_100ps.xyz', '199:200')
pos, cell, pbc, batch, n_atoms = ase2data(frames)

cutoff = 6

mapping, batch_mapping, shifts_idx = compute_neighborlist(
    cutoff, pos, cell, pbc, batch, self_interaction = False)

# print neighbor indices to select in Ovito
# elec: 11827
# LiCl: 3779

particle = 3779
neighs = mapping[1][np.where(mapping[0] == particle)[0]]

str = ""
for neigh in neighs:
    str = str + (f"index == {neigh} || ")
print(str)