# Degeneracy

## Parameters

In [None]:
# This cell is tagged `parameters`
# input
input_metanetx_tsv = "/Users/ggricourt/Documents/signature/metanetx/metanetx.4_4.tsv"
input_emolecules_sdf = "/Users/ggricourt/Documents/signature/emolecules/emolecules.2023-07-01.sdf.gz"
# output
output_database_sql = "database.sql"
output_degeneracy_json = "degeneracy.json"
output_stats_csv = "degeneracy.csv"
output_degeneracy_png = "degeneracy.png"
            
radius = 2
nbits = 2048
allHsExplicit = False
use_smarts = True
boundary_bonds = False
max_molecular_weight = 500
map_root = True
rooted_smiles = False
all_bits = True
threads = 2
tempdir = None

## Import

In [None]:
# missing deps: openbabel seaborn pandarallel ipywidgets
import argparse
import collections
import csv
import glob
import gzip
import json
import os
import math
import multiprocessing
import sqlite3
import tempfile
import time
import shutil
from typing import Dict, List, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from openbabel import pybel
from pandarallel import pandarallel
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
from signature.Signature import MoleculeSignature
from signature.signature_alphabet import compatible_alphabets, load_alphabet, merge_alphabets, sanitize_molecule, SignatureAlphabet, signature_from_smiles
from signature.utils import read_csv, read_tsv, read_txt, write_csv

RDLogger.DisableLog("rdApp.*")
pandarallel.initialize(nb_workers=threads, progress_bar=True)

## Functions

In [None]:
# SQL
def number_records(cursor, table: str) -> int:
    cursor.execute('SELECT COUNT(*) FROM %s' % (table,))
    res = cursor.fetchone()
    if res:
        return res[0]
    return -1

def check_table_exists(cursor, table: str) -> bool:
    cursor.execute('SELECT name FROM sqlite_master WHERE type="table" AND name="%s"' % (table,))
    if cursor.fetchone():
        return True
    return False

def check_record_exists(cursor, table: str, smiles: str) -> bool:
    cursor.execute('SELECT smiles FROM %s WHERE smiles=%s' % (table, smiles,))
    if cursor.fetchone():
        return True
    return False

def drop_duplicates_smiles(cursor, table: str) -> int:
    count_before = number_records(cursor=cursor, table=table)
    cursor.execute("CREATE TABLE new_table AS SELECT * FROM %s WHERE rowid IN (SELECT MIN(rowid) FROM %s GROUP BY smiles);" % (table, table,))
    cursor.execute("DROP TABLE %s;" % (table,))
    cursor.execute("ALTER TABLE new_table RENAME TO %s;" % (table,))
    #cursor.execute("VACUUM;")
    #cursor.execute("DELETE FROM %s WHERE rowid NOT IN (SELECT MIN(rowid) FROM %s GROUP BY smiles);" % (table, table,))
    count_after = number_records(cursor=cursor, table=table)
    return count_before - count_after

# Molecules
def read_sdf(path: str) -> Dict:
    for mol in pybel.readfile("sdf", path):        
        ident = mol.data["EMOL_VERSION_ID"]
        smi = mol.write("smi").replace("\n", "").replace("\t", "")
        yield dict(id=ident, smiles=smi)
        
def sanitize(smiles: str, max_molecular_weight: int) -> Optional[str]:
    # Remove molecules with weight > max_molecular_weight
    # and with more than one piece. Make sure all molecules
    # are unique.
    res = pd.NA
    if pd.isna(smiles) or smiles == "nan":
        return res
    if smiles.find(".") != -1:
        return res  # not in one piece
    if len(smiles) > int(max_molecular_weight / 5):  # Cheap skip
        return res
    mol, smiles = sanitize_molecule(Chem.MolFromSmiles(smiles), formalCharge=True)
    if mol is None:
        return res
    if Chem.Descriptors.ExactMolWt(mol) > max_molecular_weight:
        return res
    return smiles

