In [2]:
import pickle
from tqdm import tqdm
from collections import Counter, defaultdict
from rdkit import Chem
from MACCSkeys import smartsPatts, maccs_keys_to_skip, ordered_by_occurrence, ordered_by_length_desc, ordered_by_length_asc

##### Checking #occurrences of each MACCS key

In [3]:
if False:
    rxncounts_per_key = Counter()

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                if (lhs_mol.HasSubstructMatch(maccs_pattern)
                    or rhs_mol.HasSubstructMatch(maccs_pattern)):
                    rxncounts_per_key.update({maccs_key: 1})

    # Store data (serialize)
    with open('rxncounts_per_key.pickle', 'wb') as handle:
        pickle.dump(rxncounts_per_key, handle, protocol=pickle.HIGHEST_PROTOCOL)


##### Checking if every reaction has a MACCS key

In [4]:
if False:
    n_valid_reactions = 0

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            lhs_found = False
            rhs_found = False

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                if lhs_found and rhs_found:
                    break

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                if not lhs_found:
                    if lhs_mol.HasSubstructMatch(maccs_pattern):
                        lhs_found = True

                if not rhs_found:
                    if rhs_mol.HasSubstructMatch(maccs_pattern):
                        rhs_found = True

            if lhs_found and rhs_found:
                n_valid_reactions += 1

    print(num_rxns - n_valid_reactions) # 0


##### Checking if every reaction is a MACCS transformation (only LHS)

In [5]:
maccs_keys = list(smartsPatts.keys())

if False:

    n_assumption_true_rxns = 0
    same_maccs_invalid_rxns = 0
    other_maccs_invalid_rxns = 0

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            # get maccs matched atom maps
            all_maccs_atom_maps_valid = set()
            same_maccs_atom_maps_invalid = set()
            other_maccs_atom_maps_invalid = set()

            for maccs_key in maccs_keys:

                if maccs_key in maccs_keys_to_skip:
                    continue

                this_maccs_atom_maps = set()
                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                match_tuples_list = lhs_mol.GetSubstructMatches(maccs_pattern)
                for match_tuple in match_tuples_list:

                    if any(lhs_mol.GetAtomWithIdx(atom_id).GetAtomMapNum() in this_maccs_atom_maps for atom_id in match_tuple):
                        for atom_id in match_tuple:
                            atom = lhs_mol.GetAtomWithIdx(atom_id)
                            same_maccs_atom_maps_invalid.add(atom.GetAtomMapNum())
                        continue # if any overlap with existing same-maccs-key matches, skip this match

                    if any(lhs_mol.GetAtomWithIdx(atom_id).GetAtomMapNum() in all_maccs_atom_maps_valid for atom_id in match_tuple):
                        for atom_id in match_tuple:
                            atom = lhs_mol.GetAtomWithIdx(atom_id)
                            other_maccs_atom_maps_invalid.add(atom.GetAtomMapNum())
                        continue # if any overlap with existing any-key matches, skip this match

                    for atom_id in match_tuple:
                        atom = lhs_mol.GetAtomWithIdx(atom_id)
                        all_maccs_atom_maps_valid.add(atom.GetAtomMapNum())
                        this_maccs_atom_maps.add(atom.GetAtomMapNum())

            # find bonds changed atom map numbers
            lhs_amap_to_nbr_amaps = defaultdict(set)
            rhs_amap_to_nbr_amaps = defaultdict(set)

            for bond in lhs_mol.GetBonds():
                atom_1 = lhs_mol.GetAtomWithIdx(bond.GetBeginAtomIdx())
                atom_2 = lhs_mol.GetAtomWithIdx(bond.GetEndAtomIdx())
                lhs_amap_to_nbr_amaps[atom_1.GetAtomMapNum()].add(atom_2.GetAtomMapNum())
                lhs_amap_to_nbr_amaps[atom_2.GetAtomMapNum()].add(atom_1.GetAtomMapNum())

            for bond in rhs_mol.GetBonds():
                atom_1 = rhs_mol.GetAtomWithIdx(bond.GetBeginAtomIdx())
                atom_2 = rhs_mol.GetAtomWithIdx(bond.GetEndAtomIdx())
                rhs_amap_to_nbr_amaps[atom_1.GetAtomMapNum()].add(atom_2.GetAtomMapNum())
                rhs_amap_to_nbr_amaps[atom_2.GetAtomMapNum()].add(atom_1.GetAtomMapNum())

            # check if bonds changed atoms are all maccs matched atoms
            rxn_valid = True
            for atom_map in lhs_amap_to_nbr_amaps:
                if lhs_amap_to_nbr_amaps[atom_map] != rhs_amap_to_nbr_amaps[atom_map]:
                    if atom_map not in all_maccs_atom_maps_valid:

                        if atom_map in same_maccs_atom_maps_invalid:
                            same_maccs_invalid_rxns += 1
                        elif atom_map in other_maccs_atom_maps_invalid:
                            other_maccs_invalid_rxns += 1

                        rxn_valid = False
                        break

            if rxn_valid:
                n_assumption_true_rxns += 1

    print(n_assumption_true_rxns, same_maccs_invalid_rxns, other_maccs_invalid_rxns)
    # 402536 (actual) BUT 306435 (from ordering of MACCS keys and RDKit preferring earlier indices)
    # if NOT sorted - 306435, 1970, 94949
    # if sorted by occ (desc) - 84948, 823, 321109 (BAD)
    # if sorted by len (desc) - 298673, 27308, 77788


