In [1]:
from typing import List

from atomate.qchem.database import QChemCalcDb
from monty.serialization import dumpfn, loadfn
from mrnet.core.mol_entry import MoleculeEntry, MoleculeEntryError

from rxnrep.dataset.electrolyte_utils import (
    check_species,
    check_connectivity,
    check_bond_species,
    check_bond_length,
    check_num_bonds,
    check_bad_rdkit_molecule,
    remove_high_energy_mol_entries,
)

In [2]:
def get_db_num_entries(db_file):
    db = QChemCalcDb.from_db_file(db_file, admin=True)
    return db.collection.count_documents({})

In [3]:
def query_db_entries(db_file, num_entries: int = None, environment: str = None):
    """
    Query the molecule document database to pull all the molecules form molecule builder.

    Args:
        db_file: path to a json file storing credentials of the database
        num_entries (int): the number of entries to query, if `None`, get all.
        environment: query value for the environment key in the db. e.g. `smd_thf`. if `None`, ignore it.
        
    Returns:
        A list of molecule document entries.
    """

    num_entries = 0 if num_entries is None else num_entries
    query = {} if environment is None else {"environment": environment}

    db = QChemCalcDb.from_db_file(db_file, admin=True)
    cursor = db.collection.find(query, no_cursor_timeout=True).limit(num_entries)
    entries = [i for i in cursor]

    cursor.close()

    return entries

In [4]:
def filter_connectivity(entries, verbose=False):
    """
    remove mols having atoms not connected to others 
    """
    succeeded = []
    for m in entries:
        fail, comment = check_connectivity(m)
        if fail:
            if verbose:
                print(m.entry_id, comment)
        else:
            succeeded.append(m)

    return succeeded


def filter_species(entries, not_allowed_species=None, verbose=False):
    """
    remove mols with specific species
    """
    not_allowed_species = ["P"] if not_allowed_species is None else not_allowed_species

    succeeded = []
    for m in entries:
        fail, comment = check_species(m, species=not_allowed_species)
        if fail:
            if verbose:
                print(m.entry_id, comment)
        else:
            succeeded.append(m)

    return succeeded


def filter_bond_species(entries, verbose=False):
    """
    remove mols with specific bond between species, e.g. Li-H
    """
    succeeded = []
    for m in entries:
        fail, comment = check_bond_species(m)
        if fail:
            if verbose:
                print(m.entry_id, comment)
        else:
            succeeded.append(m)

    return succeeded


def filter_bond_length(entries, verbose=False):
    """
    remove mols with larger bond length    
    """
    succeeded = []
    for m in entries:
        fail, comment = check_bond_length(m)
        if fail:
            if verbose:
                print(m.entry_id, comment)
        else:
            succeeded.append(m)

    return succeeded


def filter_num_bonds(entries, verbose=False):
    """
    remove mols with unexpected number of bonds (e.g. more than 4 bonds for carbon),
    without considering metal species 
    """
    succeeded = []
    for m in entries:
        fail, comment = check_num_bonds(m)
        if fail:
            if verbose:
                print(m.entry_id, comment)
        else:
            succeeded.append(m)

    return succeeded


def filter_bad_rdkit_mol(entries, verbose=False):
    """
    remove mols that cannot be converted to rdkit mol
    """
    succeeded = []
    for m in entries:
        fail, comment = check_bad_rdkit_molecule(m)
        if fail:
            if verbose:
                print(m.entry_id, comment)
        else:
            succeeded.append(m)

    return succeeded


def filter_mol_entries(
    entries: List[MoleculeEntry], verbose=False
) -> List[MoleculeEntry]:
    """
    Filter out some `bad` molecules. 
    """
    print("Number of starting entries:", len(entries))

    # for molecules with the same isomorphism and charge, remove the ones with higher free energy
    entries = remove_high_energy_mol_entries(entries)
    print("Number of entries after removing isomorphic ones:", len(entries))

    entries = filter_connectivity(entries, verbose)
    print("Number of entries after (filter_connectivity):", len(entries))

    entries = filter_species(entries, verbose=verbose)
    print("Number of entries after (filter_species):", len(entries))

    entries = filter_num_bonds(entries, verbose)
    print("Number of entries after (filter_num_bonds):", len(entries))

    entries = filter_bond_species(entries, verbose)
    print("Number of entries after (filter_bond_species):", len(entries))

    entries = filter_bond_length(entries, verbose)
    print("Number of entries after (filter_bond_length):", len(entries))

    entries = filter_bad_rdkit_mol(entries, verbose)
    print("Number of entries after (filter_bad_rdit_mol):", len(entries))

    return entries

### Query db to get molecule documents 

In [5]:
db_file = "/Users/mjwen/Applications/db_access/sam_db/sam_db_mol_builder.json"
get_db_num_entries(db_file)

16718

In [6]:
num_entries = 2000
# num_entries = None
mol_docs = query_db_entries(db_file, num_entries)

### dump to file 

In [7]:
num_failed = 0
mol_entries = []
for doc in mol_docs:
    try:
        entry = MoleculeEntry.from_molecule_document(doc)
        mol_entries.append(entry)
    except MoleculeEntryError:
        num_failed += 1
print("Number of mol doc failed to be converted to mol entry:", num_failed)

# filter
mol_entries = filter_mol_entries(mol_entries, verbose=False)

# dump to file
fname = f"/Users/mjwen/Documents/Dataset/electrolyte/mol_entries_n{num_entries}.json"
dumpfn(mol_entries, fname)

Number of mol doc failed to be converted to mol entry: 5
Number of starting entries: 1995
Number of entries after removing isomorphic ones: 1135
Number of entries after (filter_connectivity): 1135
Number of entries after (filter_species): 1130
Number of entries after (filter_num_bonds): 1113
Number of entries after (filter_bond_species): 1112
Number of entries after (filter_bond_length): 1092


RDKit ERROR: [22:16:14] Explicit valence for atom # 3 C, 5, is greater than permitted


Number of entries after (filter_bad_rdit_mol): 1091
