In [1]:
import numpy as np
import pandas as pd
import os
pd.set_option("display.max_columns", None)
import math
import re
from tqdm.notebook import tqdm
tqdm.pandas()
import kaleido
from rdkit import Chem
from rdkit.Chem import AllChem
import warnings
warnings.filterwarnings("ignore")
from itertools import combinations
import pickle
import networkx as nx
from IPython.display import clear_output
from plotly.graph_objects import Figure
import plotly.express as px
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
clear_output()

In [2]:
! ls ../data

clean_data  raw_data


# Read and preprocess data

In [2]:
activity = pd.read_parquet('../data/raw_data/activities.prqt')
activity = activity[['activity_chembl_id', 'salt_chembl_id', 'molecule_chembl_id',
                     'target_chembl_id', 'assay_chembl_id', 'document_chembl_id',
                     'standard_type', 'standard_value', 'bao_endpoint']]
activity['standard_type'] = activity['standard_type'].astype('category')
activity['bao_endpoint'] = activity['bao_endpoint'].astype('category')

document = pd.read_parquet('../data/raw_data/document.prqt').drop('classification', axis=1)
document['authors'] = document.authors.apply(lambda x: set(x.replace('.', '').split(', ')))
document['postcodes'] = document.postcodes.apply(lambda x: set(x.split(', ')) if type(x) == str else x)
classification = pd.read_parquet('../data/raw_data/classification_17_10.prqt')
document = document.merge(classification, on='document_chembl_id')

target = pd.read_parquet('../data/raw_data/target.prqt')

molecule = pd.read_parquet('../data/raw_data/molecule.prqt')
lipinsky = pd.read_parquet('../data/raw_data/lipinsky.prqt')
structures = pd.read_parquet('../data/raw_data/all_structures.prqt')
molecule = molecule.merge(structures, on='salt_chembl_id').merge(lipinsky, on='salt_chembl_id')


assay = pd.read_parquet('../data/raw_data/assay.prqt')
types = pd.read_excel('../data/raw_data/types_28_08.xlsx')
assay = assay.merge(types.iloc[:27787, :35][["ASSAY_ID", "Target TYPE"]], left_on='assay_chembl_id', right_on='ASSAY_ID', how='left')

pubdates = pd.read_parquet('../data/raw_data/doc_with_pubdate_17_07.prqt')
pubdates = pubdates.drop('year', axis=1)
pubdates['year'] = pubdates.pubdate.apply(lambda x: x[0])
pubdates['month'] = pubdates.pubdate.apply(lambda x: x[1] if len(x) >= 2 else 1)
pubdates['authors'] = pubdates.authors.apply(lambda x: set(x.replace('.', '').split(', ')))
pubdates['postcodes'] = pubdates.postcodes.apply(lambda x: set(x.split(', ')) if type(x) == str else set([]))

pairs = pd.read_parquet('../data/raw_data/all_pairs.prqt')
pairs['delta_log_activity'] = abs(np.log10(pairs.standard_value1) - np.log10(pairs.standard_value2))
pairs = pairs[['activity_chembl_id1', 'activity_chembl_id2', 
                       'salt_chembl_id1', 'salt_chembl_id2',
                       'molecule_chembl_id1', 'target_chembl_id1', 
                       'assay_chembl_id1', 'assay_chembl_id2',
                       'document_chembl_id1', 'document_chembl_id2',
                       'standard_type1', 
                       'standard_value1', 'standard_value2', 'delta_log_activity']].rename(columns=\
                      {'molecule_chembl_id1':'molecule_chembl_id', 'target_chembl_id1':'target_chembl_id',
                       'standard_type1':'standard_type'})
pairs['log_value1'] = np.log10(pairs.standard_value1)
pairs['log_value2'] = np.log10(pairs.standard_value2)
pairs = pairs\
.merge(pubdates[['document_chembl_id', 'year', 'month']].add_suffix('1'), 
       on='document_chembl_id1')\
.merge(pubdates[['document_chembl_id', 'year', 'month']].add_suffix('2'), 
       on='document_chembl_id2')
pairs = pairs\
.merge(classification.add_suffix('1'), 
       on='document_chembl_id1')\
.merge(classification.add_suffix('2'), 
       on='document_chembl_id2')\

activity = activity.merge(pubdates[['document_chembl_id', 'year', 'month']], on='document_chembl_id')

