# Classification for Machine Learning

### Ideas:

#### Energy
- run scf calculations on N structures
- train the NN using the descriptor neighbours_spatial_dist_all and the corresponding energy
- get the energy of new structures from the NN using their descriptor neighbours_spatial_dist_all

#### Geometry optimisation:
- optimise the geometry of N structures
- train the NN using the descriptor neighbours_spatial_dist_all and the displacement of the atoms in the structures. This would be a Mx3 (M = number of atoms) array that contains the x,y,z displacements.
- get the displacement of atoms for a new structure
- How do we include the cell? First line added to the neighbours_spatial_dist_all array of the structure (+ padding)?

#### Tests and questions
- How many shells to use?
- Using the non optimised energy vs using the optimised energy
- For the geometry optimisation part: does starting from a NN optimised structure make the full convergence quicker?
- Does training the NN with the optimised structure energies predict the energy of what the optimised geometry would give?
- Any difference in using fractional or cartesian displacements?

#### Info for the NN:
- neighbours_spatial_dist_all: descriptor for the structures (list of 2d arrays N_atoms x M_neighbours, where the M_neighbours depends on how many shells are included)
- energy: enegy of the optimised structures (list of J energy values, where J depends on how many calculations were performed)
- all_disp: information about the total displacement from the initial structures, cell and atoms (list of J 2d arrays (3 + N_atoms x 3), where J is the number of calculations that were run and the (3 + N_atoms are the cell parameters + the atom coordinates)
- disp_cell: information about the displacement from the initial cell (list of J 2d arrays (3 x 3), where J is the number of calculations that were run
- disp_atoms: information about the displacement of the atoms from the initial structures (list of J 2d arrays (N_atoms x 3), where J is the number of calculations that were run 

In [1]:
from crystal_functions.file_readwrite import Crystal_output, write_properties_input, Crystal_input, write_crystal_gui,write_crystal_input, Crystal_density
from crystal_functions.convert import cry_gui2pmg, cry_out2pmg
from crystal_functions.utils import view_pmg
from crystal_functions.calculate import cry_shrink

from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.cif import CifWriter
from ase.visualize import view

from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.core.surface import SlabGenerator
import numpy as np
import pandas as pd
import os
from pathlib import Path
 



### Read the structures from the output files and replace with new_atom
This cell generates the structures list. This contains all the 4023 structures as pymatgen structure objects

In [2]:
new_atom = 'Li'

cry_output = Crystal_output('data/classification/ml/LTS_CONFCNT_ONLY.out')

cry_output.get_config_analysis()

original_structure = cry_gui2pmg('data/classification/ml/LTS_CONFCNT_ONLY.gui')
structures = []
ti_atoms = []
li_atoms = []
for j,substitutions in enumerate(cry_output.atom_type2):
    new_structure = original_structure.copy()
    for i in substitutions:
        new_structure.replace(i-1,new_atom)
    structures.append(new_structure)
    ti_atoms.append((np.array(cry_output.atom_type1[j])-1).tolist())
    li_atoms.append((np.array(cry_output.atom_type2[j])-1).tolist())

### Shells

In [3]:
shells = np.unique(np.round(structures[0].distance_matrix[0],decimals=6),return_counts=True)[0].tolist()
shells

[0.0, 2.54275, 3.595992, 4.404172, 5.0855, 5.685762, 6.22844, 7.62825]

## Space group & Irreducible atoms

In [4]:
import sys
sys.path.insert(1,'../../crystal-code-tools/crystal_functions/crystal_functions/')
from calculate import cry_shrink
space_groups = []
n_irr_atoms = []
shrink = []
for i,structure in enumerate(structures):
    symm_analyzer = SpacegroupAnalyzer(structure)
    space_groups.append(symm_analyzer.get_space_group_number())
    symm_structure = symm_analyzer.get_symmetrized_structure()
    n_irr_atoms.append(len(symm_structure.equivalent_indices))
    shrink.append(cry_shrink(symm_analyzer.find_primitive(),spacing=0.2))
space_groups = np.array(space_groups)
#n_irr_atoms = np.array(n_irr_atoms)

In [5]:
space_groups_unique, space_groups_first, space_groups_count = np.unique(space_groups,return_counts=True, return_index=True)

In [6]:
space_groups_dict = {v: np.where(space_groups == v)[0] for v in np.unique(space_groups)}

# Analyse scf results

# The magic happens here

Each element of the neighbours_spatial_dist_all (one element per structure) is a 2D array whose rows correspond to an atom and the array contains:
- 0 = atom number of the atom whose neighbours I'm analysing
- 1:6 = atomic numbers of the atoms in the first coord shell
- 7:19 = atomic numbers of the atoms in the second coord shell
- 20:28 = atomic numbers of the atoms in the third coord shell

These atomic numbers are ordered following a spatial analysis. They are written starting from the atoms below the atom 0 and moving upwards.

## Spatial distribution analysis

In [7]:
def cart2sph(x, y, z):
    hxy = np.hypot(x, y)
    r = np.hypot(hxy, z)
    el = np.arctan2(z, hxy)
    az = np.arctan2(y, x)
    if np.around(az,6) ==  np.around(2*np.pi,6) \
    or np.around(az,6) ==  -np.around(2*np.pi,6):
        az = 0.
    if np.around(az,6) < 0.:
        az = np.round(2*np.pi+az,6)
    return [round(az,6), round(el,6), round(r,6)]

## THIS ONE WORKS (it scales better when a small number of shells is included)

In [8]:
max_shell = 3
centered_sph_coords = []
centered_sph_coords_structure = []
neighbours_spatial_dist = []
neighbours_spatial_dist_all = []
import time
time0 = time.time()
for k,structure in enumerate(structures[0:100]): #TMP only calculate 100 structures
    time0 = time.time()
    neighbours_spatial_dist = []
    
    for j in range(structure.num_sites):
        centered_sph_coords = []
        neighbours_spatial_dist_atom = []
        
        for n in range(max_shell+1):
            atom_indices = np.where(np.round(structure.distance_matrix[j],5) == np.round(shells[n],5))[0].tolist()
            centered_sph_coords = []
            for i in atom_indices:

                translation_vector = structure.sites[j].distance_and_image(structure.sites[i])[1]
                new_cart_coords = structure.cart_coords[i]+(translation_vector*structure.lattice.abc)
                centered_cart_coords = new_cart_coords-structure.cart_coords[j] 

                centered_sph_coords.append(cart2sph(centered_cart_coords[0],centered_cart_coords[1],centered_cart_coords[2]))        

            spatial_distribution = np.argsort(np.array(centered_sph_coords)[:,1]*10 +\
                                              np.array(centered_sph_coords)[:,0])


            neighbours_spatial_dist_atom.extend((np.array(structure.atomic_numbers)[np.array(atom_indices)[spatial_distribution]]).tolist())
        neighbours_spatial_dist.append(neighbours_spatial_dist_atom)
    #print(time.time()-time0)
    neighbours_spatial_dist_all.append(neighbours_spatial_dist) 
        

In [9]:
neighbours_spatial_dist_all[1]

[[3,
  16,
  16,
  16,
  16,
  16,
  16,
  22,
  3,
  3,
  3,
  3,
  3,
  3,
  22,
  3,
  3,
  3,
  3,
  16,
  16,
  16,
  16,
  16,
  16,
  16,
  16],
 [3,
  16,
  16,
  16,
  16,
  16,
  16,
  3,
  22,
  3,
  3,
  22,
  3,
  3,
  3,
  3,
  22,
  3,
  3,
  16,
  16,
  16,
  16,
  16,
  16,
  16,
  16],
 [3,
  16,
  16,
  16,
  16,
  16,
  16,
  3,
  3,
  3,
  3,
  3,
  3,
  22,
  3,
  22,
  22,
  3,
  3,
  16,
  16,
  16,
  16,
  16,
  16,
  16,
  16],
 [3,
  16,
  16,
  16,
  16,
  16,
  16,
  22,
  3,
  3,
  22,
  22,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  16,
  16,
  16,
  16,
  16,
  16,
  16,
  16],
 [3,
  16,
  16,
  16,
  16,
  16,
  16,
  3,
  22,
  22,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  22,
  3,
  16,
  16,
  16,
  16,
  16,
  16,
  16,
  16],
 [3,
  16,
  16,
  16,
  16,
  16,
  16,
  3,
  22,
  3,
  3,
  3,
  3,
  22,
  22,
  3,
  3,
  3,
  3,
  16,
  16,
  16,
  16,
  16,
  16,
  16,
  16],
 [3,
  16,
  16,
  16,
  16,
  16,
  16,
  22,
  22,
  3,
  3,
  3,
  3,
  3,
  3,


### Group same structures (none for LTS)

In [11]:
neighbours_spatial_dist_all_sorted = []
a = np.array(neighbours_spatial_dist_all[0])
a = neighbours_spatial_dist_all[0]

for k,structure in enumerate(structures[0:100]):
    sorted_atoms = []
    for i in range(len(neighbours_spatial_dist_all[0])):
        sorted_atoms.append(int(''.join([str(x) for x in neighbours_spatial_dist_all[k][i]])))
    neighbours_spatial_dist_all_sorted.append((np.array(neighbours_spatial_dist_all)[k][np.argsort(np.array(sorted_atoms))]).tolist())

In [12]:
neighbours_spatial_dist_id = []
a = np.array(neighbours_spatial_dist_all[0])
a = neighbours_spatial_dist_all[0]

for k,structure in enumerate(structures):
    sorted_atoms = []
    for i in range(len(neighbours_spatial_dist_all[0])):
        sorted_atoms.append((''.join([str(x) for x in neighbours_spatial_dist_all[k][i]])))
    neighbours_spatial_dist_id.append(int(''.join(sorted_atoms)))

IndexError: list index out of range

# Write inputs

## One per space group in reverse

In [None]:
crystal_input = Crystal_input('data/classification/ml/pseudo_input.d12')
with open('data/classification/ml/qsub_input.qsub','r') as file:
    qsub_input = file.readlines()

    
for structure in reversed(np.unique(space_groups,return_counts=True, return_index=True)[1]):
    
    write_crystal_gui('data/classification/ml/pseudo/LTS_%s_PBE.gui'%structure,structures[structure],pseudo_atoms=[22])
    write_crystal_input('data/classification/ml/pseudo/LTS_%s_PBE.d12'%structure,crystal_input=crystal_input)
    queue = qsub_input.copy()
    queue[1] = '#PBS -N %s\n'%structure
    queue.append('/rds/general/user/gmallia/home/CRYSTAL17_cx1/v2.2gnu/runcryP LTS_%s_PBE'%structure)
    with open('data/classification/ml/pseudo/%s.qsub'%structure, 'w') as file:
        file.writelines(queue)
    #print('qsub %s.qsub'%structure)
    

## All space groups organised in folders

In [None]:
crystal_input = Crystal_input('data/classification/ml/crystal_input.d12')

with open('data/classification/ml/qsub_input.qsub_SAVE','r') as file:
    qsub_input = file.readlines()

for space_group in reversed(space_groups_unique[-10:]):
    #print(len(space_groups_dict[space_group]))
    queue = qsub_input.copy()
    queue[1] = '#PBS -N %s\n'%space_group

    Path('data/classification/ml/sg_%s'%space_group).mkdir(parents=True, exist_ok=True)
    for structures_n in space_groups_dict[space_group]:
        crystal_input.scf_block[0][1] = '%s %s \n'%(shrink[structures_n], shrink[structures_n]*2)
        write_crystal_gui('data/classification/ml/sg_%s/LTS_%s_PBE.gui'%(space_group,structures_n),structures[structures_n])
        write_crystal_input('data/classification/ml/sg_%s/LTS_%s_PBE.d12'%(space_group,structures_n),crystal_input=crystal_input)
        
        queue.append('/rds/general/user/gmallia/home/CRYSTAL17_cx1/v2.2gnu/runcryP LTS_%s_PBE\n'%structures_n)
    with open('data/classification/ml/sg_%s/sg_%s.qsub'%(space_group,space_group), 'w') as file:
        file.writelines(queue)
    '''print('cd sg_%s'%space_group)
    print('qsub sg_%s.qsub'%space_group)
    print('cd ..')'''

## Analyse the structures

In [13]:
import sys
sys.path.insert(1,'../../crystal-code-tools/crystal_functions/crystal_functions/')
from convert import cry_gui2pmg

initial_positions = structures[0].cart_coords
initial_cell = structures[0].lattice.matrix
structure_id = []
not_converged = []
energy = []
all_disp = []
output_folder = "data/classification/ml/outputs_optgeom/"
for file in os.listdir(output_folder):
    if file.endswith(".out"):
        structure_n = int(file.split('_')[1])
        output = Crystal_output(os.path.join(output_folder, file))
        if output.converged == True:
            structure_id.append(structure_n)
            #print(os.path.join("data/classification/ml/outputs_optgeom/", file[0:14]))

            energy.append(output.get_final_energy())
            last_geom = sorted(os.listdir(os.path.join(output_folder,"LTS_%s_PBE.optstory"%structure_n)))[-1]
            pmg_structure = cry_gui2pmg(os.path.join(output_folder,"LTS_%s_PBE.optstory/%s"%(structure_n,last_geom)))
            pmg_structure.lattice
            disp_atoms = pmg_structure.cart_coords - initial_positions
            disp_cell = pmg_structure.lattice.matrix - initial_cell
            all_disp.append(np.concatenate((disp_cell,disp_atoms),axis=0))
        else:
            not_converged.append(structure_n)


In [14]:
all_disp

[array([[ 0.00000000e+00, -1.61139691e-01, -1.61139691e-01],
        [-1.61139691e-01,  0.00000000e+00, -1.61139691e-01],
        [-1.61139691e-01, -1.61139691e-01,  0.00000000e+00],
        [-8.54427185e-02,  2.33078659e-18,  2.33078659e-18],
        [ 9.78926259e-02, -2.05319086e-01, -6.74323956e-18],
        [ 5.37132302e-02,  1.39155949e-01,  2.87965275e-18],
        [ 9.78926259e-02, -6.74323956e-18, -2.05319086e-01],
        [-1.07426460e-01, -2.05319086e-01, -2.05319086e-01],
        [-8.04660200e-02,  1.34179250e-01, -1.34179250e-01],
        [ 5.37132302e-02,  2.87965275e-18,  1.39155949e-01],
        [-8.04660200e-02, -1.34179250e-01,  1.34179250e-01],
        [ 1.87892480e-01,  1.34179250e-01,  1.34179250e-01],
        [-8.04660200e-02, -1.34179250e-01, -1.34179250e-01],
        [ 2.67527897e-02, -2.69604405e-02, -1.34179250e-01],
        [ 5.37132302e-02, -1.51382065e-17, -1.39155949e-01],
        [ 2.67527897e-02, -1.34179250e-01, -2.69604405e-02],
        [-1.07426460e-01

# THE END

### SHRINK test

In [None]:
struct = cry_gui2pmg('data/classification/ml/optc021')
inp = Crystal_input('data/classification/ml/crystal_input.d12')

with open('data/classification/ml/qsub_input.qsub_SAVE','r') as file:
    qsub_input = file.readlines()

energy = []
for i in range(4,13):

    energy.append(Crystal_output('data/classification/ml/test_shrink/LTS_%s.out'%i).get_final_energy())

In [None]:
for j,i in enumerate(energy[1:]):
    print(energy[j]-energy[j+1])

# TESTING

## THIS ONE WORKS (it scales better when a large number of shells is included)
It doesn't work for the primitive cell

In [None]:
max_shell = 2
centered_sph_coords = []
centered_sph_coords_structure = []
neighbours_spatial_dist = []
neighbours_spatial_dist_all = []
import time
time0 = time.time()
for k,structure in enumerate(structures[100:105]):
    print('STRUCTURE',structure)
    time0 = time.time()
    neighbours_spatial_dist = []
    for j in range(structure.num_sites):
        centered_sph_coords = []
        for i in range(structure.num_sites):  
            print(i,j)
            translation_vector = structure.sites[j].distance_and_image(structure.sites[i])[1]
            new_cart_coords = structure.cart_coords[i]+(translation_vector*structure.lattice.abc)
            #print(translation_vector)
            #print(new_cart_coords,structure.cart_coords[i])
            centered_cart_coords = new_cart_coords-structure.cart_coords[j] 
            centered_sph_coords.append(cart2sph(centered_cart_coords[0],centered_cart_coords[1],centered_cart_coords[2]))        
            #print(cart2sph(centered_cart_coords[0],centered_cart_coords[1],centered_cart_coords[2]))
        #centered_sph_coords_all.append(centered_sph_coords)
        centered_sph_coords = np.array(centered_sph_coords)
        #print(centered_sph_coords)

        atoms_shell = []
        for unique in shells:
            atoms_shell.append(np.where(np.round(centered_sph_coords[:,2],5) == np.round(unique,5))[0].tolist())
            
        #print(atoms_shell)
        neighbours_spatial_dist_atom = []
        for shell in range(max_shell+1):
            #a = np.array(centered_sph_coords)[atoms_shell[shell]]
            spatial_distribution = np.argsort(centered_sph_coords[atoms_shell[shell]][:,1]*10 + centered_sph_coords[atoms_shell[shell]][:,0])
            #print(np.array(structure.atomic_numbers)[np.array(atoms_shell[shell])[spatial_distribution]])
            #print(np.array(atoms_shell[shell])[spatial_distribution])
            
            neighbours_spatial_dist_atom.extend((np.array(structure.atomic_numbers)[np.array(atoms_shell[shell])[spatial_distribution]]).tolist())
            #THIS IS THE CONTROL LINE
            #print(centered_sph_coords[np.array(atoms_shell[shell])[spatial_distribution]])
        neighbours_spatial_dist.append(neighbours_spatial_dist_atom)
    neighbours_spatial_dist_all.append(neighbours_spatial_dist)
    print(time.time()-time0)
    #centered_sph_coords_structure = np.array(centered_sph_coords_all)
