## Parameters

In this notebook, 

In [None]:
# This cell is tagged `parameters`
# Input
input_metanetx_tsv = "metanetx.train.tsv.gz"
input_emolecules_tsv = "emolecules.tsv.gz"

# Output
outdir = "."

# Systems
threads = 12
workdir = "/data/tmp" # None (use system)

# Fixed parameters
nbits = 2048
allHsExplicit = False
boundary_bonds = False
map_root = True
rooted_smiles = False

RADIUS = [0, 1, 2, 3, 4, 5, 6]
ECFP_NBITS = [512, 1024, 2048]
SOURCES = ["metanetx", "emolecules"]

## Import

In [None]:
# missing deps: seaborn pandarallel ipywidgets scikit-bio
import collections
import csv
import glob
import gzip
import hashlib
import itertools
import json
import os
import sqlite3
import tempfile
import time
import shutil
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import skbio
from pandarallel import pandarallel
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit import RDLogger
from signature.Signature import MoleculeSignature
from signature.signature_alphabet import SignatureAlphabet, signature_from_smiles


RDLogger.DisableLog("rdApp.*")
pandarallel.initialize(nb_workers=threads, progress_bar=True)
if workdir is None:
    workdir = tempfile.mkdtemp()
else:
    os.makedirs(workdir, exist_ok=True)
workdir_sqlite = os.path.join(workdir, "sqlite")
os.makedirs(workdir_sqlite, exist_ok=True)
os.environ["SQLITE_TMPDIR"] = workdir_sqlite

## Functions

In [None]:
# SQL
def number_records(cursor, table: str) -> int:
    cursor.execute('SELECT COUNT(*) FROM %s' % (table,))
    res = cursor.fetchone()
    if res:
        return res[0]
    return -1

def fetch_in_batches(cursor, batch_size=1000):
    while True:
        rows = cursor.fetchmany(batch_size)
        if not rows:
            break
        for row in rows:
            yield row

def file_size(path:str):
    file_size_bytes = os.path.getsize(path)
    file_size_mb = file_size_bytes / (1024 * 1024)
    return file_size_mb

def vacuum(source_path: str, tempdir: str):
    print("Start Vacuum")
    source_conn = sqlite3.connect(source_path)
    fdb = tempfile.NamedTemporaryFile(dir=tempdir, suffix=".db")
    target_conn = sqlite3.connect(fdb.name)

    with target_conn:
        for line in source_conn.iterdump():
            if line.strip():
                target_conn.execute(line)
    source_conn.close()
    target_conn.close()

    file_in_size = file_size(path=source_path)
    file_out_size = file_size(path=fdb.name)

    shutil.copyfile(fdb.name, source_path)
    print("Gain of size (Mb):", file_in_size - file_out_size)

# Molecules
def read_tsv(path: str) -> Dict:
    csvfile = gzip.open(path, "rt", newline="") if path.endswith(".gz") else open(path, "r", newline=="")
    reader = csv.DictReader(csvfile, delimiter="\t")
    for ix, row in enumerate(reader):
        ident = str(ix)
        smi = row["SMILES"]
        if smi and len(smi) > 0:
            yield dict(id=ident, smiles=smi)
    csvfile.close()

def compute_hash(label: str, size: int = 64):
    return hashlib.shake_256(label.encode("utf8")).hexdigest(size)

def ecfp(smi: str, radius: int, nbits: int) -> str:
    fpgen = AllChem.GetMorganGenerator(radius=radius, fpSize=nbits, includeChirality=True)
    fp = fpgen.GetCountFingerprint(AllChem.MolFromSmiles(smi))
    secfp = "".join([str(x) for x in fp.ToList()])
    return compute_hash(secfp)

# Signature
def compute_sig(smi: str, alphabet):
    sig, _, _ = signature_from_smiles(smi, alphabet)
    return compute_hash(sig)

def compute_fragment(smi: str, radius: int, boundary_bonds: bool, map_root: bool, rooted_smiles: bool, nbits: int):
    ms = MoleculeSignature(
        Chem.MolFromSmiles(smi),
        radius=radius,
        boundary_bonds=boundary_bonds,
        map_root=map_root,
        rooted_smiles=rooted_smiles,
        nbits=nbits,
    )
    return ",".join([compute_hash(x) for x in ms.to_list()])