issue = pd.read_excel('../data/raw_data/Assay-issue.xlsm')
no_issue = issue[issue['Unnamed: 2'] == "NO ISSUE"].iloc[:, :2]

protein_classification = pd.read_parquet('../data/raw_data/protein_classification.prqt')
all_targets = pd.read_parquet('../data/raw_data/all_targets.prqt')

steps_df = pd.DataFrame(columns=['activity', 'molecule', 'target', 'assay', 'document', 
                                 'ki_values', 'ki_systems', 'ki_pairs', 'ic_values', 'ic_systems', 'ic_pairs'])

# Define functions

In [3]:
def get_systems(df, pairs):

    ki_df = df.query('standard_type == "Ki"')[
        [
            "molecule_chembl_id",
            "target_chembl_id",
            "activity_chembl_id",
            "standard_value",
            
        ]
    ]
    ic_df = df.query('standard_type == "IC50"')[
        [
            "molecule_chembl_id",
            "target_chembl_id",
            "activity_chembl_id",
            "standard_value",
        ]
    ]

    ki_pairs = pairs.query('standard_type == "Ki"')
    ic_pairs = pairs.query('standard_type == "IC50"')
    
    step_dict = {'activity' : len(df) , 
                 'molecule' : len(df.salt_chembl_id.drop_duplicates()) , 
                 'target' : len(df.target_chembl_id.drop_duplicates()) , 
                 'assay' : len(df.assay_chembl_id.drop_duplicates()) , 
                 'document' : len(df.document_chembl_id.drop_duplicates()) , 
                 'ki_values' : len(ki_df) , 
                 'ki_systems' : len(ki_df[['molecule_chembl_id', 'target_chembl_id']].drop_duplicates()) , 
                 'ki_pairs' : len(ki_pairs) , 
                 'ic_values' : len(ic_df) , 
                 'ic_systems' : len(ic_df[['molecule_chembl_id', 'target_chembl_id']].drop_duplicates()), 
                 'ic_pairs' : len(ic_pairs)}
    
    global steps_df
    steps_df = pd.concat([steps_df, pd.DataFrame(step_dict, index=[0])], ignore_index=True)
    
    return steps_df


def update_pairs(activity, pairs):
    
    query = '''
    activity_chembl_id1 in @activity.activity_chembl_id
    and activity_chembl_id2 in @activity.activity_chembl_id
    '''.replace('\n', ' ')
    pairs = pairs.query(query)

    acts =  pd.concat([pairs.activity_chembl_id1, 
                       pairs.activity_chembl_id2]).drop_duplicates()
    activity = activity.query('activity_chembl_id in @acts')

    return activity, pairs


def inchi_key_from_molblock(molblock):
    try:
        mol = Chem.MolFromMolBlock(molblock)
        return Chem.inchi.MolToInchiKey(mol)
    except:
        return np.nan


def inchi_key_from_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        return Chem.inchi.MolToInchiKey(mol)
    except:
        return np.nan


def is_radical(molblock):
    try:
        mol = Chem.MolFromMolBlock(molblock)
        return any(atom.GetNumRadicalElectrons() > 0 for atom in mol.GetAtoms())
    except:
        return np.nan


def filter_molecule(molecule):

    radicals = set([])
    diff_inchi = set([])
    
    for index, row in tqdm(molecule.iterrows()):
        if is_radical(row.molfile):
            radicals.add(row.molecule_chembl_id)
        if inchi_key_from_smiles(row.canonical_smiles) != inchi_key_from_molblock(row.molfile):
            diff_inchi.add(row.molecule_chembl_id)

    query = '''
    (not ((chirality == -1) and (nstereo > 0)))
    and molecule_chembl_id not in @diff_inchi
    and molecule_chembl_id not in @radicals'''.replace('\n', ' ')
    molecule = molecule.query(query)
    
    return molecule


def filter_assay(assay):

    assay_ids = assay[
        ~(
            assay["Target TYPE"].isin(
                [
                    "Enzyme",
                    "Transcription factor",
                    "Nuclear Receptor",
                    "Other Protein",
                    "Transcription Factor",
                ]
            )
            & assay["bao_format"].isin(["BAO_0000219", "BAO_0000366"])
        )
    ].assay_chembl_id
    
    assay = assay.query('assay_chembl_id in @assay_ids')

    return assay


