In [1]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent   
sys.path.insert(0, str(PROJECT_ROOT))

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler


from src.data import read_xyz_trajectory, split_dataset, calculate_average_distances
from src.feature_builders import add_pairwise_distance_features, add_angular_acsf_column

In [2]:
DATA_DIR = "../dataset/"

UNDESIRED_COLS = ['atom_index', 'element', 'atomic_number']


In [3]:
dataset = read_xyz_trajectory(os.path.join(DATA_DIR, "raw/malonaldehyde_300K/structures.xyz"), os.path.join(DATA_DIR, "raw/malonaldehyde_300K/energies.txt"))
test_dataset = read_xyz_trajectory(os.path.join(DATA_DIR, "raw/malonaldehyde_300K-test/structures.xyz"), os.path.join(DATA_DIR, "raw/malonaldehyde_300K-test/energies.txt"))

dataset.head(5)

Unnamed: 0,frame,atom_index,element,atomic_number,x,y,z,energy
0,0,0,O,8,7.518784,12.374943,9.872595,-44957.13136
1,0,1,C,6,7.700793,10.111345,10.143338,-44957.13136
2,0,2,C,6,10.074552,8.697429,9.988165,-44957.13136
3,0,3,C,6,12.301752,9.908006,9.874942,-44957.13136
4,0,4,O,8,12.404117,12.447734,10.120959,-44957.13136


In [4]:
dataset_with_distances = add_pairwise_distance_features(dataset)
test_dataset_with_distances = add_pairwise_distance_features(test_dataset)

dataset_with_distances.head(15)

Unnamed: 0,frame,atom_index,element,atomic_number,x,y,z,energy,dist_atom0_atom1,dist_atom0_atom2,...,dist_atom4_atom5,dist_atom4_atom6,dist_atom4_atom7,dist_atom4_atom8,dist_atom5_atom6,dist_atom5_atom7,dist_atom5_atom8,dist_atom6_atom7,dist_atom6_atom8,dist_atom7_atom8
0,0,0,O,8,7.518784,12.374943,9.872595,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
1,0,1,C,6,7.700793,10.111345,10.143338,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
2,0,2,C,6,10.074552,8.697429,9.988165,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
3,0,3,C,6,12.301752,9.908006,9.874942,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
4,0,4,O,8,12.404117,12.447734,10.120959,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
5,0,5,H,1,5.978652,8.981591,10.231535,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
6,0,6,H,1,9.954607,6.692762,9.916954,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
7,0,7,H,1,14.156946,8.958377,9.833734,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
8,0,8,H,1,10.442674,12.790426,9.943805,-44957.13136,2.286986,4.47989,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
9,1,0,O,8,7.64159,12.480086,9.958507,-44958.579728,2.541494,4.424123,...,7.23444,6.169491,3.925506,2.901738,4.630822,8.273043,5.289944,4.813972,6.253325,6.187541


In [5]:
dataset_with_distances.describe()

