In [1]:
import polars as pl
import pandas as pd
from pathlib import Path
import random
import numpy as np
from datasets import load_dataset
import pickle as pkl
import re
from utils import DatasetGenerator
# Parent directory
parent_dir = str(Path().resolve().parents[1])

  from .autonotebook import tqdm as notebook_tqdm


# Drugbank to indication

In [2]:
import spacy
from transformers import pipeline
nlp = spacy.load("en_ner_bc5cdr_md")
pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")

  deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]
Device set to use mps:0


In [8]:
def is_abbreviation(text):
    text_clean = text.replace('.', '')
    return text_clean.isupper() and len(text_clean) <= 5    

def abbreviate(entity):
    return ''.join([i[0] for i in entity.split()]).upper()

def intersect_with_preferred_capitalization(list1, list2):
    # Build sets of lower-cased items for each list
    lower1 = {item.lower() for item in list1}
    lower2 = {item.lower() for item in list2}
    common_keys = lower1.intersection(lower2)

    # Create a dictionary mapping each lowercase key to all its variations from both lists
    mapping = {}
    for item in list1 + list2:
        key = item.lower()
        if key in common_keys:
            mapping.setdefault(key, set()).add(item)

    # For each common key, choose the version that starts with an uppercase letter if available
    result = []
    for key, variations in mapping.items():
        # Check if any variation is capitalized
        preferred = next((v for v in variations if v[0].isupper()), None)
        if preferred is None:
            # If none are capitalized, pick an arbitrary version (sorted to be consistent)
            preferred = sorted(variations)[0]
        result.append(preferred)
    return result

def extract_conditions_with_spacy(text):
    doc = nlp(text)
    entities = set(
        ent.text.strip() for ent in doc.ents 
        if ent.label_ == 'DISEASE' and 2 < len(ent.text) < 50
    )
    return list(entities) # if remove abbreviations if entities include full names

def extract_conditions_with_transformers(text):
    res = pipe(text, aggregation_strategy="first")
    entities = set([ent['word'] for ent in res if ent['entity_group'] == 'DISEASE' and 2 < len(ent['word']) < 50])
    return list(entities)

def extract_conditions(text):
    entities_spacy = extract_conditions_with_spacy(text)
    entities_transformers = extract_conditions_with_transformers(text)
    entities = intersect_with_preferred_capitalization(entities_spacy, entities_transformers)
    abbreviations = [abbreviate(entity) for entity in entities]

    output = []
    for entity in entities:
        if any(abbrev.upper() in entity.upper() for abbrev in abbreviations):
            # remove abbreviations if entities include full names
            continue
        output.append(entity)

    return list(set(entities) - set(abbreviations))
def validate_name(text):
    doc = nlp(text + ' ')
    entities = set(
        ent.text.strip() for ent in doc.ents 
        if ent.label_ == 'CHEMICAL' and 2 < len(ent.text) < 50
    )
    if len(entities) == 0:
        return ''
    return list(entities)[0]

In [4]:
text ='Lepirudin is indicated for anticoagulation in adult patients with acute coronary syndromes (ACS) such as unstable angina and acute myocardial infarction without ST elevation. In patients with ACS, lepirudin is intended for use with [aspirin].[L41539] Lepirudin is also indicated for anticoagulation in patients with heparin-induced thrombocytopenia (HIT) and associated thromboembolic disease in order to prevent further thromboembolic complications.[L41539]'
text = text.split('.')[0]
print(extract_conditions(text))
text = 'Fluconazole can be administered in  the treatment of the following fungal infections[L11043]:\r\n\r\n 1) Vaginal yeast infections caused by Candida\r\n 2) Systemic Candida infections\r\n 3) Both esophageal and oropharyngeal candidiasis \r\n 4) Cryptococcal meningitis\r\n 5) UTI (urinary tract infection) by Candida\r\n 6) Peritonitis (inflammation of the peritoneum) caused by Candida\r\n\r\n**A note on fungal infection prophylaxis**\r\n\r\nPatients receiving bone marrow transplantation who are treated with cytotoxic chemotherapy and/or radiation therapy may be predisposed to candida infections, and may receive fluconazole as prophylactic therapy.'
extract_conditions(text)

['acute myocardial infarction', 'unstable angina', 'acute coronary syndromes']


['Vaginal yeast infections',
 'fungal infection',
 'urinary tract infection',
 'Peritonitis',
 'Candida infections']

In [5]:
from wordfreq import zipf_frequency
zipf_frequency('Peritonitis', lang='en')

2.34