def initial_filtration(activity, pairs, molecule, assay, target, document, diff_doc=True):
    
    if diff_doc:
        pairs = pairs.query('document_chembl_id1 != document_chembl_id2')

    assay = filter_assay(assay)
    molecule = filter_molecule(molecule)

    query = '''
    molecule_chembl_id in @molecule.molecule_chembl_id
    and target_chembl_id in @target.target_chembl_id
    and document_chembl_id in @document.document_chembl_id
    and assay_chembl_id in @assay.assay_chembl_id
    '''.replace('\n', ' ')
    activity = activity.query(query)
    
    activity, pairs = update_pairs(activity, pairs)
    return activity, pairs

# Basic data filtering (common filtering workflow)

In [4]:
activity, pairs = initial_filtration(activity, pairs, molecule, assay, target, document)

0it [00:00, ?it/s]

# t-SNE molecule variety visualization

In [5]:
morgan_gen = GetMorganGenerator(radius=2, fpSize=2048)

def fp_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    fp = morgan_gen.GetFingerprint(mol)
    return np.array(fp)

def tsne_df(chem_rep_list):
    pca_50 = PCA(n_components=200)
    pca_result_50 = pca_50.fit_transform(chem_rep_list)
    tsne = TSNE(n_components=2, verbose=0, perplexity=30, n_iter=5000)
    tsne_results = tsne.fit_transform(pca_result_50)
    return pd.DataFrame(tsne_results, columns=["Component 1", "Component 2"])

def graph_chemical_space(df: pd.DataFrame, color_col: str = None, graph_title: str = None):
    mol_list = [fp_from_smiles(smiles) for smiles in df["canonical_smiles"]]
    rd_df = tsne_df(mol_list)
    rd_df = pd.concat([rd_df, df[['canonical_smiles', 'salt_chembl_id', color_col]]], axis=1)
    return rd_df

def draw_tsne(rd_df, name='', col='num_ro5_violations', save=False):

    rd_df = rd_df.dropna()
    fig1 = px.scatter(rd_df, x='Component 1', y='Component 2', color=col, 
                      title=name, width=850, height=750, 
                      range_color=(0, 4), color_continuous_scale='Turbo')

    fig1.update_traces(marker=dict(size=3))
    fig1.update_layout(plot_bgcolor='white', xaxis_title="", yaxis_title="")
    
    if col == 'num_ro5_violations':
        fig1.update_layout(coloraxis_colorbar=dict(
            title="Lipinski's rule violations",
            tickvals=[0, 1, 2, 3, 4],
            ticktext=[0, 1, 2, 3, 4],
        ))

    fig1.update_yaxes(showgrid=False, zeroline=False, range = [-200, 200], showticklabels=False)
    fig1.update_xaxes(showgrid=False, zeroline=False, range = [-200, 200], showticklabels=False)
    
    if save:
        fig1.write_image(f"{name.lower().replace(' ', '_')}.svg")
        fig1.write_image(f"{name.lower().replace(' ', '_')}.png")

    return fig1


tsne_df = graph_chemical_space(molecule.query('salt_chembl_id in @activity.salt_chembl_id').reset_index(drop=True),
                                   color_col='num_ro5_violations', 
                                   graph_title='t-SNE on molecular fingerprints')

draw_tsne(tsne_df, name='t-SNE on molecular fingerprints after common steps', save=True)

# Donut plot for target visualization

In [6]:
def get_count_dict(target_hierarchy):
    count_dict = dict()

    for index, row in protein_classification.sort_values("class_level", ascending=False).iterrows():
        prot_id = row["protein_class_id"]
        parent_id = row["parent_id"]

        if prot_id not in count_dict:
            count_dict[prot_id] = len(
                target_hierarchy.query("protein_class_id == @prot_id")
            )

            if parent_id not in count_dict:
                count_dict[parent_id] = count_dict[prot_id]
            else:
                count_dict[parent_id] += count_dict[prot_id]

        else:
            count_dict[prot_id] += len(
                target_hierarchy.query("protein_class_id == @prot_id")
            )

            if parent_id not in count_dict:
                count_dict[parent_id] = count_dict[prot_id]
            else:
                count_dict[parent_id] += count_dict[prot_id]

    return count_dict