##### Checking if every reaction involves only a pair of substructures

In [6]:
maccs_keys = list(smartsPatts.keys())
save_pickle_file = 'n_interacting_substructs_stats.pickle'

if False:

    n_assumption_true_rxns = 0
    n_interacting_substructs_in_rxns = Counter()

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            # get maccs matched atom maps
            match_idx = 0
            all_maccs_atom_maps_valid = set()
            all_maccs_atom_maps_to_match_idx = dict()

            for maccs_key in maccs_keys:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                match_tuples_list = lhs_mol.GetSubstructMatches(maccs_pattern)
                for match_tuple in match_tuples_list:

                    if any(lhs_mol.GetAtomWithIdx(atom_id).GetAtomMapNum() in all_maccs_atom_maps_valid for atom_id in match_tuple):
                        continue # if any overlap with existing matches, skip this match

                    for atom_id in match_tuple:
                        atom = lhs_mol.GetAtomWithIdx(atom_id)
                        all_maccs_atom_maps_valid.add(atom.GetAtomMapNum())
                        all_maccs_atom_maps_to_match_idx[atom.GetAtomMapNum()] = match_idx

                    match_idx += 1

            # find bonds changed atom map numbers
            lhs_amap_to_nbr_amaps = defaultdict(set)
            rhs_amap_to_nbr_amaps = defaultdict(set)

            for bond in lhs_mol.GetBonds():
                atom_1 = lhs_mol.GetAtomWithIdx(bond.GetBeginAtomIdx())
                atom_2 = lhs_mol.GetAtomWithIdx(bond.GetEndAtomIdx())
                lhs_amap_to_nbr_amaps[atom_1.GetAtomMapNum()].add(atom_2.GetAtomMapNum())
                lhs_amap_to_nbr_amaps[atom_2.GetAtomMapNum()].add(atom_1.GetAtomMapNum())

            for bond in rhs_mol.GetBonds():
                atom_1 = rhs_mol.GetAtomWithIdx(bond.GetBeginAtomIdx())
                atom_2 = rhs_mol.GetAtomWithIdx(bond.GetEndAtomIdx())
                rhs_amap_to_nbr_amaps[atom_1.GetAtomMapNum()].add(atom_2.GetAtomMapNum())
                rhs_amap_to_nbr_amaps[atom_2.GetAtomMapNum()].add(atom_1.GetAtomMapNum())

            # check if bonds changed atoms are all maccs matched atoms
            rxn_valid = True
            seen_match_idxs = set()
            for atom_map in lhs_amap_to_nbr_amaps:
                if lhs_amap_to_nbr_amaps[atom_map] != rhs_amap_to_nbr_amaps[atom_map]:
                    if atom_map not in all_maccs_atom_maps_valid:
                        rxn_valid = False
                        break
                    else:
                        match_idx = all_maccs_atom_maps_to_match_idx[atom_map]
                        seen_match_idxs.add(match_idx)

            if rxn_valid:
                n_assumption_true_rxns += 1
                n_interacting_substructs_in_rxns.update({len(seen_match_idxs): 1})

    n_interacting_substructs_stats = {
        "n_assumption_true_rxns": n_assumption_true_rxns,
        "n_interacting_substructs_in_rxns": n_interacting_substructs_in_rxns,
        "num_rxns": num_rxns,
    }

    # Store data (serialize)
    with open(save_pickle_file, 'wb') as handle:
        pickle.dump(n_interacting_substructs_stats, handle, protocol=pickle.HIGHEST_PROTOCOL)

    print(n_assumption_true_rxns, n_interacting_substructs_stats, num_rxns)

