In [1]:
import os
import pandas as pd
import numpy as np

In [2]:
####################
# Load the data
####################
# 1. Drug-disease dataframe
files = ["complex_disease_train_w_name.csv", "complex_disease_test_w_name.csv", "complex_disease_valid_w_name.csv"]
df = pd.concat([pd.read_csv(f) for f in files]).reset_index(drop=True)

# 2. Retrieve drug names
drug_id2name = {}
for _, row in df.iterrows():
    if row["x_type"] == "drug":
        drug_id2name[row["x_id"]] = row["x_name"]
    if row["y_type"] == "drug":
        drug_id2name[row["y_id"]] = row["y_name"]
print(f"Number of unique drugs: {len(drug_id2name)}")

Number of unique drugs: 2074


In [3]:
# 3. Retrieve drug-reporeted pmids
import requests
from xml.etree import ElementTree as ET
import random
import time
from tqdm import tqdm

def get_pubmed_ids(query, db="pubmed"):
    sleep_time = random.uniform(0.5, 1.0)
    time.sleep(sleep_time)

    url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db={db}&term={query}&rettype=count"
    response = requests.get(url)
    if not response.status_code == 200:
        print(f"[{query}] Error: {response.status_code}")
        return None
    else:
        root = ET.fromstring(response.text)
        total_hits = int(root.find("Count").text)

        # Second, get the pmids
        retmax = 100000
        total_pmids = set()
        for retstart in range(0, total_hits, retmax):
            url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db={db}&term={query}&retstart={retstart}&retmax={retmax}"
            response = requests.get(url)
            if not response.status_code == 200:
                print(f"[{query}] Error: {response.status_code}")
                return None
            else:
                root = ET.fromstring(response.text)
                pmids = [id_elem.text for id_elem in root.findall(".//Id")]
                total_pmids.update(set(pmids))
        
        import requests
from xml.etree import ElementTree as ET
import random
import time

def get_pubmed_ids(query, db="pubmed"):
    sleep_time = random.uniform(0.5, 1.0)
    time.sleep(sleep_time)

    url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db={db}&term={query}&rettype=count"
    response = requests.get(url)
    if not response.status_code == 200:
        print(f"[{query}] Error: {response.status_code}")
        return None
    else:
        root = ET.fromstring(response.text)
        total_hits = int(root.find("Count").text)

        # Second, get the pmids
        retmax = 100000
        total_pmids = set()
        for retstart in range(0, total_hits, retmax):
            url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db={db}&term={query}&retstart={retstart}&retmax={retmax}"
            response = requests.get(url)
            if not response.status_code == 200:
                print(f"[{query}] Error: {response.status_code}")
                return None
            else:
                root = ET.fromstring(response.text)
                pmids = [id_elem.text for id_elem in root.findall(".//Id")]
                total_pmids.update(set(pmids))
        
        total_pmids = sorted(list(total_pmids))
        return total_pmids


In [4]:
drug_pmids = [(drug_id, get_pubmed_ids(drug_name)) for drug_id, drug_name in tqdm(drug_id2name.items())]

100%|██████████| 2074/2074 [1:28:26<00:00,  2.56s/it]  


In [5]:
drug_pmids_pmc = [(drug_id, get_pubmed_ids(drug_name, db="pmc")) for drug_id, drug_name in tqdm(drug_id2name.items())]

 97%|█████████▋| 2002/2074 [1:17:52<02:46,  2.31s/it]

[Albutrepenonacog alfa] Error: 500


100%|██████████| 2074/2074 [1:20:36<00:00,  2.33s/it]


In [14]:
import pickle
with open("05_drug_pmids.pkl", "wb") as f:
    pickle.dump(drug_pmids, f)

with open("05_drug_pmids_pmc.pkl", "wb") as f:
    pickle.dump(drug_pmids_pmc, f)

In [27]:
# Convert PMC to PMID
from tqdm import tqdm
from joblib import Parallel, delayed
import time
import random

def pmc2pmid(pmc_ids: list):
    # random sleep
    sleep_time = random.uniform(0.5, 1.0)

    base_url = 'https://www.ncbi.nlm.nih.gov/pmc/utils/idconv/v1.0/'
    params = {
        "ids": ",".join(["PMC" + pmc_id for pmc_id in pmc_ids]),
        "tool": "MyPMCApp",
        "email": "masato.tsutsui@protein.osaka-u.ac.jp",
        "format": "json"
    }

    response = requests.get(base_url, params=params)
    if not response.status_code == 200:
        print(f"Error: {response.status_code}")
        return None
    else:
        data = response.json()

        mapping = {}
        for record in data.get("records", []):
            pmcid = record.get("pmcid")
            pmid = record.get("pmid")
            mapping[pmcid] = pmid
        return mapping