def plot_donut(target_hierarchy, name: str, how="normal", show=False):
    count_dict = get_count_dict(target_hierarchy)
    count_dict[0] = 0

    if how == "log":
        target_dict = {
            "target": [name_dict[i] for i in parent_dict],
            "parent": [
                name_dict[int(j)] if (j != "" and not np.isnan(j)) else ""
                for i, j in parent_dict.items()
            ],
            "count": [count_dict[i] for i, j in parent_dict.items()],
            "color": [np.log10(count_dict[i]) for i, j in parent_dict.items()],
        }

        fig = px.sunburst(
            target_dict,
            names="target",
            parents="parent",
            values="count",
            color="count",
            title=name,
            width=1000,
            height=1000,
            maxdepth=4,
            range_color=[0, max(target_dict["color"])],
            color_continuous_scale="Turbo",
        )

    else:
        target_dict = {
            "target": [name_dict[i] for i in parent_dict],
            "parent": [
                name_dict[int(j)] if (j != "" and not np.isnan(j)) else ""
                for i, j in parent_dict.items()
            ],
            "count": [count_dict[i] for i, j in parent_dict.items()],
            "color": [count_dict[i] for i, j in parent_dict.items()],
        }

        fig = px.sunburst(
            target_dict,
            names="target",
            parents="parent",
            values="count",
            color="color",
            title=name,
            width=1000,
            height=1000,
            maxdepth=4,
            color_continuous_scale="Turbo",
            range_color=[0, max(target_dict["color"])],
        )

    fig.update_layout(font=dict(size=25))
    fig.write_image(f"{name.lower().replace(' ', '_')}.png")
    
    if show:
        fig.show()

In [7]:
pref_names = []
dups = set([])

for index, row in protein_classification.iterrows():
    if row["short_name"] in dups:
        pref_names.append(row["short_name"] + " ")
    else:
        pref_names.append(row["short_name"])
        dups.add(row["short_name"])

protein_classification["new_pref_name"] = pref_names
protein_classification["parent_id"] = protein_classification.parent_id.apply(
    lambda x: int(x) if not np.isnan(x) else ""
)

parent_dict = dict(zip(protein_classification.protein_class_id, protein_classification.parent_id))
name_dict = dict(zip(protein_classification.protein_class_id, protein_classification.new_pref_name))

plot_donut(all_targets.query('target_chembl_id in @activity.target_chembl_id'), name="Donut plot for targets after common steps", show=False, how="normal")

# Citations search based on correlations

In [8]:
def chlog(x):
    return 1 if x > 0.99999 else -1 if x < -0.99999 else x

def calculate_pearson(counts_table, id_column, value1, value2, label):
    pearson = (
        counts_table.query('standard_value1 != standard_value2')
        .groupby([f'{id_column}1', f'{id_column}2']).filter(lambda x: len(x) > 1)
        .groupby([f'{id_column}1', f'{id_column}2'])[[value1, value2]]
        .corr().unstack().iloc[:, 1].reset_index()
    )
    pearson.columns = [f'{id_column}1', f'{id_column}2', label]
    return pearson

def process_table(table, counts_table, id_column):

    pearson = calculate_pearson(counts_table, id_column, 'standard_value1', 'standard_value2', 'pearson_neq')
    pearson2 = calculate_pearson(counts_table, id_column, 'log_value1', 'standard_value2', 'pearson_neq_1_log')
    pearson3 = calculate_pearson(counts_table, id_column, 'standard_value1', 'log_value2', 'pearson_neq_2_log')
    pearson4 = calculate_pearson(counts_table, id_column, 'log_value1', 'log_value2', 'pearson_neq_both_log')

    table = (table
             .merge(pearson, on=[f'{id_column}1', f'{id_column}2'], how='left')
             .merge(pearson2, on=[f'{id_column}1', f'{id_column}2'], how='left')
             .merge(pearson3, on=[f'{id_column}1', f'{id_column}2'], how='left')
             .merge(pearson4, on=[f'{id_column}1', f'{id_column}2'], how='left')
            ).drop_duplicates()

    table['pearson_neq_1_log'] = table['pearson_neq_1_log'].apply(chlog)
    table['pearson_neq_2_log'] = table['pearson_neq_2_log'].apply(chlog)
    table['pearson_neq_both_log'] = table['pearson_neq_both_log'].apply(chlog)

    return table