Unnamed: 0,frame,atom_index,atomic_number,x,y,z,energy,dist_atom0_atom1,dist_atom0_atom2,dist_atom0_atom3,...,dist_atom4_atom5,dist_atom4_atom6,dist_atom4_atom7,dist_atom4_atom8,dist_atom5_atom6,dist_atom5_atom7,dist_atom5_atom8,dist_atom6_atom7,dist_atom6_atom8,dist_atom7_atom8
count,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,...,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0,18000.0
mean,999.5,4.0,4.222222,9.999315,10.108905,10.000284,-44954.025865,2.38957,4.373372,5.330546,...,7.393351,6.17649,3.862536,2.510268,4.659481,8.183469,5.721918,4.651462,6.209107,5.725951
std,577.366235,2.582061,2.973213,2.506095,1.953081,0.199428,3.498275,0.107073,0.07737,0.109125,...,0.118645,0.099565,0.080223,0.618378,0.150401,0.123008,0.477438,0.152091,0.16868,0.470435
min,0.0,0.0,1.0,5.619141,6.493508,8.612501,-44961.645867,2.157486,4.060451,4.994395,...,7.016014,5.839437,3.528568,1.732437,4.080919,7.658558,4.750638,3.972508,5.367654,4.825919
25%,499.75,2.0,1.0,7.649643,8.817278,9.917763,-44956.414952,2.291222,4.320331,5.258319,...,7.310474,6.10936,3.810053,1.905361,4.557604,8.103176,5.276631,4.552669,6.097964,5.281281
50%,999.5,4.0,6.0,10.0,9.940645,9.999885,-44954.543949,2.379955,4.375665,5.328083,...,7.3918,6.174962,3.861777,2.691694,4.662809,8.184614,5.513225,4.651978,6.210345,5.682999
75%,1499.25,6.0,6.0,12.350928,12.367379,10.081152,-44952.0598,2.488899,4.424931,5.400755,...,7.472166,6.245507,3.915797,3.092004,4.75927,8.263735,6.162866,4.754855,6.31177,6.14777
max,1999.0,8.0,8.0,14.394796,13.537416,11.301125,-44928.723831,2.648162,4.659231,5.766183,...,7.799959,6.478957,4.166902,4.006819,5.261263,8.660023,6.91777,5.245058,6.887478,7.031228


In [6]:
feature_cols = [c for c in dataset_with_distances.columns if c.startswith("dist_")]

X = (
    dataset_with_distances
    .groupby("frame")[feature_cols]
    .first()
)

y = (
    dataset_with_distances
    .groupby("frame")["energy"]
    .first()
)   


In [7]:
X.head()

Unnamed: 0_level_0,dist_atom0_atom1,dist_atom0_atom2,dist_atom0_atom3,dist_atom0_atom4,dist_atom0_atom5,dist_atom0_atom6,dist_atom0_atom7,dist_atom0_atom8,dist_atom1_atom2,dist_atom1_atom3,...,dist_atom4_atom5,dist_atom4_atom6,dist_atom4_atom7,dist_atom4_atom8,dist_atom5_atom6,dist_atom5_atom7,dist_atom5_atom8,dist_atom6_atom7,dist_atom6_atom8,dist_atom7_atom8
frame,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,2.286986,4.47989,5.381688,4.892184,3.743752,6.182425,7.465898,2.954121,2.767304,4.613264,...,7.301573,6.257909,3.915421,1.99902,4.598469,8.187996,5.87516,4.774892,6.117225,5.33784
1,2.541494,4.424123,5.285867,4.668866,3.930339,6.270297,7.466435,1.859295,2.542266,4.58684,...,7.23444,6.169491,3.925506,2.901738,4.630822,8.273043,5.289944,4.813972,6.253325,6.187541
2,2.181502,4.283794,5.227888,4.936085,3.777824,6.134635,7.244768,3.130554,2.683973,4.564393,...,7.443128,6.345885,3.700328,1.889644,4.581375,8.105309,6.128065,4.798581,6.233883,5.063855
3,2.525948,4.414954,5.196246,4.897908,3.896338,6.369278,7.396955,1.949384,2.557639,4.61518,...,7.524088,6.243987,3.943691,3.051327,4.987738,8.233797,5.389498,4.414316,6.339805,6.143072
4,2.560887,4.394723,5.300318,5.024531,3.820987,6.247345,7.245486,1.857584,2.560666,4.712852,...,7.548933,6.263631,3.697895,3.302998,4.81481,8.130064,5.276012,4.556979,6.329595,6.103932


In [8]:
X.describe()