def ecfp(smi: str, radius: int) -> str:
    fpgen = AllChem.GetMorganGenerator(radius=radius, fpSize=2048)
    # fp = fpgen.GetFingerprint(mol)  # returns a bit vector (value 1 or 0)
    fp = fpgen.GetCountFingerprint(AllChem.MolFromSmiles(smi))
    return "-".join([str(x) for x in fp.ToList()])

# Signature
def compute_sig(smi: str, radius: int, use_smarts: bool, boundary_bonds: bool, map_root: bool, rooted_smiles: bool, nbits: int, all_bits: bool):
    # Compute the molecule signature
    ms = MoleculeSignature(
        Chem.MolFromSmiles(smi),
        radius=radius,
        use_smarts=use_smarts,
        boundary_bonds=boundary_bonds,
        map_root=map_root,
        rooted_smiles=rooted_smiles,
        nbits=nbits,
        all_bits=all_bits
    )
    return ms.to_string(neighbors=False)
    
def compute_data(df: pd.DataFrame, max_molecular_weight: int, radius: int, use_smarts: bool, boundary_bonds: bool, map_root: bool, rooted_smiles: bool, nbits: int, all_bits: bool) -> pd.DataFrame:
    df.drop_duplicates(subset=["smiles"], inplace=True)
    
    print("\t", "Sanitize")
    start_time = time.time()
    df["smiles"] = df["smiles"].parallel_apply(sanitize, args=(max_molecular_weight,),)
    print("\t", "End sanitize", time.time() - start_time)
    df = df[~pd.isna(df["smiles"])]
    
    print("\t", "Compute signature")
    start_time = time.time()
    df["sig"] = df["smiles"].parallel_apply(compute_sig, args=(radius, use_smarts, boundary_bonds, map_root, rooted_smiles, nbits, all_bits,),)
    print("\t", "End signature", time.time() - start_time)
    
    print("\t", "Ecfp")
    start_time = time.time()
    df["ecfp"] = df["smiles"].parallel_apply(ecfp, args=(radius,))
    print("\t", "End ecfp", time.time() - start_time)
    
    return df

                       

## Create dabase

In [None]:
if os.path.exists(output_database_sql):
    os.remove(output_database_sql)
conn = sqlite3.connect(output_database_sql)
cursor = conn.cursor()

In [None]:
# MetaNetX
print("MetaNetX")
start = time.time()
df_metanetx = pd.read_csv(input_metanetx_tsv, sep="\t")

df_metanetx.rename(columns={"#ID": "id", "SMILES": "smiles"}, inplace=True)
df_metanetx.drop(columns=["name", "reference", "formula", "charge", "mass", "InChI", "InChIKey"], inplace=True)

df_metanetx = compute_data(
    df=df_metanetx,
    max_molecular_weight=max_molecular_weight,
    radius=radius,
    use_smarts=use_smarts,
    boundary_bonds=boundary_bonds,
    map_root=map_root,
    rooted_smiles=rooted_smiles,
    nbits=nbits,
    all_bits=all_bits
)
time_compute = time.time() - start
start = time.time()
df_metanetx.to_sql('metanetx', conn, if_exists='replace', index=False)
time_sql = time.time() - start
print("Time compute :", time_compute, "Time sql :", time_sql)

# Drop duplicates
count = drop_duplicates_smiles(cursor, table="metanetx")
print("Remove duplicates:", count)

In [None]:
# emolecules
print("eMolecules")

# Init
max_size = 5e5
data = []
current_item = 0
current_file = 0
if tempdir is None or tempdir == "":
    tempdir = tempfile.mkdtemp()

# Parse and split emolecules for memory purposes
for ix, rec in enumerate(read_sdf(path=input_emolecules_sdf)):
    if current_item > max_size - 1:
        df = pd.DataFrame(data)
        df.to_csv(os.path.join(tempdir, "emolecules." + str(current_file) + ".raw.csv.gz"), index=False)
        data = []
        current_item = 0
        current_file += 1
        
    data.append(rec)
    current_item += 1

