In [1]:
import torch
import torch_scatter

from data_helpers import DataPeriodicNeighbors
from e3nn.nn.models.gate_points_2101 import Convolution, Network
from e3nn.o3 import Irreps

from pymatgen.ext.matproj import MPRester
import pymatgen.analysis.magnetism.analyzer as pg
import numpy as np
import pickle

import time

In [2]:
order_list_mp = []
formula_list_mp = []
sites_list = []
structures_list = []
y_values = []
id_list = []
order_encode = {"NM": 0, "AFM": 1, "FM": 2, "FiM": 2}

magnetic_atoms = ['Ga', 'Tm', 'Y', 'Dy', 'Nb', 'Pu', 'Th', 'Er', 'U',
                  'Cr', 'Sc', 'Pr', 'Re', 'Ni', 'Np', 'Nd', 'Yb', 'Ce',
                  'Ti', 'Mo', 'Cu', 'Fe', 'Sm', 'Gd', 'V', 'Co', 'Eu',
                  'Ho', 'Mn', 'Os', 'Tb', 'Ir', 'Pt', 'Rh', 'Ru']

m = MPRester(endpoint=None, include_user_agent=True)
structures = m.query(criteria={"elements": {"$in": magnetic_atoms}, 'blessed_tasks.GGA+U Static': {
                     '$exists': True}}, properties=["material_id", "pretty_formula", "structure", "blessed_tasks", "nsites", "magnetism"])

100%|██████████| 31739/31739 [04:10<00:00, 126.71it/s]


In [17]:
structures[0]['structure'].site_properties

{'magmom': [6.426,
  6.426,
  4.225,
  4.224,
  0.057,
  0.073,
  0.057,
  0.073,
  0.083,
  0.083]}

In [3]:
structures_copy = structures.copy()
for struc in structures_copy:
    if len(struc["structure"]) > 250:
        structures.remove(struc)
        print("MP Structure Deleted")

# %%
id_NM = []
id_FM = []
id_AFM = []
for i in range(len(structures)):
    if structures[i]["ordering"] == 'NM':
        id_NM.append(i)
    elif structures[i]["ordering"] == 'AFM':
        id_AFM.append(i)
    elif structures[i]["ordering"] == 'FM' or structures[i]["ordering"] == 'FiM':
        id_FM.append(i)
np.random.shuffle(id_FM)
np.random.shuffle(id_NM)
np.random.shuffle(id_AFM)
id_AFM, id_AFM_to_delete = np.split(id_AFM, [int(len(id_AFM))])
id_NM, id_NM_to_delete = np.split(id_NM, [int(1.2*len(id_AFM))])
id_FM, id_FM_to_delete = np.split(id_FM, [int(1.2*len(id_AFM))])

structures_mp = [structures[i] for i in id_NM] + [structures[j] for j in id_FM] + [structures[k] for k in id_AFM]
np.random.shuffle(structures_mp)


for structure in structures_mp:
    order_list_mp.append(structure["ordering"])
    structures_list.append(structure["structure"])
    formula_list_mp.append(structure["pretty_formula"])
    id_list.append(structure["material_id"])
    sites_list.append(structure["nsites"])


y_values = [order_encode[order] for order in order_list_mp]

elements = pickle.load(open('element_info.p', 'rb'))

In [13]:
structures[1]

