In [None]:
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm_notebook as tqdm
import gc
import time

from sklearn.model_selection import KFold, GroupKFold
from sklearn.preprocessing import LabelEncoder
import lightgbm as lgb

In [None]:
structures = pd.read_csv('../input/structures.csv')

In [None]:
structures_idx = structures.set_index('molecule_name')
df_idx = structures.set_index('molecule_name')

atoms=['H', 'C', 'N', 'O', 'F']
num_pickup = 20

ss = structures.groupby('molecule_name').size()
ss = ss.cumsum()
ssx = np.zeros(len(ss) + 1, 'int')
ssx[1:] = ss
xyz = structures[['x', 'y', 'z']].values


def get_dist_matrix(molecule):
    molecule_id = ss.index.get_loc(molecule)
    start_molecule, end_molecule = ssx[molecule_id], ssx[molecule_id + 1]
    locs = xyz[start_molecule:end_molecule]
    num_atoms = end_molecule - start_molecule
    loc_tile = np.tile(locs.T, (num_atoms, 1, 1))
    dist_mat = np.sqrt(((loc_tile - loc_tile.T)**2).sum(axis=1))

    return dist_mat


def assign_atoms_index(df, molecule):
    se_0 = df.loc[molecule]['atom_index_0']
    se_1 = df.loc[molecule]['atom_index_1']

    if se_0.dtype in ['int8' ,'int16', 'int32', 'int64']:
        se_0 = pd.Series(se_0)
    if se_1.dtype in ['int8' ,'int16', 'int32', 'int64']:
        se_1 = pd.Series(se_1)

    assign_idx = pd.concat([se_0, se_1]).unique()
    assign_idx.sort()

    return assign_idx


def get_pickup_dist_matrix(df, molecule):
    pickup_dist_matrix = np.zeros([0, len(atoms)*num_pickup])
    assigned_idxs = assign_atoms_index(df, molecule)
    dist_mat = get_dist_matrix(molecule)

    for idx in assigned_idxs:
        df_temp = structures_idx.loc[molecule]
        locs = df_temp[['x','y','z']].values

        dist_arr = dist_mat[idx]

        atoms_mol = structures_idx.loc[molecule]['atom'].values
        atoms_mol_idx = structures_idx.loc[molecule]['atom_index'].values

        mask_atoms_mol_idx = atoms_mol_idx != idx
        masked_atoms = atoms_mol[mask_atoms_mol_idx]
        masked_atoms_idx = atoms_mol_idx[mask_atoms_mol_idx]
        masked_dist_arr = dist_arr[mask_atoms_mol_idx]
        masked_locs = locs[masked_atoms_idx]

        sorting_idx = np.argsort(masked_dist_arr)
        sorted_atoms_idx = masked_atoms_idx[sorting_idx]
        sorted_atoms = masked_atoms[sorting_idx]
        sorted_dist_arr = masked_dist_arr[sorting_idx]

        target_matrix = np.zeros([len(atoms), num_pickup])

        for i, atom in enumerate(atoms):
            pickup_atom = sorted_atoms == atom
            pickup_dist = sorted_dist_arr[pickup_atom]
            num_atom = len(pickup_dist)

            if num_atom > num_pickup:
                target_matrix[i, :num_pickup]  = pickup_dist[:num_pickup]
            else:
                target_matrix[i, :num_atom] = pickup_dist

        pickup_dist_matrix = np.vstack([pickup_dist_matrix, target_matrix.reshape(-1)])

    return pickup_dist_matrix


mols = structures['molecule_name'].unique()  
dist_mat = np.zeros([0, num_pickup * len(atoms)])
atoms_idx = np.zeros([0], dtype=np.int32)
molecule_names = np.empty([0])


for mol in tqdm(mols):
    assigned_idxs = assign_atoms_index(df_idx, mol)
    dist_mat_mole = get_pickup_dist_matrix(df_idx, mol)
    mol_name_arr = [mol] * len(assigned_idxs) 

    molecule_names = np.hstack([molecule_names, mol_name_arr])
    atoms_idx = np.hstack([atoms_idx, assigned_idxs])
    dist_mat = np.vstack([dist_mat, dist_mat_mole])


col_name_list = []
for atom in atoms:
    for i in range(num_pickup):
        col_name_list.append('dist_{}_{}'.format(atom, i))


se_mol = pd.Series(molecule_names, name='molecule_name')
se_atom_idx = pd.Series(atoms_idx, name='atom_index')
dist = pd.DataFrame(dist_mat, columns=col_name_list)
dist_result = pd.concat([se_mol, se_atom_idx, dist], axis=1)
dist_result.loc[:, 'dist_H_0':'dist_F_19'] = dist_result.loc[:, 'dist_H_0':'dist_F_19'].replace(0, np.nan)

In [None]:
dist_result.to_pickle('dist_each_other.pkl')