def get_table(pairs, id_column):
    
    counts = pairs\
    .groupby([f'{id_column}1', f'{id_column}2'])\
    .filter(lambda x: len(x) >= 3)
    
    table = \
    counts.groupby([f'{id_column}1', f'{id_column}2'])\
    .count().reset_index().iloc[:, :3]\
    .rename(columns={'activity_chembl_id1':'n_pairs'})\
    .merge(
    counts.query('standard_value1 != standard_value2')\
    .groupby([f'{id_column}1', f'{id_column}2'])\
    .count().reset_index().iloc[:, :3]\
    .rename(columns={'activity_chembl_id1':'neq_pairs'}),
    on=[f'{id_column}1', f'{id_column}2'], how='left')\
    .merge(
    counts.query('standard_value1 == standard_value2')\
    .groupby([f'{id_column}1', f'{id_column}2'])\
    .count().reset_index().iloc[:, :3]\
    .rename(columns={'activity_chembl_id1':'eq_pairs'}),
    on=[f'{id_column}1', f'{id_column}2'], how='left')

    if id_column == 'assay_chembl_id':
        table = table.merge(counts[[f'{id_column}1', f'{id_column}2', 'year1', 'year2', 
                             'month1', 'month2', 'document_chembl_id1', 'document_chembl_id2']], 
                            on=[f'{id_column}1', f'{id_column}2']).drop_duplicates()
    else:
        table = table.merge(counts[[f'{id_column}1', f'{id_column}2', 'year1', 'year2', 
                             'month1', 'month2']], 
                            on=[f'{id_column}1', f'{id_column}2']).drop_duplicates()
        
    
    table['neq_pairs'] = table['neq_pairs'].apply(lambda x: 0 if np.isnan(x) == True else x)
    table['eq_pairs'] = table['eq_pairs'].apply(lambda x: 0 if np.isnan(x) == True else x)
    table = process_table(table, counts, id_column)

    return table.drop_duplicates(), counts.drop_duplicates()


cited_assays = set([])
error_assays = set([])

for index, row in pairs\
.query('(delta_log_activity in (1, 2, 3)) or (2.98 <= delta_log_activity <= 3.02) or (standard_value1 == log_value2) or (standard_value2 == log_value1)').iterrows():
    if (row.year1 < row.year2) or (row.year1 == row.year2 and row.month1 < row.month2) or (int(row.document_chembl_id1.strip('CHEMBL')) < int(row.document_chembl_id2.strip('CHEMBL'))):
        cited_assays.add(row.assay_chembl_id2)
        error_assays.add(row.assay_chembl_id2)
    else:
        cited_assays.add(row.assay_chembl_id1)
        error_assays.add(row.assay_chembl_id1)

assay_table, assay_counts = get_table(pairs, 'assay_chembl_id')

for index, row in assay_table.query('pearson_neq <= 0').iterrows():
    a = assay_counts.query('assay_chembl_id1 == @row.assay_chembl_id1 and assay_chembl_id2 == @row.assay_chembl_id2 and standard_value1 != standard_value2')
    if (len(set(a.standard_value1) & set(a.standard_value2)) >= len(a.standard_value1) / 2
       ) or (row.neq_pairs == 3 and row.pearson_neq < -0.9999) or (row.neq_pairs >= 4 and row.pearson_neq < -0.999):
        if (row.year1 < row.year2) or (row.year1 == row.year2 and row.month1 < row.month2) or (int(row.document_chembl_id1.strip('CHEMBL')) < int(row.document_chembl_id2.strip('CHEMBL'))):
            cited_assays.add(row.assay_chembl_id2)
            error_assays.add(row.assay_chembl_id2)
        else:
            cited_assays.add(row.assay_chembl_id1)
            error_assays.add(row.assay_chembl_id1) 
        
for index, row in assay_table.iterrows():
    if row.neq_pairs == 0 or ( (1 <= row.neq_pairs <= 2) and (row.eq_pairs >= 2) 
                             ) or ( row.neq_pairs == 3 and row.pearson_neq > 0.9999
                                  ) or ( row.neq_pairs >= 4 and row.pearson_neq > 0.999
                                       ) or ( row.neq_pairs >= 3 and row.pearson_neq_1_log in (-1, 1) 
                                            ) or ( row.neq_pairs >= 3 and row.pearson_neq_2_log in (-1, 1)  
                                                 ) or ( row.neq_pairs >= 3 and row.pearson_neq_both_log in (-1, 1) ):
        
        if (row.year1 < row.year2) or (row.year1 == row.year2 and row.month1 < row.month2) or (int(row.document_chembl_id1.strip('CHEMBL')) < int(row.document_chembl_id2.strip('CHEMBL'))):
            cited_assays.add(row.assay_chembl_id2)
        else:
            cited_assays.add(row.assay_chembl_id1)
    