{'material_id': 'mp-1013796',
 'pretty_formula': 'Li3Ni2(GeO4)3',
 'structure': Structure Summary
 Lattice
     abc : 10.263495690955983 10.263495690955983 10.263495690955983
  angles : 109.47122063449069 109.47122063449069 109.47122063449069
  volume : 832.2695819600563
       A : -5.925632 5.925632 5.925632
       B : 5.925632 -5.925632 5.925632
       C : 5.925632 5.925632 -5.925632
 PeriodicSite: Li (4.4442, 5.9256, 2.9628) [0.7500, 0.6250, 0.8750]
 PeriodicSite: Li (7.4070, 0.0000, 2.9628) [0.2500, 0.8750, 0.6250]
 PeriodicSite: Li (2.9628, 7.4070, 0.0000) [0.6250, 0.2500, 0.8750]
 PeriodicSite: Li (2.9628, 4.4442, 5.9256) [0.8750, 0.7500, 0.6250]
 PeriodicSite: Li (0.0000, 2.9628, 7.4070) [0.8750, 0.6250, 0.2500]
 PeriodicSite: Li (5.9256, 2.9628, 4.4442) [0.6250, 0.8750, 0.7500]
 PeriodicSite: Li (1.4814, 0.0000, 2.9628) [0.2500, 0.3750, 0.1250]
 PeriodicSite: Li (-1.4814, 5.9256, 2.9628) [0.7500, 0.1250, 0.3750]
 PeriodicSite: Li (2.9628, -1.4814, 5.9256) [0.3750, 0.7500, 0.125

In [4]:
torch.set_default_dtype(torch.float64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {'len_embed_feat': 64,
          'num_channel_irrep': 32,
          'num_e3nn_layer': 2,
          'max_radius': 5,
          'num_basis': 10,
          'adamw_lr': 0.005,
          'adamw_wd': 0.03,
          'radial_layers': 3
          }

# Used for debugging
identification_tag = "1:1:1.1 Relu wd:0.03 4 Linear"
cost_multiplier = 1.0

print('Length of embedding feature vector: {:3d} \n'.format(params.get('len_embed_feat')) +
      'Number of channels per irreducible representation: {:3d} \n'.format(params.get('num_channel_irrep')) +
      'Number of tensor field convolution layers: {:3d} \n'.format(params.get('num_e3nn_layer')) +
      'Maximum radius: {:3.1f} \n'.format(params.get('max_radius')) +
      'Number of basis: {:3d} \n'.format(params.get('num_basis')) +
      'AdamW optimizer learning rate: {:.4f} \n'.format(params.get('adamw_lr')) +
      'AdamW optimizer weight decay coefficient: {:.4f}'.format(
          params.get('adamw_wd'))
      )

Length of embedding feature vector:  64 
Number of channels per irreducible representation:  32 
Number of tensor field convolution layers:   2 
Maximum radius: 5.0 
Number of basis:  10 
AdamW optimizer learning rate: 0.0050 
AdamW optimizer weight decay coefficient: 0.0300


In [7]:
structures_list

[]

In [5]:
species = set()
count = 0
for struct in structures_list[:]:
    try:
        species = species.union(list(set(map(str, struct.species))))
        count += 1
    except:
        print(count)
        count += 1
        continue
species = sorted(list(species))
print("Distinct atomic species ", len(species))

len_element = 118
atom_types_dim = 3*len_element
embedding_dim = params['len_embed_feat']
lmax = 1
# Roughly the average number (over entire dataset) of nearest neighbors for a given atom
n_norm = 35

Distinct atomic species  0


In [None]:
input = torch.zeros(len(structures_list[0]), 3*len_element)
for j, site in enumerate(structures_list[0]):
    input[j, int(elements[str(site.specie)]['atomic_number'])] = elements[str(site.specie)]['atomic_radius']
    input[j, len_element + int(elements[str(site.specie)]['atomic_number']) + 1] = elements[str(site.specie)]['en_pauling']
    input[j, 2*len_element + int(elements[str(site.specie)]['atomic_number']) + 1] = elements[str(site.specie)]['dipole_polarizability']
dn = (DataPeriodicNeighbors(
    x=input, Rs_in=None,
    pos=torch.tensor(structures_list[0].cart_coords.copy()), lattice=structures_list[0].tensor(struct.lattice.matrix.copy()),
    r_max=params['max_radius'],
    y=(torch.tensor([y_values[i]])).to(torch.long),
    n_norm=n_norm,
    order = structures_list[0]
))

In [None]:
data = []
count = 0
indices_to_delete = []
for i, struct in enumerate(structures_list):
    try:
        print(
            f"Encoding sample {i+1:5d}/{len(structures_list):5d}", end="\r", flush=True)
        input = torch.zeros(len(struct), 3*len_element)
        for j, site in enumerate(struct):
            input[j, int(elements[str(site.specie)]['atomic_number'])] = elements[str(site.specie)]['atomic_radius']
            input[j, len_element + int(elements[str(site.specie)]['atomic_number']) + 1] = elements[str(site.specie)]['en_pauling']
            input[j, 2*len_element + int(elements[str(site.specie)]['atomic_number']) + 1] = elements[str(site.specie)]['dipole_polarizability']
        data.append(DataPeriodicNeighbors(
            x=input, Rs_in=None,
            pos=torch.tensor(struct.cart_coords.copy()), lattice=torch.tensor(struct.lattice.matrix.copy()),
            r_max=params['max_radius'],
            y=(torch.tensor([y_values[i]])).to(torch.long),
            n_norm=n_norm,
            order = struct
        ))

        count += 1
    except Exception as e:
        indices_to_delete.append(i)
        print(f"Error: {count} {e}", end="\n")
        count += 1
        continue