In [171]:
from functools import lru_cache
import pandas as pd
import uuid
import os
import requests
from tqdm import tqdm_notebook as tqdm
from pymatgen.core import Structure

from ase.ga.slab_operators import (CutSpliceSlabCrossover,
                                   RandomSlabPermutation,
                                   RandomCompositionMutation)
from ase.ga.offspring_creator import OperationSelector
from ase import io
from ase import Atoms

import argparse
import os
import shutil
import sys
import time

import numpy as np
import torch
import torch.nn as nn
from sklearn import metrics
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset

from cgcnn.cgcnn.data import CIFData
from cgcnn.cgcnn.data import collate_pool
from cgcnn.cgcnn.model import CrystalGraphConvNet

import csv
import functools
import json
import os
import random
import uuid
import warnings
warnings.simplefilter("ignore")
import numpy as np
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler

In [274]:
oclist = [(5, CutSpliceSlabCrossover()),
          (2, RandomSlabPermutation()),
          #(1, RandomCompositionMutation())
          ]
operation_selector = OperationSelector(*zip(*oclist))
def mutate(a,b):
    a.info['confid'] = a.symbols
    b.info['confid'] = b.symbols
    op = operation_selector.get_operator()
    if type(op).__name__ == "CutSpliceSlabCrossover":
        if a.get_positions().shape[0] > b.get_positions().shape[0]:
            del a[b.get_positions().shape[0]:]
        else:
            del b[a.get_positions().shape[0]:]
    return op.get_new_individual((a,b))
mutate(f,m)

(Atoms(symbols='H2O3Se158Sn71', pbc=True, cell=[[26.902, 0.0, 0.0], [-13.450999999999995, 23.297815412608973, 0.0], [0.0, 0.0, 30.673]], spacegroup_kinds=...),
 'CutSpliceSlabCrossover: Parents Sn4Se8Sn2Se6SnSeSn4Se6SnSeSnSe2Sn2Se6SnSe3Sn5Se9SnSe5Sn3Se6SnSe5Sn3Se6SnSe2Sn3Se10SnSe3Sn3Se7SnSe2Sn2Se6Sn4Se8SnSe4Sn4Se9Sn3Se4Sn5Se7Sn2Se6SnSeSnSe3Sn3Se10Sn2SeSn2Se6SnSe3SnSe2SnSe3SnSe CoGeSnV3SeOSe3OSn2Se3O2SeSnSeSn4Se6SnSe4Sn2Se6SnSe3Sn5Se8Sn2Se5Sn3Se6SnSe5Sn3Se6SnSe2Sn3Se9Sn2Se3Sn3Se7SnSe2Sn2Se6Sn4Se8SnSe4SnHSnNH2NHSeNSe3Sn3Se5Sn4Se7Sn2Se6SnSeSnSe3SnHOSe5OSe4Sn2SeSn2Se6SnSe3SnSe2SnSe3SnSe')

In [3]:
class Normalizer(object):
    """Normalize a Tensor and restore it later. """
    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']
class AtomInitializer(object):
    """
    Base class for intializing the vector representation for atoms.

    !!! Use one AtomInitializer per dataset !!!
    """
    def __init__(self, atom_types):
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {idx: atom_type for atom_type, idx in
                            self._embedding.items()}

    def state_dict(self):
        return self._embedding

    def decode(self, idx):
        if not hasattr(self, '_decodedict'):
            self._decodedict = {idx: atom_type for atom_type, idx in
                                self._embedding.items()}
        return self._decodedict[idx]
