In [1]:
import numpy as np
import pandas as pd
import fastai
from tqdm import tqdm_notebook as tqdm
from fastai.tabular import *
import pickle

from multiprocessing import Pool
from sklearn.preprocessing import LabelEncoder, LabelBinarizer, StandardScaler
np.range = (lambda x:(x.min(), x.max()))

In [2]:
# train = pd.read_csv("train.csv")
# test = pd.read_csv('test.csv')
structures = pd.read_csv('structures.csv')

In [3]:
# tmp = train.merge(structures.rename(columns=lambda x: x+'_0'), left_on=['molecule_name', 'atom_index_0'], right_on=['molecule_name_0', 'atom_index_0'])
# joined = tmp.merge(structures.rename(columns=lambda x: x+'_1'), left_on=['molecule_name', 'atom_index_1'], right_on=['molecule_name_1', 'atom_index_1'])
# joined = joined.drop(columns=['molecule_name_0', 'molecule_name_1'])
# joined.head()

In [4]:
def add_structure_features(df):
    df['dist'] = (df['x']**2 + df['y']**2 + df['z']**2).apply(np.sqrt)
    df['v_x'], df['v_y'], df['v_z']  =  df['x']/df['dist'], df['y']/df['dist'], df['z']/df['dist']
    
    df['a_x'], df['a_y'], df['a_z']  =  df['x'].apply(np.abs), df['y'].apply(np.abs), df['z'].apply(np.abs)
    df['s_x'], df['s_y'], df['s_z']  =  df['x'].apply(np.sign), df['y'].apply(np.sign), df['z'].apply(np.sign)
    
    df['min'] = df[['x', 'y', 'z']].apply(np.abs).min(1)
    df['max'] = df[['x', 'y', 'z']].apply(np.abs).max(1)
    
    return df

In [5]:
%time structures = add_structure_features(structures)

CPU times: user 725 ms, sys: 384 ms, total: 1.11 s
Wall time: 938 ms


In [6]:
structures.head()

Unnamed: 0,molecule_name,atom_index,atom,x,y,z,dist,v_x,v_y,v_z,a_x,a_y,a_z,s_x,s_y,s_z,min,max
0,dsgdb9nsd_000001,0,C,-0.012698,1.085804,0.008001,1.085908,-0.011694,0.999904,0.007368,0.012698,1.085804,0.008001,-1.0,1.0,1.0,0.008001,1.085804
1,dsgdb9nsd_000001,1,H,0.00215,-0.006031,0.001976,0.006701,0.3209,-0.900035,0.29489,0.00215,0.006031,0.001976,1.0,-1.0,1.0,0.001976,0.006031
2,dsgdb9nsd_000001,2,H,1.011731,1.463751,0.000277,1.779373,0.568589,0.822622,0.000155,1.011731,1.463751,0.000277,1.0,1.0,1.0,0.000277,1.463751
3,dsgdb9nsd_000001,3,H,-0.540815,1.447527,-0.876644,1.776603,-0.30441,0.814772,-0.493438,0.540815,1.447527,0.876644,-1.0,1.0,-1.0,0.540815,1.447527
4,dsgdb9nsd_000001,4,H,-0.523814,1.437933,0.906397,1.778648,-0.294501,0.808442,0.509599,0.523814,1.437933,0.906397,-1.0,1.0,1.0,0.523814,1.437933


In [7]:
structures.describe().loc['std']

atom_index    5.592487
x             1.655271
y             1.989152
z             1.445870
dist          1.303192
v_x           0.548445
v_y           0.662664
v_z           0.504995
a_x           1.077326
a_y           1.215136
a_z           0.974087
s_x           0.996969
s_y           0.978171
s_z           0.999344
min           0.491218
max           1.117642
Name: std, dtype: float64

In [8]:
for col in ['x', 'y', 'z', 'dist', 'v_x', 'v_y', 'v_z', 'a_x', 'a_y', 'a_z', 'min', 'max']:
    structures[col] = ((structures[col] - structures[col].mean())/structures[col].std()).astype(np.float32)
    
for col in ['s_x', 's_y', 's_z']:
    structures[col] = structures[col].astype(np.float32)

structures.atom_index = structures.atom_index.astype(np.int16)
structures.describe()

Unnamed: 0,atom_index,x,y,z,dist,v_x,v_y,v_z,a_x,a_y,a_z,s_x,s_y,s_z,min,max
count,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0,2358657.0
mean,8.757349,-6.803775e-07,-3.788459e-06,-3.663693e-07,-2.452273e-08,1.185681e-06,-1.664546e-05,3.657462e-07,7.860106e-07,-2.692544e-07,1.494633e-06,0.07780232,-0.2077784,0.03618627,-1.908379e-06,6.010725e-07
std,5.592487,0.9991393,0.9982817,0.9991086,0.9991422,0.9992731,0.9991241,0.9991654,0.9991232,0.9992008,0.9992303,0.998116,0.983276,0.9982372,0.9996186,0.9990478
min,0.0,-5.636408,-4.826277,-6.361002,-2.060529,-1.883654,-1.420496,-2.025684,-1.169818,-1.324815,-1.098794,-1.0,-1.0,-1.0,-0.9922283,-2.037827
25%,4.0,-0.5857057,-0.7502785,-0.6258549,-0.5749148,-0.7537246,-0.8606139,-0.7410564,-0.7860946,-0.8430145,-0.858613,-1.0,-1.0,-1.0,-0.8036761,-0.6367755
50%,9.0,-0.02601123,-0.03511799,-0.03560693,-0.177933,0.01564048,-0.2368678,-0.03102189,-0.2377276,-0.0601735,-0.1862439,1.0,-1.0,1.0,-0.2782271,-0.1613324
75%,13.0,0.6169443,0.8583548,0.6065696,0.6214064,0.810725,1.109337,0.7375404,0.5979137,0.4883163,0.4768986,1.0,1.0,1.0,0.4691264,0.558412
max,28.0,5.610773,5.286521,5.417029,5.97148,1.763019,1.597627,1.934751,7.539006,7.054461,8.278973,1.0,1.0,1.0,8.492958,7.072009


