In [1]:
from MS2LDA.motif_parser import load_m2m_folder
from MS2LDA.Add_On.MassQL.MassQL4MotifDB import load_motifDB, motifDB2motifs
from MS2LDA.utils import retrieve_spec4doc

from MS2LDA.Add_On.Fingerprints.FP_annotation import annotate_motifs as calc_fingerprints

import pickle
import tomotopy as tp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
#from adjustText import adjust_text

from rdkit.Chem import RDKFingerprint
from rdkit.DataStructs import TanimotoSimilarity
import numpy as np
from tqdm import tqdm
from rdkit import DataStructs
from rdkit.DataStructs.cDataStructs import ExplicitBitVect
from rdkit.Chem.Draw import MolsToGridImage
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import MolFromSmarts


from typing import Dict, List, Optional
from rdkit.Chem import MolFromSmiles, rdFMCS, RDKFingerprint
from rdkit.DataStructs import TanimotoSimilarity
import numpy as np
from tqdm import tqdm


In [2]:
motifDB_1, motifDB_2 = load_motifDB("/home/ioannis/thesis_data/filtered_pos_output_w1_1/motifset_optimized.json")
motifs = motifDB2motifs(motifDB_2)
with open('/home/ioannis/thesis_data/filtered_pos_output_w1_1/doc2spec_map.pkl', 'rb') as f:
    doc2spec_map = pickle.load(f)
lda_model = tp.LDAModel.load('/home/ioannis/thesis_data/filtered_pos_output_w1_1/ms2lda.bin')

In [3]:

#USEFUL

def safe_mol(smiles: Optional[str]):
    if smiles is None:
        return None
    try:
        return MolFromSmiles(smiles)
    except:
        return None

def calculate_sos(fp1, fp2):
    if sum(fp1) < sum(fp2):
        smaller_fp = fp1
        bigger_fp = fp2
    else:
        smaller_fp = fp2
        bigger_fp = fp1
    
    smaller_fp_sum = sum(smaller_fp)
    fp_intersection = 0
    for bit1, bit2 in zip(smaller_fp, bigger_fp):
        if bit1 == 1 and bit2 == 1:
            fp_intersection += 1

    if fp_intersection == 0:
        return 0
    else:
        return fp_intersection / smaller_fp_sum

def compute_pairwise_similarity(mols: List, threshold: float = 0.5) -> float:
    if len(mols) < 2:
        return 0.0
    fps = [RDKFingerprint(m) for m in mols]
    total = 0
    matches = 0
    n = len(fps)
    for i in range(n):
        for j in range(i + 1, n):
            total += 1
            if TanimotoSimilarity(fps[i], fps[j]) >= threshold:
                matches += 1
    return matches / total if total > 0 else 0.0


def compute_mcs_num_atoms(mols):
    if not mols:
        return 0, ""
    try:
        mcs = rdFMCS.FindMCS(
            mols,
            bondCompare=rdFMCS.BondCompare.CompareAny,
            completeRingsOnly=True,
            ringMatchesRingOnly=True,
            timeout=600,
        )
        return int(mcs.numAtoms), mcs.smartsString
    except:
        return 0, ""

In [4]:
def process_motifs(motifs, lda_model, doc2spec_map, prob_threshold: float = 0.5, sim_threshold: float = 0.5, sos_cal = True) -> dict:
    motif_ids = []
    num_atoms = []
    len_frag_loss = []
    mcs_smarts = []
    intra_sims = []
    inter_sims = []
    molecules_by_motif = {}

    for motif in tqdm(motifs):
        annotation = motif.get("auto_annotation", [])
        if len(annotation) <= 1:
            continue

        try:
            motif_id = int(motif.get("id").split("_")[1])
        except:
            continue

        motif_ids.append(motif_id)
        len_frag_loss.append(len(getattr(motif.peaks, "mz", [])))

        # MCS calculation
        ann_mols = [safe_mol(s) for s in annotation]
        ann_mols = [m for m in ann_mols if m is not None]
        n_atoms, smarts = compute_mcs_num_atoms(ann_mols)
        num_atoms.append(n_atoms)
        mcs_smarts.append(smarts)

        # ---------- Compute motif representative fingerprint (Daylight) ----------
        rep_fp_ = calc_fingerprints([annotation], fp_type="maccs", threshold=0.9)[0]
        
        # Create an empty ExplicitBitVect of the same length
        rep_fp = ExplicitBitVect(len(rep_fp_))
        
        # Set bits based on the numpy array
        for i, bit in enumerate(rep_fp_):
            if bit:
                rep_fp.SetBit(i)


        # ---------- Collect candidate molecules ----------
        candidate_mols = []
        for doc_id, doc in enumerate(lda_model.docs):
            for _motif_id, prob in doc.get_topics():
                if _motif_id == motif_id and prob >= prob_threshold:
                    spec = retrieve_spec4doc(doc2spec_map, lda_model, doc_id)
                    mol = safe_mol(spec.get("smiles"))
                    if mol is not None:
                        candidate_mols.append(mol)

        molecules_by_motif[motif_id] = candidate_mols

        # ---------- Intra similarity (candidates vs candidates) ----------
        if len(candidate_mols) < 2:
            intra_sims.append(0.0)
        else:
            fps = [RDKFingerprint(m) for m in candidate_mols]
            total = 0
            matches = 0
            n = len(fps)
            for i in range(n):
                for j in range(i + 1, n):
                    total += 1
                    if TanimotoSimilarity(fps[i], fps[j]) >= sim_threshold:
                        matches += 1
            intra_sims.append(matches / total if total > 0 else 0.0)

        # ---------- Inter similarity (motif representative vs candidates) ----------
        if rep_fp is None or not candidate_mols:
            inter_sims.append(0.0)
        
        elif sos_cal == True:
            fps = [RDKFingerprint(m) for m in candidate_mols]

            # convert ExplicitBitVect → list of ints so calculate_sos can use sum() and zip()
            rep_fp_list = list(rep_fp)

            sims = []
            for fp in fps:
                fp_list = list(fp)
                sim = calculate_sos(rep_fp_list, fp_list)
                sims.append(sim)

            inter_sims.append(float(np.mean(sims)))
        else:
            fps = [RDKFingerprint(m) for m in candidate_mols]
            sims = [TanimotoSimilarity(rep_fp, fp) for fp in fps]
            inter_sims.append(float(np.mean(sims)))

    return {
        "motif_ids": motif_ids,
        "num_atoms": num_atoms,
        "len_frag_loss": len_frag_loss,
        "mcs_smarts": mcs_smarts,
        "molecules_by_motif": molecules_by_motif,
        "intra_sims": intra_sims,
        "inter_sims": inter_sims
    }


In [None]:
results_sos = process_motifs(motifs, lda_model, doc2spec_map, sos_cal = True)

 38%|███▊      | 199/522 [29:45<18:49,  3.50s/it]    