In [9]:
drugbank = (
    pl.read_csv(
        f"{parent_dir}/datasets/source/drugbank_data_with_indications.csv",
        columns=['Name', 'indication']
    )
    .drop_nulls(subset=['indication'])
    .with_columns([
        pl.col('indication')
          .map_elements(lambda x: extract_conditions(x.split('.')[0]), return_dtype=pl.List(pl.String))
          .alias('disease'),
        pl.col('indication')
          .map_elements(lambda x: x.split('.')[0], return_dtype=pl.String)
          .alias('indication_simple'),
        pl.col('Name').map_elements(lambda x: validate_name(x), return_dtype=pl.String)
          # .str.replace(r"\(.*?\)", "")  
          # .str.replace(r"\[.*?\]", "")  
          # .str.split(",").list.get(0)  
          .str.strip_chars()
          .alias('chemical')
        # pl.col
    ])
    .filter(pl.col('chemical').str.len_chars() > 0)
    # Properly clean each disease name within the list
    .with_columns(
        pl.col('disease').map_elements(
            lambda disease_list: sorted({
                re.sub(r"\(.*?\)|\[.*?\]", "", d)
                  .split(",")[0]
                  .strip("#+ ")
                  .replace('-', ' ')
                  .strip()
                for d in disease_list
            }),
            return_dtype=pl.List(pl.String)
        ).alias('disease_clean')
    )
    .filter(pl.col('disease_clean').list.len() > 0)
)


In [10]:
drugbank.write_json(f"{parent_dir}/datasets/source/drugbank_data_with_indications_clean.json")

In [56]:
drugdict = dict(
    drugbank.select(['Name', 'disease_clean']).iter_rows()
)

In [None]:
class DrugDisease(DatasetGenerator):
    '''
    Class to handle the dataset from DrugBank
    '''
    def apply_template(self, drug: str, condition: str, negated: bool=False):
        if negated:
            return f"{drug} is not indicated for the treatment of {condition}."
        else:
            return f"{drug} is indicated for the treatment of {condition}."
        
        

    def lookup_incorrect(self, key) -> str:
        '''
        Return False condition for a given drug
        '''
        correct = self.source[key]
        choice = random.choice(list(set(self.values) - set(correct)))
        if choice.lower() in [c.lower() for c in correct]:
            return self.lookup_incorrect(key)
        elif abbreviate(choice) in [abbreviate(c) for c in correct] or abbreviate(choice) in [c for c in correct]:
            return self.lookup_incorrect(key)
        elif any(word in c.lower().split() for word in choice.lower().split() for c in correct):
            return self.lookup_incorrect(key)
        return choice
    
db = DrugDisease(drugdict, category='indications')

In [58]:
db.lookup_incorrect('Lepirudin'), db.source['Lepirudin']

('blood loss',
 ['acute coronary syndromes',
  'acute myocardial infarction',
  'unstable angina'])

In [59]:
data = db.generate_full_dataset()
data.write_json(f"{parent_dir}/datasets/source/drug_disease.json")
subsample = db.generate_subsample(n = 5000, seed=42).with_columns(
                 pl.col("correct_object_2").list.join(", ").alias("correct_object_2"))
subsample.write_csv(f"{parent_dir}/datasets/drug_disease_subsample.csv")


In [86]:
def get_rand_frequencies(word):
    freq  = zipf_frequency(word, 'en', minimum=0.0, wordlist='best')
    if freq == 0: 
        random.seed(word)
        if random.random() < 0.1: # keep a small amount of words with 0 frequency
            freq = 1
    return freq

In [87]:
subsample = db.generate_full_dataset().with_columns(
                 pl.col("correct_object_2").list.join(", ").alias("correct_object_2"))
subsample = subsample.with_columns(
    pl.col('object_1').map_elements(lambda x: get_rand_frequencies(x.lower()), return_dtype=float).alias('freq_1'),
    pl.col('object_2').map_elements(lambda x: get_rand_frequencies(x.lower()), return_dtype=float).alias('freq_2'),
).filter((pl.col('freq_1') > 0) & (pl.col('freq_2') > 0))
subsample.write_csv(f"{parent_dir}/datasets/drug_disease_full.csv")
subsample.group_by(['correct', 'negation']).len()

correct,negation,len
bool,bool,u32
False,False,1419
True,False,1522
True,True,1439
False,True,1523


# Synthetic Statements
### Generator for the Drugs and Conditions

In [16]:
from namemaker import NameSet
import namemaker

seed = 'udaxihhexdvxrcsnbacghqtargwuwr'
random.seed(seed)
namemaker_rng = namemaker.get_rng()
namemaker_rng.seed(seed)

