In [3]:
from tqdm import tqdm
from collections import Counter, defaultdict
from rdkit import Chem
from MACCSkeys import smartsPatts

##### Checking #occurrences of each MACCS key

In [2]:
if False:
    rxncounts_per_key = Counter()
    maccs_keys_to_skip = set([1, 44, 125, 166])

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                if (lhs_mol.HasSubstructMatch(maccs_pattern)
                    or rhs_mol.HasSubstructMatch(maccs_pattern)):
                    rxncounts_per_key.update({maccs_key: 1})

    import pickle

    # Store data (serialize)
    with open('rxncounts_per_key.pickle', 'wb') as handle:
        pickle.dump(rxncounts_per_key, handle, protocol=pickle.HIGHEST_PROTOCOL)


##### Checking if every reaction has a MACCS key

In [3]:
if False:
    n_valid_reactions = 0
    maccs_keys_to_skip = set([1, 44, 125, 166])

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            lhs_found = False
            rhs_found = False

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                if lhs_found and rhs_found:
                    break

                if not lhs_found:
                    if lhs_mol.HasSubstructMatch(maccs_pattern):
                        lhs_found = True

                if not rhs_found:
                    if rhs_mol.HasSubstructMatch(maccs_pattern):
                        rhs_found = True

            if lhs_found and rhs_found:
                n_valid_reactions += 1


    valid_rxn_stats = {
        'n_valid_reactions': n_valid_reactions,
        'total_reactions': num_rxns,
    }

    import pickle

    # Store data (serialize)
    with open('valid_rxn_stats.pickle', 'wb') as handle:
        pickle.dump(valid_rxn_stats, handle, protocol=pickle.HIGHEST_PROTOCOL)

    print(num_rxns - n_valid_reactions) # 0


100%|██████████| 409035/409035 [08:49<00:00, 772.95it/s]


##### Checking if every reaction is a MACCS transformation (only LHS)

In [None]:
# if False:
n_assumption_true_rxns = 0
maccs_keys_to_skip = set([1, 44, 125, 166])

same_maccs_invalid_rxns = 0
other_maccs_invalid_rxns = 0

num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

with open("../data/raw/train.txt", "r") as train_dataset:
    for reaction in tqdm(train_dataset, total = num_rxns):

        lhs, rhs = reaction.split(">>")
        lhs_mol = Chem.MolFromSmiles(lhs)
        rhs_mol = Chem.MolFromSmiles(rhs)

        # get maccs matched atom maps
        all_maccs_atom_maps_valid = set()
        same_maccs_atom_maps_invalid = set()
        other_maccs_atom_maps_invalid = set()

        for maccs_key in smartsPatts:

            if maccs_key in maccs_keys_to_skip:
                continue

            this_maccs_atom_maps = set()
            maccs_smarts = smartsPatts[maccs_key][0]
            maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

            match_tuples_list = lhs_mol.GetSubstructMatches(maccs_pattern)
            for match_tuple in match_tuples_list:

                if any(lhs_mol.GetAtomWithIdx(atom_id).GetAtomMapNum() in this_maccs_atom_maps for atom_id in match_tuple):
                    for atom_id in match_tuple:
                        atom = lhs_mol.GetAtomWithIdx(atom_id)
                        same_maccs_atom_maps_invalid.add(atom.GetAtomMapNum())
                    continue # if any overlap with existing same-maccs-key matches, skip this match

                if any(lhs_mol.GetAtomWithIdx(atom_id).GetAtomMapNum() in all_maccs_atom_maps_valid for atom_id in match_tuple):
                    for atom_id in match_tuple:
                        atom = lhs_mol.GetAtomWithIdx(atom_id)
                        other_maccs_atom_maps_invalid.add(atom.GetAtomMapNum())
                    continue # if any overlap with existing any-key matches, skip this match

                for atom_id in match_tuple:
                    atom = lhs_mol.GetAtomWithIdx(atom_id)
                    all_maccs_atom_maps_valid.add(atom.GetAtomMapNum())
                    this_maccs_atom_maps.add(atom.GetAtomMapNum())

        # find bonds changed atom map numbers
        lhs_amap_to_nbr_amaps = defaultdict(set)
        rhs_amap_to_nbr_amaps = defaultdict(set)

        for atom in lhs_mol.GetAtoms():
            atom_map = atom.GetAtomMapNum()
            for nbr_atom in atom.GetNeighbors():
                lhs_amap_to_nbr_amaps[atom_map].add(nbr_atom.GetAtomMapNum())

        for atom in rhs_mol.GetAtoms():
            atom_map = atom.GetAtomMapNum()
            for nbr_atom in atom.GetNeighbors():
                rhs_amap_to_nbr_amaps[atom_map].add(nbr_atom.GetAtomMapNum())

        # check if bonds changed atoms are all maccs matched atoms
        rxn_valid = True
        for atom_map in lhs_amap_to_nbr_amaps:
            if lhs_amap_to_nbr_amaps[atom_map] != rhs_amap_to_nbr_amaps[atom_map]:
                if atom_map not in all_maccs_atom_maps_valid:

                    if atom_map in same_maccs_atom_maps_invalid:
                        same_maccs_invalid_rxns += 1
                    elif atom_map in other_maccs_atom_maps_invalid:
                        other_maccs_invalid_rxns += 1

                    rxn_valid = False
                    break

        if rxn_valid:
            n_assumption_true_rxns += 1

import pickle

rxn_assumptions_stats = {
    "n_assumption_true_rxns": n_assumption_true_rxns,
    "same_maccs_invalid_rxns": same_maccs_invalid_rxns,
    "other_maccs_invalid_rxns": other_maccs_invalid_rxns,
    "num_rxns": num_rxns,
}

# Store data (serialize)
with open('rxn_assumptions_stats.pickle', 'wb') as handle:
    pickle.dump(rxn_assumptions_stats, handle, protocol=pickle.HIGHEST_PROTOCOL)

# 402536 (actual) BUT 306435 (from ordering of MACCS keys and RDKit preferring earlier indices)
print(n_assumption_true_rxns, same_maccs_invalid_rxns, other_maccs_invalid_rxns) # 306435, 1970, 94949


In [7]:
# NEW IDEA 1 (less innovative):
# 1. Allow intersections of substructures. This creates 'blobs'. Don't reduce to simplified form.
# 2. Blobs are combinations (allow all possibilities) of substructures which needn't be any one substructure.
# 3. Then, model LHS as a combination of ALL blobs (just remove all unmatched atoms). Run NERF (PtrNets)
# PROBLEM: It is almost like NERF with preprocessing.

# NEW IDEA 2 (more innovative):
# 1. Allow intersections of substructures. This creates 'blobs'. Don't reduce to simplified form.
# 2. Blobs are combinations (allow all possibilities) of substructures which needn't be any one substructure.
# 3. Using some GNN, predict whether each blob would take part in the reaction.
# 4. Model LHS as a combination of only SELECTED blobs (remove all unmatched atoms and unselected blobs). Run NERF (PtrNets)
# WHY THIS APPROACH? We can NOT guarantee that only PAIRS of substructures will interact. What if it is TRIPLETS or more?