# **3. Analysis of eMolecules and MetaNetX alphabets in terms of size and diversity**

In this notebook, we investigate eMolecules and MetaNetX alphabet. The 
calculated metrics are presented in Figure 2 of the manuscript, as well
as Figures S1 and S3 of the supplementary material.

## Parameters

In [None]:
# This cell is tagged `parameters`
# input
input_metanetx_tsv = "metanetx.train.tsv.gz"
input_emolecules_tsv = "emolecules.train.tsv.gz"
input_chembl_tsv = "smiles_chembl_stereo_500.tsv.gz"

outdir = "/data/signature/results"


threads = 14
tempdir = "/data/tmp"

# Fixed parameters
nbits = 2048

#radius = 2
ECFP_NBITS = [512, 1024, 2048]
SOURCES = ["metanetx", "emolecules", "chembl"]

## Import

In [None]:
import csv
import glob
import gzip
import hashlib
import json
import os
import multiprocessing
import sqlite3
import tempfile
from typing import Dict, List, Optional

import numpy as np
import pandas as pd

from pandarallel import pandarallel
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
from molsig.Signature import MoleculeSignature

RDLogger.DisableLog("rdApp.*")
pandarallel.initialize(nb_workers=threads, progress_bar=True)

if tempdir:
    os.makedirs(tempdir, exist_ok=True)
    tempdir_sqlite = os.path.join(tempdir, "sqlite")
    os.makedirs(tempdir_sqlite, exist_ok=True)
    os.environ["SQLITE_TMPDIR"] = tempdir_sqlite

## Functions

In [None]:
# Molecules
def read_tsv(path: str) -> Dict:
    with gzip.open(path, "rt", newline="") as csvfile:
        reader = csv.DictReader(csvfile, delimiter="\t")
        for ix, row in enumerate(reader):
            ident = str(ix)
            smi = row["SMILES"]
            if smi and len(smi) > 0:
                 yield dict(id=ident, smiles=smi)

def compute_hash(label: str, size: int = 64):
    return hashlib.shake_256(label.encode("utf8")).hexdigest(size)

def ecfp(smi: str, radius: int, nbits: int) -> str:
    fpgen = AllChem.GetMorganGenerator(radius=radius, fpSize=nbits, includeChirality=True)
    # fp = fpgen.GetFingerprint(mol)  # returns a bit vector (value 1 or 0)
    fp = fpgen.GetCountFingerprint(AllChem.MolFromSmiles(smi))
    secfp = "".join([str(x) for x in fp.ToList()])
    return compute_hash(secfp)

# Signature
def compute_signature(smi: str, radius: int):
    try:
        ms = MoleculeSignature(
            Chem.MolFromSmiles(smi),
            radius=radius,
            nbits=0,
            use_stereo=False,
        )
        return compute_hash(ms.to_string())
    except:
        pass
    return None

def compute_fragment(smi: str, radius: int, nbits: int):
    # Compute the molecule signature
    try:
        ms = MoleculeSignature(
            Chem.MolFromSmiles(smi),
            radius=radius,
            nbits=nbits,
        )
        return ",".join([compute_hash(x) for x in ms.to_list()])
    except:
        pass
    return None

def compute_data(df: pd.DataFrame, nbits: int, radius: int) -> pd.DataFrame:
    df.drop_duplicates(subset=["smiles"], inplace=True)

    print("\t", "Compute fragments")
    df["fragment"] = df["smiles"].parallel_apply(compute_fragment, args=(radius, nbits,))

    print("\t", "Compute signature")
    df["sig"] = df["smiles"].parallel_apply(compute_signature, args=(radius, nbits,))

    df = df[~pd.isna(df["fragment"])]
    df = df[~pd.isna(df["sig"])]

    for ecfp_nbits in ECFP_NBITS:
        print("\t", "Ecfp nbits:", ecfp_nbits)
        df[f"ecfp_{ecfp_nbits}"] = df["smiles"].parallel_apply(ecfp, args=(radius, ecfp_nbits))

    return df

