In [1]:
import json
import torch
import feather
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from itertools import product
from collections import defaultdict
from rdkit.Chem import AllChem as Chem
from tqdm import tqdm_notebook as tqdm

In [2]:
settings = json.load(open('SETTINGS.json'))['PREPARE']
DATA_PATH = Path(settings['INPUTS'])
OUTPUTS = Path(settings['OUTPUTS'])
MULLIKEN_TRAIN = Path(settings['MULLIKEN_TRAIN'])
MULLIKEN_TEST = Path(settings['MULLIKEN_TEST'])

## XYZ to SDF

In [3]:
chunksout = OUTPUTS / 'sdfchunks'
chunksout.mkdir(exist_ok=True, parents=True)

In [4]:
xyzs = DATA_PATH / 'xyz'

In [5]:
for i in tqdm(range(14)):
    ifmt = f'{i:02}'
    glob = f'{xyzs}/dsgdb9nsd_{ifmt}'
    !obabel -i xyz {glob}* -o sdf -O {chunksout}/sdfchunk-{ifmt}.sdf 2>/dev/null
!cat {chunksout}/* > {OUTPUTS / 'structures.sdf'}

HBox(children=(IntProgress(value=0, max=14), HTML(value='')))




## Fix OBabel SDF

### SDF fix utils

In [6]:
def mol_name(mol):
    return mol.GetProp('_Name')

def read_sdf(path):
    mols = []
    sup = Chem.SDMolSupplier(path, sanitize=False, removeHs=False)
    for mol in sup:
        if mol is None:
            continue
        mol.SetProp('_Name', mol.GetProp('_Name')[-20:-4])
        mols.append(mol)
    return mols

VALENCES = {'C': 4, 'N': 3, 'O': 2, 'H': 1, 'F': 1}

def calc_valence(atom):
    """ valence property returns wrong results on broken molecules """
    val = 0
    for b in atom.GetBonds():
        val += int(b.GetBondTypeAsDouble())
    return val

def calc_formal_charge(atom):
    elem = atom.GetSymbol()
    val = calc_valence(atom)
    return val - VALENCES[elem]

BOND_TYPES = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE]

def inc_bond_order(bt, inc=1):
    idx = BOND_TYPES.index(bt) + inc
    if not 0 <= idx < len(BOND_TYPES):
        raise ValueError(f"Can't increment {bt} with {inc}")
    return BOND_TYPES[idx]

def dec_bond_order(bt, dec=1):
    return inc_bond_order(bt, -dec)

def methane(coords):
    m = Chem.AddHs(Chem.MolFromSmiles('C'))
    c = Chem.Conformer()
    for ai in coords.atom_index:
        c.SetAtomPosition(ai, coords.loc[ai, ['x', 'y', 'z']].values)
    m.AddConformer(c)
    m.SetProp('_Name', 'dsgdb9nsd_000001')
    return m

### Fixer functions

In [7]:
def clear_radicals(rwmol):
    edits = []
    for a in rwmol.GetAtoms():
        if a.GetNumRadicalElectrons():
            edits.append({
                'atom_index_0': a.GetIdx(),
                'atom_0': a.GetSymbol()
            })
            a.SetNumRadicalElectrons(0)
            a.SetNoImplicit(True)
    return edits

def fix_inc_bond_order(rwmol):
    edits = []
    for bond in rwmol.GetBonds():
        a1 = bond.GetBeginAtom()
        a2 = bond.GetEndAtom()
        cmp1 = calc_formal_charge(a1)
        cmp2 = calc_formal_charge(a2)
        if cmp1 < 0 and cmp2 < 0:
            edits.append({
                'atom_index_0': a1.GetIdx(),
                'atom_0': a1.GetSymbol(),
                'atom_index_1': a2.GetIdx(),
                'atom_1': a2.GetSymbol(),
                'bond_index_0': bond.GetIdx()
            })
            bond.SetBondType(inc_bond_order(bond.GetBondType(), -max(cmp1, cmp2)))
    return edits

def fix_dec_bond_order(rwmol):
    edits = []
    for bond in rwmol.GetBonds():
        a1 = bond.GetBeginAtom()
        a2 = bond.GetEndAtom()
        cmp1 = calc_formal_charge(a1)
        cmp2 = calc_formal_charge(a2)
        if cmp1 > 0 and cmp2 > 0:
            edits.append({
                'atom_index_0': a1.GetIdx(),
                'atom_0': a1.GetSymbol(),
                'atom_index_1': a2.GetIdx(),
                'atom_1': a2.GetSymbol(),
                'bond_index_0': bond.GetIdx()
            })
            bond.SetBondType(dec_bond_order(bond.GetBondType(), min(cmp1, cmp2)))
    return edits

def fix_swap_bonds_d2(rwmol):
    edits = []
    for atom in rwmol.GetAtoms():
        if calc_formal_charge(atom):
            continue
        
        extra_valence = []
        missing_valence = []
        
        for bond in atom.GetBonds():
            other = bond.GetOtherAtom(atom)
            if bond.GetBondType() != Chem.BondType.SINGLE and calc_formal_charge(other) > 0:
                extra_valence.append(other)
            if bond.GetBondType() != Chem.BondType.TRIPLE and calc_formal_charge(other) < 0 and other.GetFormalCharge() == 0:
                missing_valence.append(other)
        
        if len(extra_valence) == 1 and len(missing_valence) == 1:
            edit = {
                'atom_index_0': extra_valence[0].GetIdx(),
                'atom_0': extra_valence[0].GetSymbol(),
                'atom_index_1': missing_valence[0].GetIdx(),
                'atom_1': missing_valence[0].GetSymbol()
            }
            b = rwmol.GetBondBetweenAtoms(atom.GetIdx(), extra_valence[0].GetIdx())
            b.SetBondType(dec_bond_order(b.GetBondType()))
            edit['bond_index_0'] = b.GetIdx()
            b = rwmol.GetBondBetweenAtoms(atom.GetIdx(), missing_valence[0].GetIdx())
            b.SetBondType(inc_bond_order(b.GetBondType()))
            edit['bond_index_1'] = b.GetIdx()
            edits.append(edit)

    return edits

def find_path(path_atoms, path_bonds, start, end):
    if start.GetIdx() == end.GetIdx():
        return path_bonds
    bonds = start.GetBonds()
    prev_bond = path_bonds[-1]
    for b in bonds:
        if b.GetBondType() != prev_bond.GetBondType():
            other = b.GetOtherAtom(start)
            if other not in path_atoms:
                path = find_path(path_atoms + [start], path_bonds + [b], other, end)
                if path:
                    return path
    
    return None
    

def fix_swap_bonds_path(rwmol):
    edits = []
    extra_valence = None
    missing_valence = None
    for a in rwmol.GetAtoms():
        if calc_formal_charge(a) > 0:
            if extra_valence is None:
                extra_valence = a
            else:
                return []
        if calc_formal_charge(a) < 0 and a.GetFormalCharge() == 0:
            if missing_valence is None:
                missing_valence = a
            else:
                return []
    
    path = None
    if extra_valence is not None and missing_valence is not None:
        for b in extra_valence.GetBonds():
            if b.GetBondType() == Chem.BondType.DOUBLE:
                path = find_path([extra_valence], [b], b.GetOtherAtom(extra_valence), missing_valence)
                if path:
                    break
    if path:
        edit = {
            'atom_index_0': extra_valence.GetIdx(),
            'atom_0': extra_valence.GetSymbol(),
            'atom_index_1': missing_valence.GetIdx(),
            'atom_1': missing_valence.GetSymbol()
        }
        for i, b in enumerate(path):
            b.SetBondType(Chem.BondType.SINGLE if b.GetBondType() == Chem.BondType.DOUBLE else Chem.BondType.DOUBLE)
            edit[f'bond_index_{i}'] = b.GetIdx()
        edits.append(edit)
    
    return edits

def fix_multi_plus(rwmol):
    edits = []
    extra_valence = []
    for a in rwmol.GetAtoms():
        if calc_formal_charge(a) > 0:
            extra_valence.append(a)
    
    if len(extra_valence) == 2:
        a1, a2 = extra_valence
        for b1, b2 in product(a1.GetBonds(), a2.GetBonds()):
            if b1.GetBondType() == Chem.BondType.DOUBLE and b2.GetBondType() == Chem.BondType.DOUBLE:
                o1 = b1.GetOtherAtom(a1)
                o2 = b2.GetOtherAtom(a2)
                b12 = rwmol.GetBondBetweenAtoms(o1.GetIdx(), o2.GetIdx())
                if b12 is not None and b12.GetBondType() == Chem.BondType.SINGLE:
                    edits.append({
                        'atom_index_0': a1.GetIdx(),
                        'atom_0': a1.GetSymbol(),
                        'atom_index_1': a2.GetIdx(),
                        'atom_1': a2.GetSymbol(),
                        'bond_index_0': b1.GetIdx(),
                        'bond_index_1': b12.GetIdx(),
                        'bond_index_2': b2.GetIdx()
                        
                    })
                    b1.SetBondType(Chem.BondType.SINGLE)
                    b2.SetBondType(Chem.BondType.SINGLE)
                    b12.SetBondType(Chem.BondType.DOUBLE)
                    break
    
    return edits
    
def fix_multi_minus(rwmol):
    edits = []
    missing_valence = []
    for a in rwmol.GetAtoms():
        if calc_formal_charge(a) < 0:
            missing_valence.append(a)
    
    if len(missing_valence) == 2:
        a1, a2 = missing_valence
        for b1, b2 in product(a1.GetBonds(), a2.GetBonds()):
            if b1.GetBondType() == Chem.BondType.SINGLE and b2.GetBondType() == Chem.BondType.SINGLE:
                o1 = b1.GetOtherAtom(a1)
                o2 = b2.GetOtherAtom(a2)
                b12 = rwmol.GetBondBetweenAtoms(o1.GetIdx(), o2.GetIdx())
                if b12 is not None and b12.GetBondType() == Chem.BondType.DOUBLE:
                    edits.append({
                        'atom_index_0': a1.GetIdx(),
                        'atom_0': a1.GetSymbol(),
                        'atom_index_1': a2.GetIdx(),
                        'atom_1': a2.GetSymbol(),
                        'bond_index_0': b1.GetIdx(),
                        'bond_index_1': b12.GetIdx(),
                        'bond_index_2': b2.GetIdx()
                    })
                    b1.SetBondType(Chem.BondType.DOUBLE)
                    b2.SetBondType(Chem.BondType.DOUBLE)
                    b12.SetBondType(Chem.BondType.SINGLE)
                    break
    
    return edits

def fix_c5_to_n(rwmol):
    edits = []
    plus = []
    for a in rwmol.GetAtoms():
        if a.GetSymbol() == 'N' and calc_formal_charge(a) > 0:
            plus.append(a)
    
    if len(plus) == 0:
        return []
    
    c_to_n = []
    for a in rwmol.GetAtoms():
        if a.GetSymbol() == 'C':
            if calc_valence(a) == 5:
                double_n = [(b, b.GetOtherAtom(a)) for b in a.GetBonds() 
                            if b.GetBondType() == Chem.BondType.DOUBLE 
                            and b.GetOtherAtom(a).GetSymbol() == 'N'
                            and calc_formal_charge(b.GetOtherAtom(a)) == 0]
                if len(double_n) == 1:
                    b, n = double_n[0]
                    c_to_n.append((a, n, b))
    
    if len(c_to_n) != len(plus):
        return []
    
    for c, n, b in c_to_n:
        edits.append({
            'atom_index_0': c.GetIdx(),
            'atom_0': c.GetSymbol(),
            'atom_index_1': n.GetIdx(),
            'atom_1': n.GetSymbol(),
            'bond_index_0': b.GetIdx()
        })
        b.SetBondType(Chem.BondType.SINGLE)
    return edits

def fix_c3_to_n(rwmol):
    edits = []
    minus = []
    for a in rwmol.GetAtoms():
        if a.GetSymbol() == 'N' and calc_formal_charge(a) < 0:
            minus.append(a)
    
    if len(minus) == 0:
        return []
    
    c_to_n = []
    for a in rwmol.GetAtoms():
        if a.GetSymbol() == 'C':
            if calc_valence(a) == 3:
                double_n = [(b, b.GetOtherAtom(a)) for b in a.GetBonds() 
                            if b.GetBondType() == Chem.BondType.SINGLE 
                            and b.GetOtherAtom(a).GetSymbol() == 'N'
                            and calc_formal_charge(b.GetOtherAtom(a)) == 0]
                if len(double_n) == 1:
                    b, n = double_n[0]
                    c_to_n.append((a, n, b))
    
    if len(c_to_n) != len(minus):
        return []
    
    for c, n, b in c_to_n:
        edits.append({
            'atom_index_0': c.GetIdx(),
            'atom_0': c.GetSymbol(),
            'atom_index_1': n.GetIdx(),
            'atom_1': n.GetSymbol(),
            'bond_index_0': b.GetIdx()
        })
        b.SetBondType(Chem.BondType.DOUBLE)
    return edits

def fix_balance_charge(rwmol):
    edits = []
    minus = []
    plus = []
    
    for a in rwmol.GetAtoms():
        cmp = calc_formal_charge(a)
        if cmp < 0:
            minus.append(a)
            
        if cmp > 0:
            plus.append(a)
    
    if len(plus) != 1 and len(minus) != 1:
        return []
    
    edits.append({
        'atom_index_0': minus[0].GetIdx(),
        'atom_0': minus[0].GetSymbol(),
        'atom_index_1': plus[0].GetIdx(),
        'atom_1': plus[0].GetSymbol()
    })
    minus[0].SetFormalCharge(-1)
    plus[0].SetFormalCharge(1)    
    
    return edits

NITRO_PATTERN = Chem.MolFromSmarts("N(=O)~O")
def fix_n_plus_n_ring(rwmol):
    if rwmol.HasSubstructMatch(NITRO_PATTERN):
        return []
    non_ring_n_plus = []
    ring_plus = []
    ring_n = []
    for a in rwmol.GetAtoms():
        if a.GetSymbol() == 'N' and calc_formal_charge(a) == 1 and not a.IsInRing():
            non_ring_n_plus.append(a)
        if a.IsInRing() and calc_formal_charge(a) == 1:
            ring_plus.append(a)
        if a.GetSymbol() == 'N' and a.IsInRing() and len([b for b in a.GetBonds() if b.GetBondType() == Chem.BondType.DOUBLE]) == 1:
            ring_n.append(a)
    if len(non_ring_n_plus) != 1 or len(ring_plus) != 1 or len(ring_n) < 1:
        return []
    
    positions = rwmol.GetConformer(0).GetPositions()
    nplus_pos = positions[non_ring_n_plus[0].GetIdx()]
    nring_pos = np.asarray([positions[a.GetIdx()] for a in ring_n])
    dists = ((nring_pos - nplus_pos) ** 2).sum(axis=1)
    other = ring_n[np.argmin(dists)]
    non_ring_n_plus[0].SetFormalCharge(1)
    other.SetFormalCharge(-1)
    bond = [b for b in other.GetBonds() if b.GetBondType() == Chem.BondType.DOUBLE][0]
    bond.SetBondType(Chem.BondType.SINGLE)
    return [{
        'atom_index_0': non_ring_n_plus[0].GetIdx(),
        'atom_0': non_ring_n_plus[0].GetSymbol(),
        'atom_index_1': other.GetIdx(),
        'atom_1': other.GetSymbol(),
        'bond_index_0': bond.GetIdx()
    }]

CYANO_PATT = Chem.MolFromSmarts("N#CC")
def fix_tripple_n_to_2_plus(rwmol):
    if not rwmol.HasSubstructMatch(CYANO_PATT):
        return []
    plus = []
    for a in rwmol.GetAtoms():
        if a.GetSymbol() == 'N' and calc_formal_charge(a) == 1:
            plus.append(a)
    if len(plus) != 2:
        return []

    for a in plus:
        a.SetFormalCharge(1)

    return [{
        'atom_index_0': plus[0].GetIdx(),
        'atom_0': plus[0].GetSymbol(),
        'atom_index_1': plus[1].GetIdx(),
        'atom_1': plus[1].GetSymbol()
    }]

### Fix SDF

In [8]:
def fix_sdf():
    mols = read_sdf(str(OUTPUTS / 'structures.sdf'))
    rwmols = [Chem.RWMol(mol) for mol in mols]
    edits = []
    for m in tqdm(rwmols):
        try:
            Chem.SanitizeMol(m)
        except Exception as e:
            for f in (clear_radicals, fix_inc_bond_order, fix_dec_bond_order,
                      fix_n_plus_n_ring, fix_swap_bonds_d2,
                      fix_multi_plus, fix_multi_minus, fix_swap_bonds_path,
                      fix_c3_to_n, fix_c5_to_n, fix_balance_charge, fix_tripple_n_to_2_plus):
                eds = f(m)
                for edit in eds:
                    edit['molecule_name'] = mol_name(m)
                    edit['edit'] = f.__name__
                edits.extend(eds)
    
    edits_df = pd.DataFrame(edits)
    edits_df.to_csv(OUTPUTS / 'edits.csv')
    
    fixed_mols = [m.GetMol() for m in rwmols]
    
#     structuresdf = pd.read_csv(DATA_PATH / 'structures.csv')
#     coords = structuresdf[structuresdf.molecule_name == 'dsgdb9nsd_000001']
#     fixed_mols.append(methane(coords))
    
    w = Chem.SDWriter(str(OUTPUTS / 'fixed_structures.sdf'))
    for mol in tqdm(fixed_mols):
        w.write(mol)

fix_sdf()

HBox(children=(IntProgress(value=0, max=130775), HTML(value='')))




HBox(children=(IntProgress(value=0, max=130775), HTML(value='')))




## Angles and distances

In [9]:
def relative_coords():
    sup = Chem.SDMolSupplier(str(OUTPUTS / 'fixed_structures.sdf'), removeHs=False)
    
    # This only occasionaly works ???
    # mols = list(iter(sup))
    mols = []
    for m in sup:
        mols.append(m)
    
    print(len(mols))
    
    structuresdf = pd.read_csv(DATA_PATH / 'structures.csv')
    
    # Add methane, that for some reason wasn't loaded from sdf
    #coords = structuresdf[structuresdf.molecule_name == 'dsgdb9nsd_000001']
    #mols.append(methane(coords))
    
    mm = {mol_name(mol): mol for mol in mols}
    conformers = defaultdict(Chem.Conformer)
    structnp = structuresdf.values

    for row in structnp:
        c = conformers[row[0]]
        c.SetAtomPosition(int(row[1]), row[3:])
    
    structures_relative = []

    for mol in tqdm(mols):
        name = mol_name(mol)
        dm = Chem.GetDistanceMatrix(mol)
        c = conformers[name]
        for i in range(mol.GetNumAtoms()):
            for j in range(i):
                bdist = dm[i, j]
                if (bdist > 3):
                    continue
                a_i = mol.GetAtomWithIdx(i)
                a_j = mol.GetAtomWithIdx(i)
                dist_i_j = (np.asarray(c.GetAtomPosition(i) - c.GetAtomPosition(j)) ** 2).sum() ** 0.5
                angle = 0
                dihedral = 0

                path = Chem.GetShortestPath(mol, i, j)
                if bdist == 2:
                    angle = Chem.GetAngleRad(c, *path)
                if bdist == 3:
                    dihedral = Chem.GetDihedralRad(c, *path)

                bond = mol.GetBondBetweenAtoms(i, j)
                bond_type = str(bond.GetBondType()) if bond is not None else ''

                structures_relative.append(dict(
                    molecule_name=name,
                    atom_index_0=i,
                    atom_index_1=j,
                    distance=dist_i_j,
                    graph_distance=bdist,
                    angle=angle,
                    dihedral=dihedral,
                    bond_type=bond_type
                ))
    structures_relative_df = pd.DataFrame(structures_relative)
    structures_relative_df.angle[structures_relative_df.graph_distance != 2] = None
    structures_relative_df.dihedral[structures_relative_df.graph_distance != 3] = None
    structures_relative_df.to_feather(OUTPUTS / 'structures_relative.feather')

relative_coords()

130775


HBox(children=(IntProgress(value=0, max=130775), HTML(value='')))




A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


## Edge and Node properties

### Molecules

In [10]:
def mol_props():
    structs = feather.read_dataframe(OUTPUTS / 'structures_relative.feather')
    train = pd.read_csv(DATA_PATH / '/train.csv')
    test = pd.read_csv(DATA_PATH / '/test.csv')
    
    mols = np.sort(structs.molecule_name.unique())
    mol_df = pd.DataFrame(data={'molecule_name': mols, 'molecule_id': range(len(mols))})
    mol_ids = dict(zip(mol_df.molecule_name, mol_df.molecule_id))
    
    train_mols = set(train.molecule_name.map(mol_ids.__getitem__).unique())
    test_mols = set(test.molecule_name.map(mol_ids.__getitem__).unique())

    mol_df['is_train'] = mol_df.molecule_id.map(train_mols.__contains__)
    mol_df['is_test'] = mol_df.molecule_id.map(test_mols.__contains__)
    
    mol_df['molecule_name'] = mol_df.molecule_name.astype('category')
    mol_df.to_feather(OUTPUTS / 'mols.feather')
    return mol_df

### Nodes

In [11]:
atom_types = dict(zip('HCNOF', range(5)))
atom_nums = dict(zip('HCNOF', [1, 6, 7, 8, 9]))
atom_eneg = dict(zip('HCNOF', [2.20, 2.55, 3.04, 3.44, 3.98]))
atom_eaff = dict(zip('HCNOF', [72.8, 153.9, 7, 141, 328]))
atom_val = dict(zip('HCNOF', [1, 4, 5, 2, 1]))
atom_1st_ion = dict(zip('HCNOF', [1312.0, 1086.5, 1402.3, 1313.9, 1681.0]))

In [12]:
def node_props(mols=None):
    structs = pd.read_csv(DATA_PATH / '/structures.csv')
    if mols is None:
        mols = feather.read_dataframe(OUTPUTS / 'mols.feather')
    mol_ids = dict(zip(mols.molecule_name, mols.molecule_id))
    
    structs['molecule_id'] = structs.molecule_name.map(mol_ids.__getitem__)
    structs['atom_type'] = structs.atom.map(atom_types.__getitem__)
    structs['atom_num'] = structs.atom.map(atom_nums.__getitem__)
    structs['atom_eneg'] = structs.atom.map(atom_eneg.__getitem__)
    structs['atom_eaff'] = structs.atom.map(atom_eaff.__getitem__)
    structs['atom_val'] = structs.atom.map(atom_val.__getitem__)
    structs['atom_1st_ion'] = structs.atom.map(atom_1st_ion.__getitem__)
    
    structs['molecule_name'] = structs.molecule_name.astype('category')
    structs['atom'] = structs.atom.astype('category')
    
    structs = structs.merge(mols, on=['molecule_id'], suffixes=('', '_y'))
    del structs['molecule_name_y']
    
    nodes = structs[[
    'molecule_name',
    'molecule_id',
    'atom_index',
    'atom',
    'atom_type',
    'atom_num',
    'atom_eneg',
    'atom_eaff',
    'atom_val',
    'atom_1st_ion',
    'x', 'y', 'z',
    'is_train',
    'is_test']]
    
    nodes.to_feather(OUTPUTS / 'nodes.feather')

### Edges

In [13]:
edge_type_ids = dict(zip(['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC', '2JUMP', '3JUMP'], range(6)))

def get_edge_type(t):
    bond, gdist = t
    return {
        1.: bond,
        2.: '2JUMP',
        3.: '3JUMP',
    }[gdist]

def edge_props(mols=None):
    if mols is None:
        mols = feather.read_dataframe(OUTPUTS / 'mols.feather')
    mol_ids = dict(zip(mols.molecule_name, mols.molecule_id))
    edges = feather.read_dataframe('data/structures_relative.feather')
    
    edges['edge_type'] = edges[['bond_type', 'graph_distance']].apply(get_edge_type, axis=1)
    edges['edge_type_id'] = edges.edge_type.map(edge_type_ids.__getitem__)
    edges['molecule_id'] = edges.molecule_name.map(mol_ids.__getitem__)
    
    edges = edges[[
        'molecule_name',
        'molecule_id',
        'atom_index_0',
        'atom_index_1',
        'edge_type',
        'edge_type_id',
        'graph_distance',
        'distance',
        'angle',
        'dihedral']]
    
    edges['graph_distance'] = edges.graph_distance.map(int)
    
    train = pd.read_csv(DATA_PATH / '/train.csv')
    test = pd.read_csv(DATA_PATH / '/test.csv')
    
    train['molecule_id'] = train.molecule_name.map(mol_ids.__getitem__)
    test['molecule_id'] = test.molecule_name.map(mol_ids.__getitem__)
    
    atom_index_0 = train[['atom_index_0', 'atom_index_1']].max(axis=1)
    atom_index_1 = train[['atom_index_0', 'atom_index_1']].min(axis=1)
    train['atom_index_0'] = atom_index_0
    train['atom_index_1'] = atom_index_1

    atom_index_0 = test[['atom_index_0', 'atom_index_1']].max(axis=1)
    atom_index_1 = test[['atom_index_0', 'atom_index_1']].min(axis=1)
    test['atom_index_0'] = atom_index_0
    test['atom_index_1'] = atom_index_1
    
    edges = edges.merge(train, how='outer', on=['molecule_id', 'atom_index_0', 'atom_index_1'], suffixes=('', '_y'))
    edges.rename(columns={'id': 'train_id', 'type': 'coupling_type'}, inplace=True)
    edges['is_train'] = ~edges.scalar_coupling_constant.isnull()

    del edges['molecule_name_y']
    
    edges = edges.merge(test, how='outer', on=['molecule_id', 'atom_index_0', 'atom_index_1'], suffixes=('', '_y'))
    edges.rename(columns={'id': 'test_id', 'type': 'test_coupling_type'}, inplace=True)
    
    edges.loc[~edges.test_coupling_type.isnull(), 'coupling_type'] = edges.test_coupling_type
    edges['is_test'] = ~edges.test_id.isnull()

    del edges['test_coupling_type']
    
    del edges['molecule_name_y']
    
    edges = edges.merge(mols, on=['molecule_id'], suffixes=('', '_y'))
    del edges['molecule_name_y']
    
    edges.rename(columns={'is_train_y': 'is_mol_train', 'is_test_y': 'is_mol_test'}, inplace=True)
    
    coupling_type_id = dict(zip(np.sort(edges.coupling_type.dropna().unique()), range(8)))
    coupling_type_id[np.NaN] = -1

    edges['coupling_type_id'] = edges.coupling_type.map(coupling_type_id.__getitem__)
    
    edges = edges[[
        'molecule_name',
        'molecule_id',
        'atom_index_0',
        'atom_index_1',
        'edge_type',
        'edge_type_id',
        'graph_distance',
        'distance',
        'angle',
        'dihedral',
        'is_train',
        'is_test',
        'is_mol_train',
        'is_mol_test',
        'train_id',
        'test_id',
        'coupling_type',
        'coupling_type_id',
        'scalar_coupling_constant'
    ]]
    
    edges['molecule_name'] = edges.molecule_name.astype('category')
    edges['edge_type'] = edges.edge_type.astype('category')
    edges['coupling_type'] = edges.coupling_type.astype('category')
    
    edges.to_feather(OUTPUTS / 'edges.feather')
    

In [14]:
mol_props()

Unnamed: 0,molecule_name,molecule_id,is_train,is_test
0,dsgdb9nsd_000001,0,True,False
1,dsgdb9nsd_000002,1,True,False
2,dsgdb9nsd_000003,2,True,False
3,dsgdb9nsd_000004,3,False,True
4,dsgdb9nsd_000005,4,True,False
...,...,...,...,...
130770,dsgdb9nsd_133881,130770,True,False
130771,dsgdb9nsd_133882,130771,True,False
130772,dsgdb9nsd_133883,130772,False,True
130773,dsgdb9nsd_133884,130773,True,False


In [15]:
node_props()

In [16]:
edge_props()

## Make Graphs

In [17]:
valence_enc = {1: 10, 2: 11, 4: 12, 5: 13}

def unique_normalizer(series):
    uq = series.unique()
    return (series - np.mean(uq)) / np.std(uq)

def classic_normalizer(series):
    return (series - series.mean()) / series.std()

def create_node_dict(mol):
    return {
        'nodes': len(mol),
        'ndata': {
            'type': torch.as_tensor(mol.atom_type.values, dtype=torch.int64),
            'valence': torch.as_tensor(mol.atom_val_enc.values, dtype=torch.int64),
            'el_neg': torch.as_tensor(mol.atom_eneg_n.values, dtype=torch.float32),
            'el_aff': torch.as_tensor(mol.atom_eaff_n.values, dtype=torch.float32),
            '1st_ion': torch.as_tensor(mol.atom_1st_ion_n.values, dtype=torch.float32),
#             'xyz': torch.as_tensor(mol[['x', 'y', 'z']].values, dtype=torch.float32),
#             'xyz_n': torch.as_tensor(mol[['x_n', 'y_n', 'z_n']].values, dtype=torch.float32)
        }
    }

def create_edge_dict(mol):
    return {
        'src': torch.as_tensor(mol.atom_index_0.values, dtype=torch.int64),
        'dst': torch.as_tensor(mol.atom_index_1.values, dtype=torch.int64),
        'edata': {
            'type': torch.as_tensor(mol.edge_type_id.values, dtype=torch.int64),
            'distance': torch.as_tensor(mol.distance_n.values, dtype=torch.float32),
            'angle': torch.as_tensor(mol.angle_n.values, dtype=torch.float32),
            'dihedral': torch.as_tensor(mol.dihedral_n.values, dtype=torch.float32),
            'coupling_type': torch.as_tensor(mol.coupling_type_id.values, dtype=torch.int64),
            'coupling': torch.as_tensor(mol.scalar_coupling_constant.values, dtype=torch.float32),
#             'fc': torch.as_tensor(mol.fc.values, dtype=torch.float64),
#             'sd': torch.as_tensor(mol.sd.values, dtype=torch.float64),
#             'pso': torch.as_tensor(mol.pso.values, dtype=torch.float64),
#             'dso': torch.as_tensor(mol.dso.values, dtype=torch.float64),
#             'contrib_sum': torch.as_tensor(mol.contrib_sum.values, dtype=torch.float64),
            'train_id': torch.as_tensor(mol.train_id.values, dtype=torch.int64),
            'test_id': torch.as_tensor(mol.test_id.values, dtype=torch.int64),
        }
    }

def make_graphs(nodes, edges, mulliken=None):
    nodes['atom_val_enc'] = nodes.atom_val.map(valence_enc.__getitem__)
    
    normalize_cols = ['atom_eneg', 'atom_eaff', 'atom_1st_ion']
    for col in normalize_cols:
        nodes[col + '_n'] = unique_normalizer(nodes[col])

#     normalize_cols = ['x', 'y', 'z']
#     xyz = nodes[['x', 'y', 'z']].values.ravel()
#     xyz_mean = xyz.mean()
#     xyz_std = xyz.std()
#     for col in normalize_cols:
#         nodes[col + '_n'] = (nodes[col] - xyz_mean) / xyz_std
    
    ndata = nodes.sort_values(['molecule_id', 'atom_index']).groupby(by='molecule_id').apply(create_node_dict)
    if mulliken:
        assert len(ndata) == len(mulliken)
        for n, m in zip(ndata, mulliken):
            assert len(m) == n['nodes']
            n['ndata']['mulliken'] = m
    
    normalize_cols = ['distance', 'angle', 'dihedral']
    for col in normalize_cols:
        edges[col + '_n'] = classic_normalizer(edges[col])
    
    fill_cols = ['angle_n', 'dihedral_n']
    for col in fill_cols:
        edges[col].fillna(0., inplace=True)
    
    edata = edges.sort_values(
    ['molecule_id', 'atom_index_0', 'atom_index_1']).groupby(by='molecule_id').apply(create_edge_dict)
    
    gdata = [dict(nd, **ed) for nd, ed in zip(ndata, edata)]
    return gdata

def make_gdata():
    nodes = feather.read_dataframe('data/nodes.feather')
    edges = feather.read_dataframe('data/edges.feather')
    
    edges['train_id'] = edges.train_id.fillna(-1).astype(np.int64)
    edges['test_id'] = edges.test_id.fillna(-1).astype(np.int64)
    
    mulliken_train = torch.load(MULLIKEN_TRAIN)
    mulliken_test = torch.load(MULLIKEN_TEST)
    
    torch.save(
        make_graphs(nodes[nodes.is_train], edges[edges.is_mol_train], mulliken_train),
        OUTPUTS / 'train_gdata.torch')
    torch.save(
        make_graphs(nodes[nodes.is_test], edges[edges.is_mol_test], mulliken_test),
        OUTPUTS / 'test_gdata.torch')

In [18]:
make_gdata()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._u