all_pmc_ids = set()
for _, pmc_ids in drug_pmids_pmc:
    if pmc_ids is not None:
        all_pmc_ids.update(pmc_ids)
all_pmc_ids = list(all_pmc_ids)
print(f"Number of unique PMC IDs: {len(all_pmc_ids)}")

bs = 200
pmc2pmid_parallel = Parallel(n_jobs=5)(delayed(pmc2pmid)(all_pmc_ids[i:i+bs]) for i in tqdm(range(0, len(all_pmc_ids), bs)))

Number of unique PMC IDs: 6492520


100%|██████████| 32463/32463 [1:06:48<00:00,  8.10it/s]


In [34]:
pmc2pmid = {}
for mapping in pmc2pmid_parallel:
    if mapping is not None:
        mapping = {k.replace("PMC", ""): v for k, v in mapping.items()}
        pmc2pmid.update(mapping)

In [80]:
drug2pmids= {}
for drug, pmids in drug_pmids:
    if pmids is not None:
        drug2pmids[drug] = pmids.copy()
    else:
        drug2pmids[drug] = []

for drug, pmc_ids in drug_pmids_pmc:
    if pmc_ids is not None:
        pmids = [pmc2pmid[pmc_id] for pmc_id in pmc_ids if pmc2pmid[pmc_id] is not None]
        drug2pmids[drug].extend(pmids)
    else:
        drug2pmids[drug].extend([])

drug2pmids = {drug: sorted(list(set(pmids))) for drug, pmids in drug2pmids.items()}

In [81]:
for drug, pmids in drug2pmids.items():
    if len(pmids) == 0:
        print(f"{drug} has no pmids")

In [82]:
detected_pmids = sorted(list({pmid for pmids in drug2pmids.values() for pmid in pmids}))

pmids2drug = {pmid: [] for pmid in detected_pmids}
for drug, pmids in drug2pmids.items():
    for pmid in pmids:
        pmids2drug[pmid].append(drug)

with open("05_drugbank_pmids.txt", "w") as f:
    for pmid, drugs in pmids2drug.items():
        f.write(f"{pmid}\t{','.join(drugs)}\n")

In [33]:
from collections import Counter

drug_id2pmid_counts = Counter()

with open("05_drugbank_pmids.txt", "r") as f:
    for line in f:
        pmid, drugs = line.strip().split("\t")
        drugs = drugs.split(",")
        drug_id2pmid_counts.update(drugs)

In [132]:
#####################################
# Merge the pmid-drug information with Pubtator3
#####################################
from scipy.sparse import lil_matrix
pubtator3_dir = "/share/pubtator3"

# Define the matrix
num_pmids = len(pmids2drug)
num_drugs = len(drug_id2name)
drug_oh = lil_matrix((num_pmids, num_drugs), dtype=int)

drug2idx = {}
drug_ids_to_use = []
drug_names_to_use = []
for i, (drug_id, drug_name) in enumerate(drug_id2name.items()):
    drug2idx[drug_id] = i
    drug_names_to_use.append(drug_name)
    drug_ids_to_use.append(drug_id)

# Get indices for pmids in pubtator3
indices = []
j = 0

with open(os.path.join(pubtator3_dir, "count_data", "pubtator3_pmids.txt")) as f:
    for i, pmid in enumerate(f.read().splitlines()):
        if pmid in pmids2drug:
            indices.append(i)
            # Update the drug one-hot matrix
            for drug in pmids2drug[pmid]:
                drug_oh[j, drug2idx[drug]] = 1
            j += 1

drug_oh = drug_oh.tocsr()

In [133]:
#########################
### Load Pubtator3 counts
#########################
from scipy.io import mmread
from scipy.sparse import csr_matrix, lil_matrix

count_mtx = mmread(os.path.join(pubtator3_dir, "count_data", "counts.mtx"))
print("count_mtx:", count_mtx.shape)

# All human entrez gene ids
all_human_entrez = set(pd.read_csv(os.path.join(pubtator3_dir, "data", "input", "all_human_entrez.csv"))["ENTREZID"].astype(str))

# Obtrain the indice of GeneID
gene_ind = []
gene_ids = []

with open(os.path.join(pubtator3_dir, "count_data", "vocab.txt")) as f:
    for i, line in enumerate(f.readlines()):
        line = line.strip()
        if line.startswith("Gene"):
            # Take only the first GeneID into account e.g. 'Gene|100000688;571349;327429' --> 'Gene|100000688'
            # Because of this procedure, the vocabs become overlapped.
            line = line.split(";")[0]
            line = line.replace("Gene|", "")
            if line in all_human_entrez:
                gene_ind.append(i)
                gene_ids.append(line)