##### Results

Q. In how many reactions are the changing LHS atoms ALL covered in MACCS? <br>
A. 402536 out of 409035 = 98% [THIS IS GOOD]

Q. Due to atom overlaps, in how many reactions get covered? <br>
A. 306435 out of 409035 = 75% (missed 1970 due to same MACCS, and 94949 due to other MACCS) [THIS WILL BE HANDLED USING DUMMY ATOMS]

Q. Among the covered 306435 reactions, how many are substructure-pair ONLY reactions? <br>
A. Only 28630 out of 306435 = 9% [THIS IS OKAY BECAUSE WE WILL USE MULTIPLE PAIRS]

-----
##### ALTERNTAE IDEA: Model NERF using 'blobs' (Problem = minimal novelty)
---> 'Blob' = combination of one or more substructures w/ or w/o overlaps). <br>
---> This reduces complexity of NERF. Problem = minimal reduction in complexity

##### Checking avg %age of atoms per reaction LHS are not part of ANY MACCS substructure

In [9]:
if False:

    avg_perc_unmatched_atoms = 0
    avg_num_substructs_per_rxn = 0

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            num_substructs_rxn = 0
            all_matched_atoms = set()

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                match_tuples_list = lhs_mol.GetSubstructMatches(maccs_pattern)
                for match_tuple in match_tuples_list:
                    all_matched_atoms.update(set(match_tuple))
                    num_substructs_rxn += 1

            total_num_atoms = lhs_mol.GetNumAtoms()
            perc_atoms_unmatched = 100 * ((total_num_atoms - len(all_matched_atoms)) / total_num_atoms)
            avg_perc_unmatched_atoms += perc_atoms_unmatched
            avg_num_substructs_per_rxn += num_substructs_rxn

    avg_perc_unmatched_atoms /= num_rxns
    avg_num_substructs_per_rxn /= num_rxns

    print(avg_perc_unmatched_atoms, avg_num_substructs_per_rxn) # 0.1267754347743418, 212.77733201315291


##### Checking avg %age of atoms that will NOT take part in rxn (unchanged blobs + unmatched)

In [10]:
class DisjointSet:
    def __init__(self):
        self.parent = {}
        self.rank = {} # stores the depth of trees
        self.n_objects = {} # stores num objects in each tree

    # perform MakeSet operation
    def makeSet(self, universe):
        # create `n` disjoint sets (one for each item)
        for i in universe:
            self.parent[i] = i
            self.rank[i] = 0
            self.n_objects[i] = 1

    # Find the root of the set in which element `k` belongs
    def Find(self, k):
        # if `k` is not the root
        if self.parent[k] != k:
            self.parent[k] = self.Find(self.parent[k]) # path compression
        return self.parent[k]

    # Find number of objects in tree of k
    def getSizeOf(self, k):
        parent_of_k = self.Find(k)
        return self.n_objects[parent_of_k]

    # Perform Union of two subsets
    def Union(self, a, b):
        # find the root of the sets in which elements `x` and `y` belongs
        x = self.Find(a)
        y = self.Find(b)

        # if `x` and `y` are present in the same set
        if x == y:
            return

        # Always attach a smaller depth tree under the root of the deeper tree.
        if self.rank[x] > self.rank[y]:
            self.parent[y] = x
            self.n_objects[x] += self.n_objects[y]

        elif self.rank[x] < self.rank[y]:
            self.parent[x] = y
            self.n_objects[y] += self.n_objects[x]

        else:
            self.parent[x] = y
            self.n_objects[y] += self.n_objects[x]
            self.rank[y] = self.rank[y] + 1