assay_table['cited_assay_1'] = assay_table.assay_chembl_id1.isin(cited_assays)
assay_table['cited_assay_2'] = assay_table.assay_chembl_id2.isin(cited_assays)
assay_table['error_assay_1'] = assay_table.assay_chembl_id1.isin(error_assays)
assay_table['error_assay_2'] = assay_table.assay_chembl_id2.isin(error_assays)

In [9]:
document_table, document_counts = get_table(pairs, id_column='document_chembl_id')

cited_docs = set([])
error_docs = set([])

for index, row in tqdm(document_table.query('pearson_neq <= 0').iterrows()):
    a = document_counts.query('document_chembl_id1 == @row.document_chembl_id1 and document_chembl_id2 == @row.document_chembl_id2 and standard_value1 != standard_value2')
    if (len(set(a.standard_value1) & set(a.standard_value2)) >= len(a.standard_value1) // 2
       ) or (row.neq_pairs == 3 and row.pearson_neq < -0.9999) or (row.neq_pairs >= 4 and row.pearson_neq < -0.999):
        if (row.year1 < row.year2) or (row.year1 == row.year2 and row.month1 < row.month2) or (int(row.document_chembl_id1.strip('CHEMBL')) < int(row.document_chembl_id2.strip('CHEMBL'))):
            cited_docs.add(row.document_chembl_id2)
            error_docs.add(row.document_chembl_id2)
        else:
            cited_docs.add(row.document_chembl_id1)
            error_docs.add(row.document_chembl_id1) 

for index, row in document_table.iterrows():
    if row.neq_pairs == 0 or ( (1 <= row.neq_pairs <= 2) and (row.eq_pairs >= 2) 
                             ) or ( row.neq_pairs == 3 and row.pearson_neq > 0.9999
                                  ) or ( row.neq_pairs >= 4 and row.pearson_neq > 0.999
                                       ) or ( row.neq_pairs >= 3 and row.pearson_neq_1_log in (-1, 1) 
                                            ) or ( row.neq_pairs >= 3 and row.pearson_neq_2_log in (-1, 1)  
                                                 ) or ( row.neq_pairs >= 3 and row.pearson_neq_both_log in (-1, 1) ):
        
        if (row.year1 < row.year2) or (row.year1 == row.year2 and row.month1 < row.month2) or (int(row.document_chembl_id1.strip('CHEMBL')) < int(row.document_chembl_id2.strip('CHEMBL'))):
            cited_docs.add(row.document_chembl_id2)
        else:
            cited_docs.add(row.document_chembl_id1)

document_table['cited_document_1'] = document_table.document_chembl_id1.isin(cited_docs)
document_table['cited_document_2'] = document_table.document_chembl_id2.isin(cited_docs)
document_table['error_assay_1'] = document_table.document_chembl_id1.isin(error_docs)
document_table['error_assay_2'] = document_table.document_chembl_id2.isin(error_docs)

0it [00:00, ?it/s]

# Citations search based on graph connectivity components

In [10]:
def is_citation(value1: float, value2: float) -> bool:
    return (0 <= abs(value1 - value2) <= 0.02) or (
        2.98 <= abs(value1 - value2) <= 3.02
    )


def is_later(meas1: tuple[str, int, int], meas2: tuple[str, int, int]) -> bool:
    id1, year1, month1 = meas1
    id2, year2, month2 = meas2
    if year2 > year1:
        return True
    elif (year2 == year1) and (month2 != month1):
        return month2 > month1
    elif month2 == month1:
        return id2 > id1
    return False


def build_graph(measurements: pd.DataFrame) -> nx.DiGraph:
    G = nx.DiGraph()
    N = len(measurements)
    for i in range(N):
        G.add_node(i)
    for i in range(N):
        for j in range(N):
            if (
                (i != j)
                and is_later(
                    (
                        measurements.iloc[i, 0],
                        measurements.iloc[i, 5],
                        measurements.iloc[i, 6],
                    ),
                    (
                        measurements.iloc[j, 0],
                        measurements.iloc[j, 5],
                        measurements.iloc[j, 6],
                    ),
                )
                and (
                    is_citation(
                        measurements.iloc[i, 4], measurements.iloc[j, 4]
                    )
                )
            ):
                G.add_edge(j, i)
    return G