gene_ind = np.array(gene_ind)
gene_ids = np.array(gene_ids)

print("genes for analysis:", len(gene_ids))
print("genes for analysis (unique):", len(np.unique(gene_ids)))

count_mtx: (30338029, 11736685)
genes for analysis: 190337
genes for analysis (unique): 154839


In [134]:
counts2use_gene_related = count_mtx.tocsr()[indices, :][:, gene_ind].copy()
counts2use_cpd_related = drug_oh[:len(indices), :].copy()

rows_to_use = np.where(counts2use_gene_related.sum(axis=1)>0)[0]
counts2use_gene_related = counts2use_gene_related[rows_to_use]
counts2use_cpd_related = counts2use_cpd_related[rows_to_use]
print(f"The analysis will target {len(rows_to_use):,} documents.")

The analysis will target 3,485,523 documents.


In [143]:
# First, aggregate counts using indice with same GeneID
from scipy.sparse import coo_matrix
unique_genes, indices = np.unique(gene_ids, return_inverse=True)

coo = coo_matrix(counts2use_gene_related)

row = coo.row
col = indices[coo.col]
data = coo.data

counts2use_gene_related_aggregated = coo_matrix((data, (row, col)), shape=(coo.shape[0], len(unique_genes))).tocsr()

# Convert elements: if greater than 0, set to 1; otherwise, set to 0
counts2use_gene_related_aggregated = (counts2use_gene_related_aggregated > 0).astype(int)
print("counts2use_gene_related_aggregated:", counts2use_gene_related_aggregated.shape)


# co-occurrence matrix
cpd_gene_co_occurrence = counts2use_cpd_related.T.dot(counts2use_gene_related_aggregated)
print("cpd_gene_co_occurrence:", cpd_gene_co_occurrence.shape)

counts2use_gene_related_aggregated: (3485523, 154839)
cpd_gene_co_occurrence: (2074, 154839)


In [136]:
import scipy.sparse as sp
sp.save_npz("05_cpd_gene_co_occurrence.npz", cpd_gene_co_occurrence)

In [144]:
cooccurrence_counts = pd.DataFrame({"drug_id": drug_ids_to_use, "drug_name": drug_names_to_use,
                                    "count": cpd_gene_co_occurrence.sum(axis=1).A1})
cooccurrence_counts.sort_values("count", ascending=False)

Unnamed: 0,drug_id,drug_name,count
1560,DB00030,Insulin human,7394498
801,DB00133,Serine,6457136
944,DB01082,Streptomycin,5275975
1818,DB00052,Somatotropin,5088948
949,DB00123,L-Lysine,4136409
...,...,...,...
1666,DB13727,Azapetine,1
1231,DB13695,Penthienate,1
555,DB14123,Racementhol,1
2034,DB13396,Neocitrullamon,0


In [19]:
import scipy.sparse as sp
cpd_gene_co_occurrence = sp.load_npz("05_cpd_gene_co_occurrence.npz")

In [21]:
# normalize
cpd_gene_co_occurrence = np.log1p(cpd_gene_co_occurrence / cpd_gene_co_occurrence.sum(axis=1) * 1e6)


divide by zero encountered in divide



In [22]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold

def calc_cpd_features(cpd_gene_co_occurrence, threshold=0.01, n_components=200):
    print("Original number of features:", cpd_gene_co_occurrence.shape[1])
    # 1. Log-transformed
    # data = np.log1p(cpd_gene_co_occurrence)
    data = cpd_gene_co_occurrence.copy()

    # 2. Remove genes with low variance
    selector = VarianceThreshold(threshold)
    data = selector.fit_transform(data)
    print("Number of features after removing low variance features:", data.shape[1])

    # 3. Standardize
    scaler = StandardScaler()
    data = scaler.fit_transform(data.A)

    # 4. PCA
    pca = PCA(n_components=n_components, random_state=0)
    data = pca.fit_transform(data)
    print("Number of features after PCA:", data.shape[1])

    return data

cpd_features = calc_cpd_features(cpd_gene_co_occurrence, threshold=0.01, n_components=200)

Original number of features: 154839
Number of features after removing low variance features: 27149
Number of features after PCA: 200


In [57]:
feature_df = pd.DataFrame(cpd_features, columns=[f"PCA_{i}" for i in range(cpd_features.shape[1])])
feature_df.insert(0, "drug_id", drug_ids_to_use)
feature_df.insert(1, "drug_name", drug_names_to_use)
feature_df.to_csv("05_cpd_features.csv", index=False)