class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Parameters
    ----------

    elem_embedding_file: str
        The path to the .json file
    """
    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value
                          in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)
class GaussianDistance(object):
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """
    def __init__(self, dmin, dmax, step, var=None):
        """
        Parameters
        ----------

        dmin: float
          Minimum interatomic distance
        dmax: float
          Maximum interatomic distance
        step: float
          Step size for the Gaussian filter
        """
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax+step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Apply Gaussian disntance filter to a numpy distance array

        Parameters
        ----------

        distance: np.array shape n-d array
          A distance matrix of any shape

        Returns
        -------
        expanded_distance: shape (n+1)-d array
          Expanded distance matrix with the last dimension of length
          len(self.filter)
        """
        return np.exp(-(distances[..., np.newaxis] - self.filter)**2 /
                      self.var**2)

atom_init_file = "atom_init.json"
ari = AtomCustomJSONInitializer(atom_init_file)
gdf = GaussianDistance(dmin=0, dmax=8, step=0.2)
def toTensor(crystal):
    max_num_nbr = 12
    radius = 8
    atom_fea = np.vstack([ari.get_atom_fea(crystal[i].specie.number)
                          for i in range(len(crystal))])
    atom_fea = torch.Tensor(atom_fea)
    all_nbrs = crystal.get_all_neighbors(radius, include_index=True)
    all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
    nbr_fea_idx, nbr_fea = [], []
    for nbr in all_nbrs:
        if len(nbr) < max_num_nbr:
            warnings.warn('{} not find enough neighbors to build graph. '
                          'If it happens frequently, consider increase '
                          'radius.'.format(cif_id))
            nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) +
                               [0] * (max_num_nbr - len(nbr)))
            nbr_fea.append(list(map(lambda x: x[1], nbr)) +
                           [radius + 1.] * (max_num_nbr -
                                                 len(nbr)))
        else:
            nbr_fea_idx.append(list(map(lambda x: x[2],
                                        nbr[:max_num_nbr])))
            nbr_fea.append(list(map(lambda x: x[1],
                                    nbr[:max_num_nbr])))
    nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
    nbr_fea = gdf.expand(nbr_fea)
    atom_fea = torch.Tensor(atom_fea)
    nbr_fea = torch.Tensor(nbr_fea)
    nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
    return (atom_fea, nbr_fea, nbr_fea_idx), torch.Tensor([float(0)]), 0

In [230]:
structures,_,_ = toTensor(Structure.from_file("tes.cif"))
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

fara_checkpoint = torch.load("trained/faradaic/model_best.pth.tar",
    map_location=lambda storage, loc: storage)
fara_args = argparse.Namespace(**fara_checkpoint['args'])

fara_normalizer = Normalizer(torch.zeros(3))
fara_normalizer.load_state_dict(fara_checkpoint['normalizer'])

fara_model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
        atom_fea_len=fara_args.atom_fea_len,
        n_conv=fara_args.n_conv,
        h_fea_len=fara_args.h_fea_len,
        n_h=fara_args.n_h,
        classification=False)

volt_checkpoint = torch.load("trained/volts/model_best.pth.tar",
    map_location=lambda storage, loc: storage)
volt_args = argparse.Namespace(**volt_checkpoint['args'])

volt_normalizer = Normalizer(torch.zeros(3))
volt_normalizer.load_state_dict(volt_checkpoint['normalizer'])

volt_model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
        atom_fea_len=fara_args.atom_fea_len,
        n_conv=volt_args.n_conv,
        h_fea_len=volt_args.h_fea_len,
        n_h=volt_args.n_h,
        classification=False)

energy_checkpoint = torch.load("trained/free_energy/model_best.pth.tar",
    map_location=lambda storage, loc: storage)
energy_args = argparse.Namespace(**energy_checkpoint['args'])

energy_normalizer = Normalizer(torch.zeros(3))
energy_normalizer.load_state_dict(energy_checkpoint['normalizer'])

energy_model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
        atom_fea_len=energy_args.atom_fea_len,
        n_conv=energy_args.n_conv,
        h_fea_len=energy_args.h_fea_len,
        n_h=energy_args.n_h,
        classification=False)

In [5]:
@lru_cache(maxsize=None)
def predict(tensor, model, norm):
    return norm.denorm(model(*collate_pool([tensor])[0]).data)

In [347]:
@lru_cache(maxsize=None)
def fitness(fp, verbose=False):
    mol = Structure.from_file(fp)
    mol_tensor = toTensor(mol)
    free_energy = predict(mol_tensor, energy_model, energy_normalizer)
    fara_eff = predict(mol_tensor, fara_model, fara_normalizer)
    volt = predict(mol_tensor, volt_model, volt_normalizer)
    
    if verbose:
        print(f"free_energy:{free_energy} faradaic_efficiency:{fara_eff} volt_diff:{volt}")
        print(fara_eff/5-free_energy*5-abs(volt)/2)
    return fara_eff/5-free_energy*5-abs(volt)/2

In [343]:
df = pd.read_csv("fitness_5.csv")
#df = df[["CIF filename"]]
df

Unnamed: 0.1,Unnamed: 0,CIF filename,fitness
0,0,01435c55fe9f44c5a62fd3702daf534a,26.350632
1,1,95414098502f47e1b0180c38f92669bf,26.350624
2,2,4107c2090efa46848559ba82425b6156,26.346018
3,3,87cb85d2d45043e789e09274e4cd4ce3,26.345366
4,4,b9edffdc001f42a9955a3c5a8a434773,26.345366
...,...,...,...
1045,1045,da610552fedf4a79b22bdf5be1839550,26.336510
1046,1046,9530d29b57954abaadb99eea1af3fac6,26.345562
1047,1047,f71bba321c354e17ad8044835fca0066,26.344635
1048,1048,960428c7b0174f45a8d517f68685482b,26.334309


In [348]:
fitness("csd_mofs/01435c55fe9f44c5a62fd3702daf534a.cif", verbose=True)
fitness("csd_mofs/515966273fd940f98bdc5cb4c6a5010e.cif", verbose=True)

free_energy:tensor([[-2.5278]]) faradaic_efficiency:tensor([[87.4354]]) volt_diff:tensor([[7.5507]])
tensor([[26.3506]])
free_energy:tensor([[-2.5177]]) faradaic_efficiency:tensor([[87.2165]]) volt_diff:tensor([[7.7221]])
tensor([[26.1708]])


tensor([[26.1708]])

In [None]:
df["fitness"] = None
for index,cif in tqdm(df.iterrows(), total=df.shape[0]):
    try:
        cif["fitness"] = float(fitness(f"csd_mofs/{cif['CIF filename']}.cif"))
    except:
        cif["fitness"] = None

HBox(children=(IntProgress(value=0, max=6279), HTML(value='')))

In [319]:
df = df[["CIF filename","fitness"]]
df.to_csv("fitness_5.csv")

In [329]:
df = df[df["fitness"] > float("-inf")]

In [321]:
df.sort_values(by=['fitness'], ascending=False, inplace=True)

In [349]:
io.read("csd_mofs/515966273fd940f98bdc5cb4c6a5010e.cif")

Atoms(symbols='C96H48Fe4N8Ni2O20', pbc=True, cell=[6.8484, 32.9205, 16.5679], spacegroup_kinds=...)

In [350]:
pop_skim = 50

top = df[:pop_skim]
next_gen = df[:pop_skim]

In [307]:
for i in tqdm(range(1000)):
    m = io.read(f"csd_mofs/{top.sample().iloc[0]['CIF filename']}.cif", index="0")
    f = io.read(f"csd_mofs/{top.sample().iloc[0]['CIF filename']}.cif", index="0")
    child = mutate(m,f)
    uid = uuid.uuid4().hex
    next_gen = next_gen.append({"CIF filename":uid,"fitness":None},ignore_index = True)
    io.write(f"csd_mofs/{uid}.cif",child[0])

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

In [309]:
next_gen.to_csv("gen_5.csv")

In [310]:
df = next_gen

In [351]:
df = pd.read_csv("gen_1.csv")

In [360]:
df = pd.concat([df,top])[["CIF filename", "fitness"]]