def find_originals_and_citations(
    measurements: pd.DataFrame,
) -> tuple[list, dict]:
    measurements = measurements.reset_index(drop=True)
    G = build_graph(measurements)
    wcc = list(nx.weakly_connected_components(G))
    originals = []
    citations_map = {}
    for component in wcc:
        if len(component) == 1:
            originals.append(measurements.iloc[list(component)[0], 0])
        else:
            component_df = measurements.iloc[list(component), :].sort_values(
                by=["year", "month", "activity_chembl_id"]
            )
            original = component_df.iloc[0, 0]
            originals.append(original)
            for node in component_df.iloc[1:, 0]:
                citations_map[node] = original

    return originals, citations_map

all_originals = []
all_citations_map = {}
for i, group in tqdm(activity[['activity_chembl_id', 'molecule_chembl_id', 'target_chembl_id', 'standard_type', 'standard_value', 'year', 'month']].groupby(['molecule_chembl_id', 'target_chembl_id', 'standard_type'])):
    originals, citations_map = find_originals_and_citations(group)
    all_originals.extend(originals)
    all_citations_map.update(citations_map)

  0%|          | 0/10581 [00:00<?, ?it/s]

In [11]:
all_originals = set(all_originals)
activity['approx_cited_activity'] = activity.activity_chembl_id.apply(lambda x: x not in all_originals)
activity['original_activity'] = activity.activity_chembl_id.apply(lambda x: all_citations_map[x] if x in all_citations_map else np.nan)

# Final filtration

In [12]:
citation_counts = activity.dropna(subset='original_activity')\
.groupby('original_activity').agg({'activity_chembl_id':'count'}).reset_index(drop=False)
citation_counts = citation_counts.rename(columns={'activity_chembl_id': 'num_citations', 'original_activity':'activity_chembl_id'})
activity = activity.merge(citation_counts, on='activity_chembl_id', how='left')

pairs = pairs\
.merge(pubdates[['document_chembl_id', 'authors', 'postcodes']].add_suffix('1'), 
       on='document_chembl_id1')\
.merge(pubdates[['document_chembl_id', 'authors', 'postcodes']].add_suffix('2'), 
       on='document_chembl_id2')

pairs['INDEPENDENT'] = ~((pairs.postcodes1 & pairs.postcodes2) | (pairs.authors1 & pairs.authors2))

final_pairs = pairs[['activity_chembl_id1', 'activity_chembl_id2', 'salt_chembl_id1',
                             'salt_chembl_id2', 'molecule_chembl_id', 'target_chembl_id',
                             'assay_chembl_id1', 'assay_chembl_id2', 'document_chembl_id1',
                             'document_chembl_id2', 'standard_type', 'standard_value1',
                             'standard_value2', 'delta_log_activity', 'year1', 'month1',
                             'authors1', 'classification1', 'postcodes1', 'year2', 'month2', 'authors2', 
                             'classification2', 'postcodes2','INDEPENDENT']]\
                             .merge(activity[['activity_chembl_id', 'approx_cited_activity']]\
                             .add_suffix('1'), on='activity_chembl_id1')\
                             .merge(activity[['activity_chembl_id', 'approx_cited_activity']]\
                             .add_suffix('2'), on='activity_chembl_id2')

final_pairs['cited_assay_1'] = final_pairs.assay_chembl_id1.isin(cited_assays)
final_pairs['cited_assay_2'] = final_pairs.assay_chembl_id2.isin(cited_assays)
final_pairs['cited_document_1'] = final_pairs.document_chembl_id1.isin(cited_docs)
final_pairs['cited_document_2'] = final_pairs.document_chembl_id2.isin(cited_docs)

final_pairs['error_assay_1'] = final_pairs.assay_chembl_id1.isin(error_assays)
final_pairs['error_assay_2'] = final_pairs.assay_chembl_id2.isin(error_assays)
final_pairs['error_document_1'] = final_pairs.document_chembl_id1.isin(error_docs)
final_pairs['error_document_2'] = final_pairs.document_chembl_id2.isin(error_docs)

no_issue = no_issue.drop_duplicates().reset_index(drop=True)
no_issue['NO_ISSUE'] = [True for _ in range(len(no_issue))]
assay_pairs = final_pairs[['assay_chembl_id1', 'assay_chembl_id2', 'INDEPENDENT']].drop_duplicates()\
                        .merge(no_issue, on=['assay_chembl_id1', 'assay_chembl_id2'], how='left').reset_index(drop=True)