df = pd.DataFrame(data)
df.to_csv(os.path.join(tempdir, "emolecules." + str(current_file) + ".raw.csv.gz"), index=False)
current_file += 1

# Compute signature
for filename in glob.glob(os.path.join(tempdir, "emolecules.*.raw.csv.gz")):
    df = pd.read_csv(filename)
    df = compute_data(
        df=df,
        max_molecular_weight=max_molecular_weight,
        radius=radius,
        use_smarts=use_smarts,
        boundary_bonds=boundary_bonds,
        map_root=map_root,
        rooted_smiles=rooted_smiles,
        nbits=nbits,
        all_bits=all_bits
    )
    df.to_sql("emolecules", conn, if_exists="append", index=False)

# Clean up
shutil.rmtree(tempdir, ignore_errors=True)

# Drop duplicates
count = drop_duplicates_smiles(cursor, table="emolecules")
print("Remove duplicates:", count)

In [None]:
conn.close()

## Compute

In [None]:
conn = sqlite3.connect(output_database_sql)
cursor = conn.cursor()

In [None]:
COLS = {
    "sig": "Signature",
    "smiles": "SMILES",
    "ecfp": "ECFP" + str(radius * 2)
}
stats = {}
for source in ["metanetx", "emolecules"]:
    stats[source] = {}
    total = 0
    for ix, col in enumerate(COLS.keys()):
        print("Analyze:", col)
        # Parse file
        collision = collections.Counter()
        
        print("Query - start")
        start = time.time()
        cursor.execute("SELECT COUNT(*) FROM %s GROUP BY %s;" % (source, col,))
        for res in cursor.fetchall():
            collision.update([res[0]])
        print("Query - end", time.time() - start)
        stats[source][col] = dict(collision) # key: occurences, value: number of occurences

        # Get total
        if ix < 1:
            stats[source]["total"] = number_records(cursor=cursor, table=source)

# Save
with open(output_degeneracy_json, "w") as fd:
    json.dump(stats, fd)

In [None]:
conn.close()

## Graph

In [None]:
def create_df(data: Dict, src: str) -> pd.DataFrame:
    del data["total"]
    df_raw = pd.DataFrame(data)

    # Rename cols
    df_raw.rename(columns=COLS, inplace=True)

    df = pd.DataFrame()
    for col in df_raw.columns:
        df_col = df_raw[col].to_frame()
        df_col["label"] = col
        df_col.rename(columns={col: "count"}, inplace=True)
        df = pd.concat([df, df_col])

    # Create "duplicate"
    df.reset_index(inplace=True)
    df.rename(columns={"index": "duplicate"}, inplace=True)

    # Fmt
    df.fillna(0, inplace=True)
    df["duplicate"] = df["duplicate"].astype(int)
    df["count"] = df["count"].astype(int)
    df["count_log"] = df["count"].apply(lambda x: math.log(x) if x > 0 else 0)

    df["src"] = src

    df = df[df["label"] != "smiles"]
    return df

In [None]:
data = json.load(open(output_degeneracy_json))

# Load data
df_metanetx = create_df(data=data["metanetx"], src="metanetx")
df_emolecules = create_df(data=data["emolecules"], src="emolecules")

df = pd.concat([df_metanetx, df_emolecules])
df.sort_values("duplicate", inplace=True)
df.reset_index(inplace=True, drop=True)

# Graph
sns.set_theme(style="darkgrid")
ax = sns.lineplot(data=df, x="duplicate", y="count_log", hue="src", style="label")
ax.set(xlabel="Duplicate", ylabel="Occurence (log)")

df.to_csv(output_stats_csv, index=False)
plt.savefig(output_degeneracy_png)