Unnamed: 0,dist_atom0_atom1,dist_atom0_atom2,dist_atom0_atom3,dist_atom0_atom4,dist_atom0_atom5,dist_atom0_atom6,dist_atom0_atom7,dist_atom0_atom8,dist_atom1_atom2,dist_atom1_atom3,...,dist_atom4_atom5,dist_atom4_atom6,dist_atom4_atom7,dist_atom4_atom8,dist_atom5_atom6,dist_atom5_atom7,dist_atom5_atom8,dist_atom6_atom7,dist_atom6_atom8,dist_atom7_atom8
count,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,...,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0
mean,2.38957,4.373372,5.330546,4.880893,3.859644,6.175057,7.388619,2.503042,2.649167,4.67025,...,7.393351,6.17649,3.862536,2.510268,4.659481,8.183469,5.721918,4.651462,6.209107,5.725951
std,0.107096,0.077387,0.10915,0.139141,0.081617,0.098247,0.120373,0.623359,0.132734,0.069208,...,0.118671,0.099587,0.080241,0.618516,0.150434,0.123036,0.477544,0.152125,0.168718,0.470539
min,2.157486,4.060451,4.994395,4.501471,3.472227,5.759453,7.026854,1.753825,2.400911,4.432059,...,7.016014,5.839437,3.528568,1.732437,4.080919,7.658558,4.750638,3.972508,5.367654,4.825919
25%,2.291222,4.320331,5.258319,4.784564,3.808417,6.114658,7.307808,1.89841,2.524752,4.62326,...,7.310474,6.10936,3.810053,1.905361,4.557604,8.103176,5.276631,4.552669,6.097964,5.281281
50%,2.379955,4.375665,5.328083,4.868738,3.859711,6.176705,7.384708,2.036637,2.619366,4.671356,...,7.3918,6.174962,3.861777,2.691694,4.662809,8.184614,5.513225,4.651978,6.210345,5.682999
75%,2.488899,4.424931,5.400755,4.965233,3.913634,6.242714,7.465515,3.091124,2.776574,4.716398,...,7.472166,6.245507,3.915797,3.092004,4.75927,8.263735,6.162866,4.754855,6.31177,6.14777
max,2.648162,4.659231,5.766183,5.516239,4.16172,6.500246,7.815424,3.952831,2.922267,4.902543,...,7.799959,6.478957,4.166902,4.006819,5.261263,8.660023,6.91777,5.245058,6.887478,7.031228


In [9]:
X.columns

Index(['dist_atom0_atom1', 'dist_atom0_atom2', 'dist_atom0_atom3',
       'dist_atom0_atom4', 'dist_atom0_atom5', 'dist_atom0_atom6',
       'dist_atom0_atom7', 'dist_atom0_atom8', 'dist_atom1_atom2',
       'dist_atom1_atom3', 'dist_atom1_atom4', 'dist_atom1_atom5',
       'dist_atom1_atom6', 'dist_atom1_atom7', 'dist_atom1_atom8',
       'dist_atom2_atom3', 'dist_atom2_atom4', 'dist_atom2_atom5',
       'dist_atom2_atom6', 'dist_atom2_atom7', 'dist_atom2_atom8',
       'dist_atom3_atom4', 'dist_atom3_atom5', 'dist_atom3_atom6',
       'dist_atom3_atom7', 'dist_atom3_atom8', 'dist_atom4_atom5',
       'dist_atom4_atom6', 'dist_atom4_atom7', 'dist_atom4_atom8',
       'dist_atom5_atom6', 'dist_atom5_atom7', 'dist_atom5_atom8',
       'dist_atom6_atom7', 'dist_atom6_atom8', 'dist_atom7_atom8'],
      dtype='object')

##### HDNNPs

In [10]:
average_atom_to_atom_distances = calculate_average_distances(dataset=dataset)
average_atom_to_atom_distances