assay_pairs['NO_ISSUE'] = assay_pairs.NO_ISSUE.apply(lambda x: x if x == True else False)
assay_pairs['not_ind_no_iss'] = ~ assay_pairs.INDEPENDENT & assay_pairs.NO_ISSUE
final_pairs = final_pairs.merge(assay_pairs[['assay_chembl_id1', 'assay_chembl_id2', 'NO_ISSUE', 'not_ind_no_iss']], 
                                on=['assay_chembl_id1', 'assay_chembl_id2'], how='left')

assay_pair_dict = dict()
for index, row in tqdm(assay_pairs.iterrows()):
    key = sorted((row.assay_chembl_id1, row.assay_chembl_id2))
    key = tuple(key)
    if key not in assay_pair_dict:
        assay_pair_dict[key] = row.not_ind_no_iss

assay_graph = dict()

for key in tqdm(assay_pair_dict):
    if key[0] in assay_graph:
        assay_graph[key[0]].add(key[1])
    else:
        assay_graph[key[0]] = set([key[1]])
        
    if key[1] in assay_graph:
        assay_graph[key[1]].add(key[0])
    else:
        assay_graph[key[1]] = set([key[0]])
        
        
assay_stack = [key for key, value in assay_pair_dict.items() if value == True]
n_it = 0

while assay_stack:
    if n_it % 10000 == 0:
        print(n_it)
        print(len(assay_stack))
        print()
    curr_pair = assay_stack.pop()
    intersection = assay_graph[curr_pair[0]] & assay_graph[curr_pair[1]]
    for element in intersection:
        first_pair = tuple(sorted((curr_pair[0], element)))
        second_pair = tuple(sorted((curr_pair[1], element)))
        if (not assay_pair_dict[first_pair]) and (not assay_pair_dict[second_pair]):
            assay_pair_dict[first_pair] = True
            assay_pair_dict[second_pair] = True
            #assay_stack.append(first_pair)
            #assay_stack.append(second_pair)
    n_it += 1
    
new_independent = []
for index, row in final_pairs.iterrows():
    curr_key = tuple(sorted((row.assay_chembl_id1, row.assay_chembl_id2)))
    new_independent.append(assay_pair_dict[curr_key])
final_pairs['not_ind_no_iss_graph'] = new_independent

final_pairs['authors1'] = final_pairs['authors1'].apply(lambda x: ', '.join(list(x)))
final_pairs['authors2'] = final_pairs['authors2'].apply(lambda x: ', '.join(list(x)))
final_pairs['postcodes1'] = final_pairs['postcodes1'].apply(lambda x: ', '.join(list(x)))
final_pairs['postcodes2'] = final_pairs['postcodes2'].apply(lambda x: ', '.join(list(x)))
final_pairs = final_pairs.drop_duplicates().reset_index(drop=True)


activity['cited_assay'] = activity.assay_chembl_id.isin(cited_assays)
activity['cited_document'] = activity.document_chembl_id.isin(cited_docs)
activity['error_assay'] = activity.assay_chembl_id.isin(error_assays)
activity['error_document'] = activity.document_chembl_id.isin(error_docs)
activity = activity.merge(activity.groupby('assay_chembl_id')\
    .activity_chembl_id.count().reset_index()\
    .rename(columns={'activity_chembl_id':'acts_per_assay'}),
    on='assay_chembl_id')

0it [00:00, ?it/s]

  0%|          | 0/272926 [00:00<?, ?it/s]

0
2221



In [13]:
query = '''
INDEPENDENT == True and 
approx_cited_activity1 == False and
approx_cited_activity2 == False and 
cited_assay_1 == False and
cited_assay_2 == False and
cited_document_1 == False and
cited_document_2 == False and
error_assay_1 == False and
error_assay_2 == False and
error_document_1 == False and
error_document_2 == False and
NO_ISSUE == True and
not_ind_no_iss_graph == False
'''.replace('\n', ' ')

final_activity, final_pairs = update_pairs(activity, final_pairs.query(query))

In [14]:
draw_tsne(tsne_df.query('salt_chembl_id in @final_activity.salt_chembl_id'), name='t-SNE on molecular fingerprints after review and citations filtering', save=True)
plot_donut(all_targets.query('target_chembl_id in @final_activity.target_chembl_id'), name="Donut plot for targets after review and citations filtering", show=False, how="normal")