# Compute data for fragment, signature, ECFP
def compute_data(df: pd.DataFrame, boundary_bonds: bool, map_root: bool, rooted_smiles: bool, nbits: int, radius: int) -> pd.DataFrame:
    df.drop_duplicates(subset=["smiles"], inplace=True)

    print("\t", "Compute fragments")
    df["fragment"] = df["smiles"].parallel_apply(compute_fragment, args=(radius, boundary_bonds, map_root, rooted_smiles, nbits,))

    print("\t", "Compute signature")
    alphabet = SignatureAlphabet(radius=radius, nBits=nbits)
    df["sig"] = df["smiles"].parallel_apply(compute_sig, args=(alphabet,))

    for ecfp_nbits in ECFP_NBITS:
        print("\t", "Ecfp nbits:", ecfp_nbits)
        df[f"ecfp_{ecfp_nbits}"] = df["smiles"].parallel_apply(ecfp, args=(radius, ecfp_nbits))

    return df

## Create databases of signatures/fragments

In [None]:
for step_radiux, radius in enumerate(RADIUS):
    print("Radius:", radius)
    output_database_db = os.path.join(outdir, f"radius-{radius}.db")
    conn = sqlite3.connect(output_database_db)

    def compute_file(conn, origin: str, path: str, tempdir: str, radius: int):
        # Origin: "emolecules" | "metanetx"
        print("Compute:", origin)
        tempdir_compute = os.path.join(tempdir, origin)
        os.makedirs(tempdir_compute, exist_ok=True)

        if step_radius == 0:
            max_size = 1e6
            data = []
            current_item = 0
            current_file = 0

            # Parse and split file for memory purposes
            for ix, rec in enumerate(read_tsv(path=path)):
                if current_item > max_size - 1:
                    df = pd.DataFrame(data)
                    df.to_csv(os.path.join(tempdir_compute, ".".join([origin, str(current_file), "raw", "csv", "gz"])), index=False)
                    data = []
                    current_item = 0
                    current_file += 1
                data.append(rec)
                current_item += 1

            df = pd.DataFrame(data)
            df.to_csv(os.path.join(tempdir_compute, ".".join([origin, str(current_file), "raw", "csv", "gz"])), index=False)
            current_file += 1

        # Compute signature
        cursor = conn.cursor()
        for ix, filename in enumerate(glob.glob(os.path.join(tempdir_compute, f"{origin}.*.raw.csv.gz"))):
            print("\t", "Batch:", ix)
            df = pd.read_csv(filename)
            df = compute_data(
                df=df,
                boundary_bonds=boundary_bonds,
                map_root=map_root,
                rooted_smiles=rooted_smiles,
                nbits=nbits,
                radius=radius,
            )
            if df.empty:
                continue
            df.to_sql(origin, conn, if_exists="append", index=False, chunksize=int(1e4))
            del df

        # Clean up
        if step_radius == len(RADIUS - 1):
            shutil.rmtree(tempdir_compute, ignore_errors=True)

    # MetaNetX
    compute_file(conn=conn, origin="metanetx", path=input_metanetx_tsv, tempdir=workdir, radius=radius)

    # eMolecules
    compute_file(conn=conn, origin="emolecules", path=input_emolecules_tsv, tempdir=workdir, radius=radius)

    conn.close()

    vacuum(source_path=output_database_db, tempdir=workdir)

## Create metrics

In [None]:
# Define colnames expected in tables
COLNAMES = ["smiles", "sig"] + [f"ecfp_{ecfp_nbits}" for ecfp_nbits in ECFP_NBITS]

