##


#### To create the dataframes and files needed for our analysis, make sure that:

Make sure you have the following files in `data/raw`:
  1. [full_database.xml](https://drive.google.com/file/d/1149kYVkazq67e0vuv-_4APyqVX6yyh2p), which will represent the XML version of the DrugBank
  2. [BindingDB_All.tsv](https://www.bindingdb.org/bind/downloads/BindingDB_All_202411_tsv.zip), which will represent the tsv version of the BindingDB

### Imports

In [None]:
import os
import gc
import re
import requests
import time
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tqdm import tqdm
import pandas as pd
from concurrent.futures import ThreadPoolExecutor

from drugbank_XML_drugparser import DrugParser
from drugbank_bindingdb_merger import DrugBank_BindingDB_Merger
from preprocessing import Preprocessing, CleanNumericAtrributesStrategy, ColumnClean

from imports import *

%matplotlib inline

%load_ext autoreload
%autoreload 2

### Loading the data

In [None]:

if os.path.exists(MERGED):
    print("Merged dataset exists.\n Loading...")

    merged_df = pd.read_pickle(MERGED)

    print("Merged dataset loaded")

else:
    print("Merged dataset doesn't exists.\n Creating it...")

    if os.path.exists(DRUGBANK_CSV):
        print("parsed_Drugbank exists...")
        print("Loading...")

        drugbank = pd.read_csv(DRUGBANK_CSV, encoding='utf-8')
    else:
        print("parsed_Drugbank doesn't exists...")
        print("Creating parsed_Drugbank.csv")

        drugparser = DrugParser(DRUGBANK_XML)
        drugparser.parse_drugs()
        drugbank = drugparser.save_parsed_drugs(DRUGBANK_CSV, return_df = True)

        print("parsed_Drugbank.csv is created")
        print("DrugBank XML is parsed. \n Loading it ...")
    if os.path.exists(BINDINGDB_CLEAN):
        print("BindingDB clean exists...")
        clean_binding_df = pd.read_pickle(BINDINGDB_CLEAN)
    else:
        def load_BindingDB(file_path, cols):
            return pd.read_csv(file_path, sep='\t', header=0, usecols=cols)

        """
        This method srips all non-numeric characters from the value.

        For example, the temperature is always given as 12C or 14.2C, so we remove the C.

        In come cases, mesured values are given as >10000 or <.01. In these cases, we just take the bound, as it is still better then loosing all the information by replacing it with NA.
        """
        def keep_just_numeric(value):
            if type(value) != str:
                return pd.NA

            ## All non-numeric characters (except the decimal .) should be replaced
            cleaned_val = re.sub(r'[^\d.]+','', str(value)) 
            if(cleaned_val == ''): # It didn't contain any number?
                return pd.NA
            return float(cleaned_val)

        def parse_int(value):
            try:
                return int(value)
            except:
                return pd.NA

        bdb_preprocessor = Preprocessing(
            [
                ColumnClean('Ki (nM)', 'ki', clean=keep_just_numeric),
                ColumnClean('pH', 'ph', clean=keep_just_numeric),
                ColumnClean('Temp (C)', 'temp', clean=keep_just_numeric),
                ColumnClean('IC50 (nM)', 'ic50', clean=keep_just_numeric),
                ColumnClean('Kd (nM)', 'kd', clean=keep_just_numeric),
                ColumnClean('kon (M-1-s-1)', 'kon', clean=keep_just_numeric),

                # These columns are supposed to contain strings.
                ColumnClean('Article DOI', 'doi'),

                # We use these IDs to join bindingDB and drugbank
                ColumnClean('PubChem CID', 'pubchem_cid'),
                ColumnClean('ChEBI ID of Ligand', 'chebi_id'),
                ColumnClean('ChEMBL ID of Ligand', 'chembl_id'),
                ColumnClean('DrugBank ID of Ligand', 'drugbank_id'),
                ColumnClean('KEGG ID of Ligand', 'kegg_id'),
                ColumnClean('ZINC ID of Ligand', 'zinc_id'),
                ColumnClean('Ligand SMILES', 'smiles'),
                ColumnClean('Ligand InChI Key', 'inchi_key'),
                ColumnClean('BindingDB MonomerID', 'bindingdb_id', clean=parse_int),
                ColumnClean('PubChem CID','pubchem_cid'),
                ColumnClean('UniProt (SwissProt) Primary ID of Target Chain.1', 'swissprot_protein_id'),
            ]
        )




        print("Lodaing Binding DB...")
        raw_binding_df = load_BindingDB(BINDINGDB_RAW, bdb_preprocessor.get_used_old_columns())

        print("Cleaning Binding DB...")
        clean_binding_df = bdb_preprocessor.transform(raw_binding_df)
        clean_binding_df.to_pickle(BINDINGDB_CLEAN)

        assert len(clean_binding_df) == len(raw_binding_df)
        del raw_binding_df
        gc.collect()


In [None]:
clean_binding_df['bindingdb_id'].isna().sum()

In [None]:

clean_binding_df.sample(20)

In [None]:
clean_binding_df.dtypes

In [None]:
for col in clean_binding_df.columns:
    i = clean_binding_df[col].isna().mean()
    print(f"NA ratio in {col}: {i:.2f}")

In [None]:

print("Creating merged dataset")
drugbank_binding_merger = DrugBank_BindingDB_Merger()
merged_df = drugbank_binding_merger.merge(drugbank, clean_binding_df)


merged_df.to_pickle(MERGED)

print("Merged dataset is loaded and saved.")


In [None]:
len(merged_df), len(drugbank), len(clean_binding_df)

In [None]:
merged_df['Matched_On'].value_counts()

## DOIs
In order to discover temporal trends, we need to fetch data about articles, from which BindingDB measurements were taken (i.e. the time when the measurement was published).

In [None]:
unique_dois = clean_binding_df['doi'].dropna().unique()
print('Number of unique dois:', len(unique_dois))

In [None]:

def fetch_article_metadata(doi):
    url = f"https://api.crossref.org/works/{doi}"
    response = requests.get(url)
    
    if response.status_code == 200:
        data = response.json()["message"]
        
        title_list = data.get("title", [])
        title = title_list[0] if len(title_list) > 0 else pd.NA
        abstract = data.get("abstract", pd.NA)
        # Get published year, month and day
        published_date = data.get("published-print", None)
        if published_date:
            published_date = published_date["date-parts"][0]
        else:
            published_date = []
        year = published_date[0] if len(published_date) > 0 else pd.NA
        month = published_date[1] if len(published_date) > 1 else pd.NA
        day = published_date[2] if len(published_date) > 2 else pd.NA

        
        
        return {
            "title": title,
            "abstract": abstract,
            "year": year,
            "month": month,
            "day": day
        }
    else:
        print(f"Failed to fetch data for DOI {doi}. Status code: {response.status_code}")
        return None

# Example usage
doi = np.random.choice(unique_dois)
metadata = fetch_article_metadata(doi)

if metadata:
    print("Title:", metadata["title"])
    print("Abstract:", metadata["abstract"])


In [None]:
if os.path.exists(DOI_DF_PATH):
    df_dois = pd.read_pickle(DOI_DF_PATH)
    print('Loaded doi_df from', DOI_DF_PATH)
else:
    print(f'No DOI dataframe found at {DOI_DF_PATH}. Creating a new one...')
    cols = ['fetched', 'title', 'abstract', 'year', 'month', 'day']

    # Create an empty dataframe with doi as index and cols with nans
    df_dois = pd.DataFrame(index=unique_dois, columns=cols)
    df_dois['fetched'] = False

    df_dois.index.name = 'doi'
    df_dois.head()
    

In [None]:
def fetch_and_update(doi):
    if df_dois.loc[doi, 'fetched']:
        return None

    # Fetch article metadata with a random delay
    article_info = fetch_article_metadata(doi)
    if article_info is None:
        df_dois.loc[doi, 'fetched'] = True
        return None
    for key, value in article_info.items():
        df_dois.loc[doi, key] = value
    df_dois.loc[doi, 'fetched'] = True
    
    # Sleep for a short, random time to avoid triggering rate limits
    time.sleep(random.uniform(0.05, .5))

if df_dois.fetched.all():
    print('All articles were fetched. Saving df_dois')
    df_dois.to_pickle(DOI_DF_PATH)
else:
    print('Not all articles were fetched. Fetching DOIS')
    try:
        with ThreadPoolExecutor(max_workers=5) as executor:
            list(tqdm(executor.map(fetch_and_update, df_dois.index), total=len(df_dois.index)))
    except: 
        print('Error, saving WIP.')
        df_dois.to_pickle(DOI_DF_PATH)
    


In [None]:
# Now all the DOIs should be fetched using the API and the result saved in the dataframe
assert df_dois['fetched'].all()

In [None]:
# Show percentage of missing values per column?
# We see abstract and day (of month) are not going to be very useful, but for the rest we almost always have data
df_dois.isnull().mean()

In [None]:
# Note, that most articles have more measuremetns in them
df_dois['year'].hist()