In [11]:
if False:

    num_valid_rxns = 0
    avg_num_total_blobs = 0
    avg_num_involved_blobs = 0
    avg_perc_not_involved_atoms = 0

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            rhs_mol = Chem.MolFromSmiles(rhs)

            UF = DisjointSet()
            UF.makeSet([atom.GetAtomMapNum() for atom in lhs_mol.GetAtoms()])

            all_maccs_atom_maps_valid = set()

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                match_tuples_list = lhs_mol.GetSubstructMatches(maccs_pattern)
                for match_tuple in match_tuples_list:

                    prev_atom_map = None
                    for atom_id in match_tuple:
                        atom = lhs_mol.GetAtomWithIdx(atom_id)

                        if prev_atom_map:
                            UF.Union(atom.GetAtomMapNum(), prev_atom_map)

                        prev_atom_map = atom.GetAtomMapNum()
                        all_maccs_atom_maps_valid.add(atom.GetAtomMapNum())

            # find bonds changed atom map numbers
            lhs_amap_to_nbr_amaps = defaultdict(set)
            rhs_amap_to_nbr_amaps = defaultdict(set)

            for bond in lhs_mol.GetBonds():
                atom_1 = lhs_mol.GetAtomWithIdx(bond.GetBeginAtomIdx())
                atom_2 = lhs_mol.GetAtomWithIdx(bond.GetEndAtomIdx())
                lhs_amap_to_nbr_amaps[atom_1.GetAtomMapNum()].add(atom_2.GetAtomMapNum())
                lhs_amap_to_nbr_amaps[atom_2.GetAtomMapNum()].add(atom_1.GetAtomMapNum())

            for bond in rhs_mol.GetBonds():
                atom_1 = rhs_mol.GetAtomWithIdx(bond.GetBeginAtomIdx())
                atom_2 = rhs_mol.GetAtomWithIdx(bond.GetEndAtomIdx())
                rhs_amap_to_nbr_amaps[atom_1.GetAtomMapNum()].add(atom_2.GetAtomMapNum())
                rhs_amap_to_nbr_amaps[atom_2.GetAtomMapNum()].add(atom_1.GetAtomMapNum())

            rxn_valid = True
            seen_all_blob_idxs = set()
            seen_involved_blob_idxs = set()
            num_involved_atoms = 0

            for atom_map in lhs_amap_to_nbr_amaps:

                blob_idx = UF.Find(atom_map)
                seen_all_blob_idxs.add(blob_idx)

                if lhs_amap_to_nbr_amaps[atom_map] != rhs_amap_to_nbr_amaps[atom_map]:
                    if atom_map not in all_maccs_atom_maps_valid:
                        rxn_valid = False
                        break

                    blob_size = UF.getSizeOf(blob_idx)
                    if blob_idx not in seen_involved_blob_idxs:
                        num_involved_atoms += blob_size
                        seen_involved_blob_idxs.add(blob_idx)

            if rxn_valid:
                num_valid_rxns += 1
                avg_num_involved_blobs += len(seen_involved_blob_idxs)
                avg_num_total_blobs += len(seen_all_blob_idxs)

                total_num_atoms = lhs_mol.GetNumAtoms()
                perc_not_involved_atoms = 100 * ((total_num_atoms - num_involved_atoms) / total_num_atoms)
                avg_perc_not_involved_atoms += perc_not_involved_atoms


    avg_perc_not_involved_atoms /= num_valid_rxns
    avg_num_involved_blobs /= num_valid_rxns
    avg_num_total_blobs /= num_valid_rxns

    print(avg_perc_not_involved_atoms, avg_num_involved_blobs, avg_num_total_blobs)


------

##### Checking avg num atoms and avg num substructures

In [None]:
if False:

    avg_num_atoms = 0
    avg_num_substructures = 0

    num_rxns = sum(1 for line in open("../data/raw/train.txt", "r"))

    with open("../data/raw/train.txt", "r") as train_dataset:
        for reaction in tqdm(train_dataset, total = num_rxns):

            lhs, rhs = reaction.split(">>")
            lhs_mol = Chem.MolFromSmiles(lhs)
            avg_num_atoms += lhs_mol.GetNumAtoms()

            for maccs_key in smartsPatts:

                if maccs_key in maccs_keys_to_skip:
                    continue

                maccs_smarts = smartsPatts[maccs_key][0]
                maccs_pattern = Chem.MolFromSmarts(maccs_smarts)

                match_tuples_list = lhs_mol.GetSubstructMatches(maccs_pattern)
                avg_num_substructures += len(match_tuples_list)

    avg_num_atoms /= num_rxns
    avg_num_substructures /= num_rxns
    print(avg_num_atoms, avg_num_substructures) # 39.68779933257545 212.77733201315291 (too much overlap)