drug_NS = NameSet(names = drugdict.keys())
drugs_fake = [drug_NS.make_name(add_to_history=False) for _ in range(500)]
drugs_fake = list(set(drugs_fake))
# Validate
drugs_validated = []
for item in drugs_fake:
    if validate_name(item) != '':
        pass
    else:
        drugs_validated.append(item)
with open(f"{parent_dir}/datasets/source/IDK_drugs_v2.txt", 'w') as f:
    f.write("\n".join(map(str, drugs_validated)))

In [17]:
seed = 'udaxihhexdvxrcsnbacghqtargwuwr'
random.seed(seed)
namemaker_rng = namemaker.get_rng()
namemaker_rng.seed(seed)
condition_NS = NameSet(names = db.values)
conditions_fake = [condition_NS.make_name(add_to_history=False) for _ in range(200)]
conditions_validated = []
for item in conditions_fake:
    if any([item.lower() in c.lower() for c in db.values]):
        pass
    else:
        conditions_validated.append(item)
with open(f"{parent_dir}/datasets/source/IDK_diseases_v2.txt", 'w') as f:
        f.write("\n".join(map(str, conditions_validated)))

### Create Unverifiable dataset

In [14]:
seed = 'udaxihhexdvxrcsnbacghqtargwuwr'
IDK_drugs = pd.read_csv(f"{parent_dir}/datasets/source/IDK_drugs_checked_v2.csv")
IDK_drugs = IDK_drugs[IDK_drugs['Keep'] == 1]
IDK_drugs = IDK_drugs['Name'].tolist()

IDK_conditions = pd.read_csv(f"{parent_dir}/datasets/source/IDK_diseases_checked_v2.csv")
IDK_conditions = IDK_conditions[IDK_conditions['Keep'] == 1]
IDK_conditions = IDK_conditions['Name'].tolist()
random.seed(seed)
fake_indications = {}
for drug in IDK_drugs:
    fake_indications[drug] = random.sample(IDK_conditions, 2)
#fake_indications

In [25]:
db_fake = DrugDisease(fake_indications, is_fake=True, 
                      category='indications')
data_fake = db_fake.generate_full_dataset()
data_fake.write_json(f"{parent_dir}/datasets/source/drug_disease_fake.json")
subsample_fake = db_fake.generate_subsample(n = 1000, seed=42).with_columns(
                 pl.col("correct_object_2").list.join(", ").alias("correct_object_2"))
subsample_fake.write_csv(f"{parent_dir}/datasets/drug_disease_synth_subsample.csv")
subsample_fake

statement,object_1,object_2,correct_object_2,correct,negation,real_object,fake_object,fictional_object,category
str,str,str,str,bool,bool,bool,bool,bool,str
"""Alumil is indicated for the tr…","""Alumil""","""reticers""","""candigemia, reticers""",false,false,false,true,false,"""indications"""
"""Cysternime is not indicated fo…","""Cysternime""","""perebrilepsies""","""perebrilepsies, nonvalvulgaris""",false,true,false,true,false,"""indications"""
"""Neostonicone is not indicated …","""Neostonicone""","""delial brease""","""delial brease, breatory disord…",false,true,false,true,false,"""indications"""
"""Buspium is indicated for the t…","""Buspium""","""perlipidematory loss""","""perlipidematory loss, uronchos…",false,false,false,true,false,"""indications"""
"""Azelanzamide is indicated for …","""Azelanzamide""","""hepathe overampsis""","""hepathe overampsis, acular aci…",false,false,false,true,false,"""indications"""
…,…,…,…,…,…,…,…,…,…
"""Deutetractone is indicated for…","""Deutetractone""","""atori infective Disease""","""atori infective Disease, intri…",false,false,false,true,false,"""indications"""
"""Tramaltolamide is indicated fo…","""Tramaltolamide""","""akine disorders""","""hyperampsies, anal bleepischem…",false,false,false,true,false,"""indications"""
"""Glutalacine is indicated for t…","""Glutalacine""","""sorder cand vomiasis""","""asperpetiformis, sorder cand v…",false,false,false,true,false,"""indications"""
"""Fenose is not indicated for th…","""Fenose""","""dyslipolyneury""","""Onychoticus, dyslipolyneury""",false,true,false,true,false,"""indications"""


In [23]:
subsample_fake.filter(pl.col('real_object')==False).group_by(['negation']).count()

  subsample_fake.filter(pl.col('real_object')==False).group_by(['negation']).count()


negation,count
bool,u32
True,522
False,478