## Create databases of signatures/fragments

In [None]:
for radius in range(7):
    print("Radius:", radius)
    output_database_sql = os.path.join(outdir, f"radius-{radius}.sql")
    
    conn = sqlite3.connect(output_database_sql)

    def compute_file(conn, origin: str, path: str, tempdir: str, radius: int):
        # Origin, either "emolecules", "metanetx", chembl
        print("Compute:", origin)
        tempdir_compute = os.path.join(tempdir, origin)
        os.makedirs(tempdir_compute, exist_ok=True)
        if radius == 0:
            # Init
            max_size = 1e6
            data = []
            current_item = 0
            current_file = 0

            # Parse and split file for memory purposes
            for ix, rec in enumerate(read_tsv(path=path)):
                if current_item > max_size - 1:
                    df = pd.DataFrame(data)
                    df.to_csv(os.path.join(tempdir_compute, ".".join([origin, str(current_file), "raw", "csv", "gz"])), index=False)
                    data = []
                    current_item = 0
                    current_file += 1
                data.append(rec)
                current_item += 1

            df = pd.DataFrame(data)
            df.to_csv(os.path.join(tempdir_compute, ".".join([origin, str(current_file), "raw", "csv", "gz"])), index=False)
            current_file += 1
        # Compute signature
        cursor = conn.cursor()
        for ix, filename in enumerate(glob.glob(os.path.join(tempdir_compute, f"{origin}.*.raw.csv.gz"))):
            print("\t", "Batch:", ix)
            df = pd.read_csv(filename)
            df = compute_data(
                df=df,
                nbits=nbits,
                radius=radius,
            )
            if df.empty:
                continue
                
            df.to_sql(origin, conn, if_exists="append", index=False, chunksize=int(1e4))
            del df

    # Compute
    for origin, path in zip(["metanetx", "emolecules", "chembl"], [input_metanetx_tsv, input_emolecules_tsv, input_chembl_tsv]):
        compute_file(conn=conn, origin=origin, path=path, tempdir=tempdir, radius=radius)
    
    conn.close()

## Create metrics

In [None]:
# Define colnames expected in tables
COLNAMES = ["smiles", "sig"] + [f"ecfp_{ecfp_nbits}" for ecfp_nbits in ECFP_NBITS]