In [9]:
atom_encoder = LabelEncoder()
atom_encoder = atom_encoder.fit(structures.atom)
structures.atom = atom_encoder.transform(structures.atom) + 1
structures.atom = structures.atom.astype(np.int64)

In [10]:
structures.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2358657 entries, 0 to 2358656
Data columns (total 18 columns):
molecule_name    object
atom_index       int16
atom             int64
x                float32
y                float32
z                float32
dist             float32
v_x              float32
v_y              float32
v_z              float32
a_x              float32
a_y              float32
a_z              float32
s_x              float32
s_y              float32
s_z              float32
min              float32
max              float32
dtypes: float32(15), int16(1), int64(1), object(1)
memory usage: 175.5+ MB


In [11]:
def get_mol_df(name, structures=structures):
    return structures[structures.molecule_name == name]

def parse_mol(df):
    name = df.iloc[0].molecule_name
    df = df.copy()
    n_atoms = len(df)
    df.set_index('atom_index', verify_integrity=True, inplace=True)
    df.sort_index(inplace=True)
    ret = []
    ret.append(np.zeros((16,)))
    for row in df.itertuples():
        ret.append(list(row[-16:]))
    for _ in range(len(ret), 30): ret.append(np.zeros(16))
    del df
    return (name, (np.array(ret), n_atoms))

def get_mol(x):
    return parse_mol(get_mol_df(x))

In [12]:
molecule_names = list(set(structures.molecule_name))
len(molecule_names)

130775

In [13]:
get_mol(np.random.choice(molecule_names))[1]

(array([[ 0.      ,  0.      ,  0.      ,  0.      , ...,  0.      ,  0.      ,  0.      ,  0.      ],
        [ 1.      , -1.373255,  0.136474,  1.113322, ..., -1.      ,  1.      , -0.86546 , -0.089265],
        [ 1.      , -0.978196, -0.546964,  1.276868, ..., -1.      ,  1.      ,  1.902073, -0.330502],
        [ 1.      , -0.732991, -0.918591,  0.370305, ..., -1.      ,  1.      ,  0.224803, -0.10471 ],
        ...,
        [ 0.      ,  0.      ,  0.      ,  0.      , ...,  0.      ,  0.      ,  0.      ,  0.      ],
        [ 0.      ,  0.      ,  0.      ,  0.      , ...,  0.      ,  0.      ,  0.      ,  0.      ],
        [ 0.      ,  0.      ,  0.      ,  0.      , ...,  0.      ,  0.      ,  0.      ,  0.      ],
        [ 0.      ,  0.      ,  0.      ,  0.      , ...,  0.      ,  0.      ,  0.      ,  0.      ]]),
 24)

In [14]:
%%time
with Pool(8) as pool:
    molecules = pool.map(get_mol, molecule_names)

molecules = dict(molecules)

print(' ')

 
CPU times: user 1.8 s, sys: 919 ms, total: 2.72 s
Wall time: 51min 22s


In [15]:
for key in molecules:
    molecules[key] = (molecules[key][1], tensor(molecules[key][0]).type(torch.float32))

In [16]:
with open('molecules.pkl', 'wb') as f:
    pickle.dump([molecules, structures.columns.values[-16:].tolist(), atom_encoder], f)

In [17]:
len(molecules)

130775

In [19]:
molecules[np.random.choice(molecule_names)]

(21, tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00],
         [ 5.0000e+00, -8.0519e-02,  8.7572e-01, -1.5017e-01, -9.7341e-01,
          -1.0971e-01,  1.5881e+00, -2.6163e-01, -1.1342e+00, -1.6593e-01,
          -9.3997e-01, -1.0000e+00,  1.0000e+00, -1.0000e+00, -9.1408e-01,
          -7.7823e-01],
         [ 1.0000e+00, -4.9749e-02,  1.7078e-01, -3.7279e-02, -2.0484e+00,
           1.3436e+00,  6.4106e-01,  9.8945e-01, -1.1582e+00, -1.3199e+00,
          -1.0901e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00, -9.8009e-01,
          -2.0270e+00],
         [ 1.0000e+00,  9.3229e-02, -1.0675e-01,  9.5788e-01, -8.5852e-01,
           2.2968e-01, -4.3735e-01,  1.7837e+00, -9.3849e-01, -8.7542e-01,
           3.8709e-01,  1.0000e+00, -1.0000e+00,  1.0000e+00, -4.8490e-01,
          -7.4317e-01],


In [20]:
molecules[np.random.choice(molecule_names)][1].shape

torch.Size([30, 16])

In [21]:
structures.molecule_name.unique().shape

(130775,)