In [1]:
from ase.io import Trajectory
import torch
torch.manual_seed(0)
import numpy as np
from nnmd.features import calculate_sf

In [21]:
traj = Trajectory('input/Li_crystal_27.traj')
positions = np.array([atoms.positions for atoms in traj])
cell = torch.tensor(traj[0].cell.array, dtype = torch.float32)
cartesians = torch.tensor(positions, dtype = torch.float32)
cartesians.requires_grad = True

In [22]:
r_cutoff = 3.54
    
params_g2 = [
            [r_cutoff, 0.001, 0.0],
            [r_cutoff, 0.01, 0.0],
            [r_cutoff, 0.03, 0.0],
            [r_cutoff, 0.05, 0.0],
            [r_cutoff, 0.7, 0.0],
            [r_cutoff, 0.1, 0.0],
            [r_cutoff, 0.2, 0.0],
            [r_cutoff, 0.4, 0.0],
            [r_cutoff, 0.5, 0.0],
            [r_cutoff, 0.7, 0.0],
            [r_cutoff, 0.9, 0.0],
            [r_cutoff, 1.0, 0.0],
    ]
params_g4 = [
            [r_cutoff, 0.01, 3, -1],
            [r_cutoff, 0.02, 3, 1]
    ]

features = [2] * len(params_g2) + [4] * len(params_g4)
params = params_g2 + params_g4
symm_funcs_data = {'features': features, 'params': params}

g, dg = calculate_sf(cartesians, cell, symm_funcs_data)

100%|██████████| 742/742 [00:27<00:00, 27.35it/s]


In [24]:
dg_norm = dg / torch.norm(g, dim=1, keepdim=True).unsqueeze(-1).expand_as(dg)
print(dg_norm)
assert torch.isnan(dg_norm).sum() == 0

tensor([[[[ 7.9721e-07,  1.1548e-06,  1.1474e-06],
          [ 6.5565e-07,  1.1548e-06,  1.0878e-06],
          [ 6.9290e-07,  9.9093e-07,  9.9838e-07],
          ...,
          [ 1.8554e-10,  6.1118e-10,  6.1118e-10],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

         [[-7.4385e-08, -7.4506e-09,  7.7486e-07],
          [-1.6511e-07, -4.4703e-08,  7.1526e-07],
          [-1.0429e-07, -4.4703e-08,  6.4075e-07],
          ...,
          [-1.9645e-10, -1.8190e-10,  1.8190e-10],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

         [[-2.2991e-07, -7.6741e-07,  9.5367e-07],
          [-2.7572e-07, -7.3016e-07,  8.4192e-07],
          [-2.8284e-07, -6.9290e-07,  7.6741e-07],
          ...,
          [-5.1659e-10, -4.0018e-10,  4.0745e-10],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

         ...,

         [[ 2.83

In [15]:
g = torch.load('g_Li.pt')
dg = torch.load('dg_Li.pt')

  g = torch.load('g_Li.pt')
  dg = torch.load('dg_Li.pt')


In [16]:
g

tensor([[[0.1891, 0.1894, 0.1901,  ..., 0.1925, 0.1854, 0.1854],
         [0.1937, 0.1937, 0.1937,  ..., 0.1924, 0.1950, 0.1949],
         [0.1915, 0.1915, 0.1914,  ..., 0.1924, 0.1915, 0.1914],
         ...,
         [0.1878, 0.1881, 0.1886,  ..., 0.1925, 0.1846, 0.1839],
         [0.1997, 0.1993, 0.1986,  ..., 0.1924, 0.2029, 0.2031],
         [0.1849, 0.1851, 0.1854,  ..., 0.1924, 0.1800, 0.1785]],

        [[0.1918, 0.1918, 0.1917,  ..., 0.1924, 0.1898, 0.1887],
         [0.1880, 0.1881, 0.1882,  ..., 0.1924, 0.1845, 0.1832],
         [0.1664, 0.1674, 0.1693,  ..., 0.1924, 0.1495, 0.1459],
         ...,
         [0.2022, 0.2017, 0.2008,  ..., 0.1924, 0.2084, 0.2091],
         [0.1946, 0.1946, 0.1946,  ..., 0.1925, 0.1963, 0.1974],
         [0.1841, 0.1843, 0.1847,  ..., 0.1924, 0.1776, 0.1757]],

        [[0.1827, 0.1831, 0.1838,  ..., 0.1924, 0.1772, 0.1760],
         [0.1820, 0.1824, 0.1832,  ..., 0.1924, 0.1738, 0.1716],
         [0.1912, 0.1911, 0.1909,  ..., 0.1924, 0.1901, 0.

In [17]:
dg

tensor([[[[-1.9561e-01,  4.5362e-01, -3.5022e-01],
          [-1.9657e-01,  4.4291e-01, -3.3154e-01],
          [-1.9725e-01,  4.2223e-01, -2.9750e-01],
          ...,
          [-1.5845e-03,  2.0733e-03, -8.4657e-04],
          [        nan,         nan,         nan],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

         [[-2.7786e-01,  1.6788e-01,  3.9921e-02],
          [-2.7381e-01,  1.5210e-01,  4.4911e-02],
          [-2.6525e-01,  1.2394e-01,  5.2893e-02],
          ...,
          [-1.0255e-03, -2.5956e-04,  1.2385e-04],
          [        nan,         nan,         nan],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

         [[-2.3657e-01, -1.3993e-01, -4.7634e-01],
          [-2.3136e-01, -1.3100e-01, -4.5748e-01],
          [-2.2076e-01, -1.1478e-01, -4.2180e-01],
          ...,
          [-5.1466e-04, -6.8897e-05, -6.0978e-04],
          [        nan,         nan,         nan],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

         ...,

         [[ 1.97