In [None]:
# This code enables parallel execution, as outlined in the original work.
# Please, check the time complexity discussed in the paper (page 4).
# Due to the parallelization, the code contains additional lines, which may make the files harder to read.

# Merging-operation learning phase (Phase 1) - see page 3 

In [1]:
import multiprocessing as mp
from datetime import datetime

from utils import *


if __name__ == "__main__":
    mp.set_start_method('fork')

    # EXAMPLE
    smiles_list = ['CCCCC', 'CCC', 'CCCCCCCCCCCCC', 'CCCCOC', 'CC', 'C1cccccc1C']

    num_workers = 3
    num_iters = 3000
    min_frequency = 0
    mp_threshold = 1e5
    
    smiles_list = [(i, smi) for (i, smi) in enumerate(smiles_list)]
    batch_size = (len(smiles_list) - 1) // num_workers + 1
    batches = [smiles_list[i : i + batch_size] for i in range(0, len(smiles_list), batch_size)]
    mols = []
    # Convert SMILES strings to graph (MolGraph obj see load_batch_mols in utils.py)
    with mp.Pool(num_workers) as pool:
        for mols_batch in pool.imap(load_batch_mols, batches):
            mols.extend(mols_batch)

    # Check if a pair of nodes forms a fragment that belongs to the molecule (using the adjacency matrix).
    # This step identifies the initial fragments.
    # 'stats' keeps track of the total count for each fragment across all molecules.
    # 'indices' records how many of each fragment are present in each molecule.
    stats, indices = get_stats(mols, num_workers)

    output = open('merging_operation.txt', 'w')
    for i in range(num_iters): # K iteration (see page 3 on paper)
        print(f'[{datetime.now()}] Iteration {i}.')
        motif = max(stats, key=lambda x: (stats[x], x)) # Retrieve the fragment with the highest number of occurrences.
        if stats[motif] < min_frequency:
            # If the fragment with the highest occurrence count is less than the threshold, no action is taken.
            print(f'No motif has frequency >= {min_frequency}. Stopping.\n')
            break
        print(f'[Iteration {i}] Most frequent motif: {motif}, frequency: {stats[motif]}.\n')

        # Magic happens here:
        # First, it identifies the molecules containing the most popular fragment.
        # It uses the 'indices' variable to retrieve this information.
        # Next, it merges fragments into new ones and updates the 'stats' and 'indices' variables accordingly.
        # The statistics for the most popular fragments are reset to zero to avoid reprocessing them.
        apply_merging_operation(
            motif=motif,
            mols=mols,
            stats=stats,
            indices=indices,
            num_workers=num_workers if stats[motif] >= mp_threshold else 1,
        )
        # Save the most popular fragment to a file. 
        output.write(f'{motif}\n')

    # Close the file.
    output.close()

[2025-01-06 12:36:26.991623] Iteration 0.
[Iteration 0] Most frequent motif: CC, frequency: 27.

[2025-01-06 12:36:26.993123] Iteration 1.
[Iteration 1] Most frequent motif: CCCC, frequency: 8.

[2025-01-06 12:36:26.993677] Iteration 2.
[Iteration 2] Most frequent motif: CCCCCCCC, frequency: 2.

[2025-01-06 12:36:26.993805] Iteration 3.
[Iteration 3] Most frequent motif: CCCCC, frequency: 2.

[2025-01-06 12:36:26.993962] Iteration 4.
[Iteration 4] Most frequent motif: CO, frequency: 1.

[2025-01-06 12:36:26.994037] Iteration 5.
[Iteration 5] Most frequent motif: CCCCOC, frequency: 1.

[2025-01-06 12:36:26.994071] Iteration 6.
[Iteration 6] Most frequent motif: CCCCCCCCCCCCC, frequency: 1.

[2025-01-06 12:36:26.994121] Iteration 7.
[Iteration 7] Most frequent motif: CCC, frequency: 1.

[2025-01-06 12:36:26.994152] Iteration 8.
[Iteration 8] Most frequent motif: CC=CCCC, frequency: 1.

[2025-01-06 12:36:26.994291] Iteration 9.
[Iteration 9] Most frequent motif: CC1=CC=CC=CC1, frequency: 

# Motif-vocabulary Construction Phase (Phase 2) - see page 4

In [2]:
# If you plan to run both cells sequentially, please comment out the 'mp.set_start_method("fork")' line.

In [3]:
import multiprocessing as mp
import os
import os.path as path
import pickle
from collections import Counter
from datetime import datetime
from functools import partial
from typing import List, Tuple

from tqdm import tqdm

from mol_graph import MolGraph as MG

def apply_operations(batch):
    vocab = Counter()
    for smi in batch:
        mol = MG(smi, tokenizer='motif') # maybe k is defined here
        vocab = vocab + Counter(mol.motifs)
    return vocab

if __name__ == "__main__":
    
    mp.set_start_method('fork')
    num_workers = 2
    data_set = ['CCCCC', 'CCC', 'CCCCCCCCCCCCC', 'CCCCOC', 'CC', 'C1cccccc1C']
    batch_size = (len(data_set) - 1) // num_workers + 1
    batches = [data_set[i : i + batch_size] for i in range(0, len(data_set), batch_size)]
    print(f'Total: {len(data_set)} molecules.\n')
    num_operations = 500 # Determine how many operations to recover from phase 1.

    print(f'Processing...')
    vocab = Counter()
    # Loads all fragments founded in phase 1
    MG.load_operations('merging_operation.txt', num_operations)
    # Convert SMILES strings to graph (MG obj see mol_graph.py)
    # 'vocab = Counter()' is a dictionary that stores the frequency of each fragment in the dataset's molecules.
    # The key represents the fragment, and the value represents its frequency.
    with mp.Pool(num_workers, initializer=tqdm.set_lock, initargs=(mp.RLock(),)) as pool:
        for batch_vocab in pool.imap(apply_operations, batches):
            vocab = vocab + batch_vocab

    # 'MG.OPERATIONS' contains the merging operations performed during phase 1.
    atom_list = [x for (x, _) in vocab.keys() if x not in MG.OPERATIONS]
    atom_list.sort()
    new_vocab = []
    full_list = atom_list + MG.OPERATIONS
    for (x, y), value in vocab.items():
        assert x in full_list
        new_vocab.append((x, y, value))
        
    index_dict = dict(zip(full_list, range(len(full_list))))
    sorted_vocab = sorted(new_vocab, key=lambda x: index_dict[x[0]])
    with open('vocab.txt', 'w') as f:
        for (x, y, _) in sorted_vocab:
            f.write(f"{x} {y}\n")
    
    print(f"\r[{datetime.now()}] Motif vocabulary construction finished.")


Total: 6 molecules.

Processing...
[2025-01-06 12:36:27.505190] Motif vocabulary construction finished.