for radius in RADIUS
    print("Radius:", radius)

    # Define files
    output_database_db = os.path.join(outdir, f"radius-{radius}.db")
    output_degeneracy_csv = os.path.join(outdir, f"degeneracy.radius-{radius}.csv")
    output_diversity_csv = os.path.join(outdir, f"diversity.radius-{radius}.csv")
    output_alphabet_csv = os.path.join(outdir, f"alphabet.radius-{radius}.csv")

    if not os.path.isfile(output_database_db):
        print(f"Output database: {output_database_db} not found -> Skip")
        continue
    if os.path.isfile(output_degeneracy_csv) and os.path.isfile(output_diversity_csv) and os.path.isfile(output_alphabet_csv):
        continue

    # Total records
    print("Total records")
    conn = sqlite3.connect(output_database_db)
    cursor = conn.cursor()
    total_records = {}
    for source in SOURCES:
        total_records[source] = number_records(cursor=cursor, table=source)
    conn.close()

    # Degeneracy
    if not os.path.isfile(output_degeneracy_csv):
        print("Degeneracy")
        conn = sqlite3.connect(output_database_db)
        cursor = conn.cursor()

        def compute_degeneracy(cursor, source: str, colname: str, radius: int):
            print("Query - start:", source, colname)
            start = time.time()
            collision = collections.Counter() # key: occurences, value: number of occurences
            table_name = source
            cursor.execute(f"SELECT COUNT(*) FROM {table_name} GROUP BY {colname}")
            for res in fetch_in_batches(cursor=cursor):
                collision.update([res[0]])
            print("Query - end", time.time() - start)
            return dict(collision)

        datas = []
        for source, colname in itertools.product(SOURCES, COLNAMES):
            print("Compute for:", source, " - ", colname)
            data_deg = compute_degeneracy(cursor=cursor, source=source, colname=colname, radius=radius)
            for occurence, value in data_deg.items():
                data = dict(item=colname, source=source, kind="occurence", occurence=occurence, value=value, radius=radius)
                datas.append(data)
            datas.append(dict(item=colname, source=source, kind="total", value=total_records[source], radius=radius))
        df = pd.DataFrame(datas)
        df.to_csv(output_degeneracy_csv, index=False)

        conn.close()

    # Diversity
    if not os.path.isfile(output_diversity_csv):
        print("Diversity")
        conn = sqlite3.connect(output_database_sql)
        cursor = conn.cursor()

        # Build depth
        METRICS = [
            "ace",
            "chao1",
            "dominance",
            "gini_index",
            "hill",
            "inv_simpson",
            "observed_features",
            "pielou_e",
            "simpson",
            "simpson_d",
            "simpson_e",
            "shannon",
        ]

        # Build depth
        depths = np.logspace(1, 7, num=7, base=10, dtype=int)
        for depth in depths[::-1]:
            depths = np.append(depths, int(depth / 2))
        depths = np.sort(depths)
        depths = depths[:-1] # rm 10e6

        datas = []
        for seed, (batch, source, depth) in enumerate(itertools.product(range(10), SOURCES, depths)):

            start = time.time()
            data = dict(batch=batch, source=source, depth=depth, seed=seed, radius=radius)

            print("\t", "Deal with:", data)
            # Select Indices
            total_record = total_records[source]
            if total_record < depth:
                continue
            if total_record == depth:
                indices = list(range(total_record))
            else:
                rng = np.random.default_rng(seed)
                indices = np.sort(rng.choice(total_record, size=depth, replace=False, shuffle=False))
            cur_index = 0

            # Build taxons
            print("\t", "Build taxons")
            counter = collections.Counter()
            cursor.execute(f"SELECT fragment FROM {source}")
            for ix, res in enumerate(fetch_in_batches(cursor=cursor)):
                if cur_index >= depth:
                    break
                if ix == indices[cur_index]:
                    counter.update(res[0].split(","))
                    cur_index += 1
            assert cur_index == depth
            taxons = list(counter.values())

            # Aggressive clean up
            del counter

            # Compute Alpha div
            print("\t", "Compute alpha div")
            for metric in METRICS:
                value = None
                try:
                    value = skbio.diversity.alpha_diversity(metric, taxons).iloc[0]
                except:
                    pass
                data_metric = copy.deepcopy(data)
                data_metric["metric"] = metric
                data_metric["value"] = value
                datas.append(data_metric)

            print("\t", "End, time:", time.time() - start)

            # Aggressive clean up
            del taxons

        df = pd.DataFrame(datas)
        # Save
        df.to_csv(output_diversity_csv, index=False)

        conn.close()

    # Alphabet
    if not os.path.isfile(output_alphabet_csv):
        print("Alphabet")
        conn = sqlite3.connect(output_database_db)
        cursor = conn.cursor()

        for source in SOURCES:
            alphabets = set()
            cursor.execute(f"SELECT fragment FROM {source}")
            datas = []
            for res in fetch_in_batches(cursor=cursor), total=total_records[source]:
                fragments = res[0].split(",")
                nb_added = 0
                for fragment in fragments:
                    if fragment in alphabets:
                        continue
                    alphabets.add(fragment)
                    nb_added += 1
                is_added = int(nb_added > 0)
                datas.append(dict(is_added=is_added, nb_fragment_added=nb_added, nb_fragment=len(fragments), source=source))
        df = pd.DataFrame(datas)
        df.index.name = "step"

        df.to_csv(output_alphabet_csv)

        conn.close()