for radius in RADIUS
    print("Radius:", radius)

    # Define files
    output_database_db = os.path.join(outdir, f"radius-{radius}.db")
    output_degeneracy_csv = os.path.join(outdir, f"degeneracy.radius-{radius}.csv")
    output_diversity_csv = os.path.join(outdir, f"diversity.radius-{radius}.csv")
    output_alphabet_csv = os.path.join(outdir, f"alphabet.radius-{radius}.csv")

    if not os.path.isfile(output_database_db):
        print(f"Output database: {output_database_db} not found -> Skip")
        continue
    if os.path.isfile(output_degeneracy_csv) and os.path.isfile(output_diversity_csv) and os.path.isfile(output_alphabet_csv):
        continue

    # Total records
    print("Total records")
    conn = sqlite3.connect(output_database_db)
    cursor = conn.cursor()
    total_records = {}
    for source in SOURCES:
        total_records[source] = number_records(cursor=cursor, table=source)
    conn.close()

    # Degeneracy
    if not os.path.isfile(output_degeneracy_csv):
        print("Degeneracy")
        conn = sqlite3.connect(output_database_db)
        cursor = conn.cursor()

        def compute_degeneracy(cursor, source: str, colname: str, radius: int):
            print("Query - start:", source, colname)
            start = time.time()
            collision = collections.Counter() # key: occurences, value: number of occurences
            table_name = source
            cursor.execute(f"SELECT COUNT(*) FROM {table_name} GROUP BY {colname}")
            for res in fetch_in_batches(cursor=cursor):
                collision.update([res[0]])
            print("Query - end", time.time() - start)
            return dict(collision)

        datas = []
        for source, colname in itertools.product(SOURCES, COLNAMES):
            print("Compute for:", source, " - ", colname)
            data_deg = compute_degeneracy(cursor=cursor, source=source, colname=colname, radius=radius)
            for occurence, value in data_deg.items():
                data = dict(item=colname, source=source, kind="occurence", occurence=occurence, value=value, radius=radius)
                datas.append(data)
            datas.append(dict(item=colname, source=source, kind="total", value=total_records[source], radius=radius))
        df = pd.DataFrame(datas)
        df.to_csv(output_degeneracy_csv, index=False)

        conn.close()

    # Diversity
    if not os.path.isfile(output_diversity_csv):
        print("Diversity")
        conn = sqlite3.connect(output_database_sql)
        cursor = conn.cursor()

        # Build depth
        METRICS = [
            "ace",
            "chao1",
            "dominance",
            "gini_index",
            "hill",
            "inv_simpson",
            "observed_features",
            "pielou_e",
            "simpson",
            "simpson_d",
            "simpson_e",
            "shannon",
        ]

        # Build depth
        depths = np.logspace(1, 7, num=7, base=10, dtype=int)
        for depth in depths[::-1]:
            depths = np.append(depths, int(depth / 2))
        depths = np.sort(depths)
        depths = depths[:-1] # rm 10e6

        datas = []
        for seed, (batch, source, depth) in enumerate(itertools.product(range(10), SOURCES, depths)):

            start = time.time()
            data = dict(batch=batch, source=source, depth=depth, seed=seed, radius=radius)

            print("\t", "Deal with:", data)
            # Select Indices
            total_record = total_records[source]
            if total_record < depth:
                continue
            if total_record == depth:
                indices = list(range(total_record))
            else:
                rng = np.random.default_rng(seed)
                indices = np.sort(rng.choice(total_record, size=depth, replace=False, shuffle=False))
            cur_index = 0

            # Build taxons
            print("\t", "Build taxons")
            counter = collections.Counter()
            cursor.execute(f"SELECT fragment FROM {source}")
            for ix, res in enumerate(fetch_in_batches(cursor=cursor)):
                if cur_index >= depth:
                    break
                if ix == indices[cur_index]:
                    counter.update(res[0].split(","))
                    cur_index += 1
            assert cur_index == depth
            taxons = list(counter.values())

            # Aggressive clean up
            del counter

            # Compute Alpha div
            print("\t", "Compute alpha div")
            for metric in METRICS:
                value = None
                try:
                    value = skbio.diversity.alpha_diversity(metric, taxons).iloc[0]
                except:
                    pass
                data_metric = copy.deepcopy(data)
                data_metric["metric"] = metric
                data_metric["value"] = value
                datas.append(data_metric)

            print("\t", "End, time:", time.time() - start)

            # Aggressive clean up
            del taxons

        df = pd.DataFrame(datas)
        # Save
        df.to_csv(output_diversity_csv, index=False)

        conn.close()

    # Alphabet
    if not os.path.isfile(output_alphabet_csv):
        print("Alphabet")
        conn = sqlite3.connect(output_database_db)
        cursor = conn.cursor()

        for source in SOURCES:
            alphabets = set()
            cursor.execute(f"SELECT fragment FROM {source}")
            datas = []
            for res in fetch_in_batches(cursor=cursor), total=total_records[source]:
                fragments = res[0].split(",")
                nb_added = 0
                for fragment in fragments:
                    if fragment in alphabets:
                        continue
                    alphabets.add(fragment)
                    nb_added += 1
                is_added = int(nb_added > 0)
                datas.append(dict(is_added=is_added, nb_fragment_added=nb_added, nb_fragment=len(fragments), source=source))
        df = pd.DataFrame(datas)
        df.index.name = "step"

        df.to_csv(output_alphabet_csv)

        conn.close()