array([[0.        , 2.3895697 , 4.37337171, 5.33054565, 4.8808929 ,
        3.85964417, 6.1750571 , 7.3886192 , 2.50304211],
       [2.3895697 , 0.        , 2.64916651, 4.67025009, 5.33403769,
        2.08227611, 4.04133967, 6.51698899, 3.75402536],
       [4.37337171, 2.64916651, 0.        , 2.65034721, 4.37695522,
        4.10350588, 2.05258281, 4.10097705, 4.17062049],
       [5.33054565, 4.67025009, 2.65034721, 0.        , 2.3915295 ,
        6.52000396, 4.03954408, 2.08167301, 3.75755387],
       [4.8808929 , 5.33403769, 4.37695522, 2.3915295 , 0.        ,
        7.39335121, 6.17648994, 3.86253637, 2.51026813],
       [3.85964417, 2.08227611, 4.10350588, 6.52000396, 7.39335121,
        0.        , 4.65948094, 8.18346898, 5.72191773],
       [6.1750571 , 4.04133967, 2.05258281, 4.03954408, 6.17648994,
        4.65948094, 0.        , 4.65146214, 6.20910729],
       [7.3886192 , 6.51698899, 4.10097705, 2.08167301, 3.86253637,
        8.18346898, 4.65146214, 0.        , 5.725951  ],


In [11]:
print(average_atom_to_atom_distances.min(), average_atom_to_atom_distances.max())

0.0 8.183468984060516


In [12]:
R_cutoff = average_atom_to_atom_distances.max() / 4
print(R_cutoff)

2.045867246015129


In [13]:
dataset.head(10)

Unnamed: 0,frame,atom_index,element,atomic_number,x,y,z,energy
0,0,0,O,8,7.518784,12.374943,9.872595,-44957.13136
1,0,1,C,6,7.700793,10.111345,10.143338,-44957.13136
2,0,2,C,6,10.074552,8.697429,9.988165,-44957.13136
3,0,3,C,6,12.301752,9.908006,9.874942,-44957.13136
4,0,4,O,8,12.404117,12.447734,10.120959,-44957.13136
5,0,5,H,1,5.978652,8.981591,10.231535,-44957.13136
6,0,6,H,1,9.954607,6.692762,9.916954,-44957.13136
7,0,7,H,1,14.156946,8.958377,9.833734,-44957.13136
8,0,8,H,1,10.442674,12.790426,9.943805,-44957.13136
9,1,0,O,8,7.64159,12.480086,9.958507,-44958.579728


In [43]:
dataset = add_angular_acsf_column(
    dataset,
    R_cutoff=4.0,
    col_name="G_ang_full"
)


In [44]:
dataset.head(25)

Unnamed: 0,frame,atom_index,element,atomic_number,x,y,z,energy,G_ang_full
0,0,0,O,8,7.518784,12.374943,9.872595,-44957.13136,0.003312
1,0,1,C,6,7.700793,10.111345,10.143338,-44957.13136,0.001213
2,0,2,C,6,10.074552,8.697429,9.988165,-44957.13136,3e-06
3,0,3,C,6,12.301752,9.908006,9.874942,-44957.13136,0.011653
4,0,4,O,8,12.404117,12.447734,10.120959,-44957.13136,0.005832
5,0,5,H,1,5.978652,8.981591,10.231535,-44957.13136,0.003058
6,0,6,H,1,9.954607,6.692762,9.916954,-44957.13136,1.3e-05
7,0,7,H,1,14.156946,8.958377,9.833734,-44957.13136,0.000236
8,0,8,H,1,10.442674,12.790426,9.943805,-44957.13136,0.011048
9,1,0,O,8,7.64159,12.480086,9.958507,-44958.579728,0.006644


In [45]:
dataset['G_ang_full'].describe()

count    18000.000000
mean         0.004636
std          0.006552
min          0.000000
25%          0.000172
50%          0.000992
75%          0.008098
max          0.057762
Name: G_ang_full, dtype: float64

In [46]:
dataset['G_ang_full'].value_counts()

0.000000e+00    1280
2.946314e-07       2
1.835218e-02       2
6.260960e-08       2
8.834553e-03       2
                ... 
1.123702e-03       1
5.447959e-05       1
2.551052e-03       1
1.116183e-02       1
4.888527e-04       1
Name: G_ang_full, Length: 16712, dtype: int64