In [1]:
"""
multiple functions taken from https://github.com/txie-93/cgcnn/blob/master/cgcnn/data.py
"""

from __future__ import print_function, division

import csv
import functools
import json
import os
import random
import warnings
import pickle
import numpy as np
import torch
from pymatgen.core.structure import Structure


class GaussianDistance(object):
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """
    def __init__(self, dmin, dmax, step, var=None):
        """
        Parameters
        ----------

        dmin: float
          Minimum interatomic distance
        dmax: float
          Maximum interatomic distance
        step: float
          Step size for the Gaussian filter
        """
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax+step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Apply Gaussian disntance filter to a numpy distance array

        Parameters
        ----------

        distance: np.array shape n-d array
          A distance matrix of any shape

        Returns
        -------
        expanded_distance: shape (n+1)-d array
          Expanded distance matrix with the last dimension of length
          len(self.filter)
        """
        return np.exp(-(distances[..., np.newaxis] - self.filter)**2 /
                      self.var**2)


class AtomInitializer(object):
    """
    Base class for intializing the vector representation for atoms.

    !!! Use one AtomInitializer per dataset !!!
    """
    def __init__(self, atom_types):
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {idx: atom_type for atom_type, idx in
                            self._embedding.items()}

    def state_dict(self):
        return self._embedding

    def decode(self, idx):
        if not hasattr(self, '_decodedict'):
            self._decodedict = {idx: atom_type for atom_type, idx in
                                self._embedding.items()}
        return self._decodedict[idx]


class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Parameters
    ----------

    elem_embedding_file: str
        The path to the .json file
    """
    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value
                          in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)




In [2]:
def makedir_if_not_exist(dir_to_make):
    if not os.path.exists(dir_to_make):
        os.makedirs(dir_to_make)


In [3]:
def pickle_dset(savedir, root_dir, max_num_nbr=6, radius=8, dmin=0, step=0.2,
             random_seed=123):
    """
    The pickle_dset reads the relevant data for a given param set, so can be used in aws (no need to install pymatgen):

    inputs as in https://github.com/txie-93/
    root_dir
    ├── id_prop.csv
    ├── atom_init.json
    ├── id0.cif
    ├── id1.cif
    ├── ...

    id_prop.csv: a CSV file with two columns. The first column recodes a
    unique ID for each crystal, and the second column recodes the value of
    target property.

    atom_init.json: a JSON file that stores the initialization vector for each
    element.

    ID.cif: a CIF file that recodes the crystal structure, where ID is the
    unique ID for the crystal.

    Parameters
    ----------

    savedir: str
        The path to the directory we will create which will contain the datasets to be passed into AWS.
    root_dir: str
        The path to the root directory of the dataset with atom_init.json and id_prop.csv
    max_num_nbr: int
        The maximum number of neighbors while constructing the crystal graph
    radius: float
        The cutoff radius for searching neighbors
    dmin: float
        The minimum distance for constructing GaussianDistance
    step: float
        The step size for constructing GaussianDistance
    random_seed: int
        Random seed for shuffling the dataset

    Returns
    -------

    nothing, just pickles
    """

    root_dir = root_dir
    max_num_nbr, radius = max_num_nbr, radius
    assert os.path.exists(root_dir), 'root_dir does not exist!'   
    
    makedir_if_not_exist(savedir)
    makedir_if_not_exist(savedir+"/atom_feas")
    makedir_if_not_exist(savedir+"/nbr_feas")
    makedir_if_not_exist(savedir+"/nbr_fea_indices")
    makedir_if_not_exist(savedir+"/rel_indices")
    makedir_if_not_exist(savedir+"/cif_ids")
    makedir_if_not_exist(savedir+"/targets")

    assert os.path.exists(savedir), 'pickling directory savedir does not exist!'

    id_prop_file = os.path.join(root_dir, 'id_prop.csv')
    assert os.path.exists(id_prop_file), 'id_prop.csv does not exist!'
    with open(id_prop_file) as f:
        reader = csv.reader(f)
        id_prop_data = [row for row in reader]

    random.seed(random_seed)
    random.shuffle(id_prop_data)
    atom_init_file = os.path.join(root_dir, 'atom_init.json')
    assert os.path.exists(atom_init_file), 'atom_init.json does not exist!'
    ari = AtomCustomJSONInitializer(atom_init_file)
    gdf = GaussianDistance(dmin=dmin, dmax=radius, step=step)

    # pickle objects in initialization instead of by
    # csv and pymatgen .. can use in cloud then

    count = 0
    for cif_id, target, o_atom_idx, _ in id_prop_data:
        if count%1000==0:
            print(count)
        crystal = Structure.from_file(os.path.join(root_dir,
                                                   cif_id+'.cif'))
        o_nbr_dist = 3
        o_atom_nbrs = crystal.get_neighbors(crystal[int(o_atom_idx)], o_nbr_dist)
        max_o_nbrs = 6
        o_dists = []
        for site in o_atom_nbrs:
            o_dists.append(site[1])
        o_nbr_indices = sorted(range(len(o_dists)), key=lambda k: o_dists[k])
        o_nbr_indices = np.array(o_nbr_indices[:max_o_nbrs])
        # pymatgen returns indices outside the actual crystal len (images)
        # so here we handle any time that happens . TO DO: handle better!
        # we are paddin with the O centered atom
        o_nbr_indices[o_nbr_indices > len(crystal)-1] = o_atom_idx
        o_nbr_indices = list(o_nbr_indices)
        rel_inds = [int(o_atom_idx)]
        rel_inds += o_nbr_indices

        atom_fea = np.vstack([ari.get_atom_fea(crystal[i].specie.number)
                              for i in range(len(crystal))])
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < max_num_nbr:
                warnings.warn('{} not find enough neighbors to build graph. '
                              'If it happens frequently, consider increase '
                              'radius.'.format(cif_id))
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) +
                                   [0] * (max_num_nbr - len(nbr)))
                nbr_fea.append(list(map(lambda x: x[1], nbr)) +
                               [radius + 1.] * (max_num_nbr -
                                                     len(nbr)))
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2],
                                            nbr[:max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1],
                                        nbr[:max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])

        for val in o_nbr_indices:
            if val>=len(crystal):
                print(o_atom_idx)
                print(cif_id)
                print(len(crystal))
                print(o_nbr_indices)
                
        pickle.dump(atom_fea, open("{}/atom_feas/atom_fea_{}.p".format(savedir, count), "wb"))
        pickle.dump(target, open("{}/targets/target_{}.p".format(savedir, count), "wb"))
        pickle.dump(cif_id, open("{}/cif_ids/cif_id_{}.p".format(savedir, count), "wb"))
        pickle.dump(nbr_fea, open("{}/nbr_feas/nbr_fea_{}.p".format(savedir, count), "wb"))
        pickle.dump(nbr_fea_idx, open("{}/nbr_fea_indices/nbr_fea_idx_{}.p".format(savedir, count), "wb"))
        pickle.dump(rel_inds, open("{}/rel_indices/rel_index_{}.p".format(savedir, count), "wb"))
        count+=1


        

run the code as pickle_dset(pickle_dir, data_dir)
data_dir contains the directory with structure as in the cgcnn example data directory, with cif id, atom_init.json and id_prop.csv files.