### Data Processing ###

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import xml.etree.ElementTree as ET

def parse_drugbank_corpus(drugbank_dir):
    """
    Parses all XML files in the drugbank_dir.
    Returns a list of dicts with sentences, entities, and DDIs.
    """
    data = []

    for filename in os.listdir(drugbank_dir):
        #print(filename)
        if filename.endswith(".xml"):
            file_path = os.path.join(drugbank_dir, filename)
            tree = ET.parse(file_path)
            root = tree.getroot()

            for sentence in root.iter("sentence"):
                sent_text = sentence.attrib.get("text")
                sent_id = sentence.attrib.get("id")

                entities = {}
                for entity in sentence.iter("entity"):
                    ent_id = entity.attrib["id"]
                    ent_text = entity.attrib["text"]
                    ent_char_offset = entity.attrib.get("charOffset", "")
                    entities[ent_id] = {
                        "id": ent_id,
                        "text": ent_text,
                        "char_offset": ent_char_offset
                    }

                ddilist = []
                for pair in sentence.iter("pair"):
                    ddi_label = pair.attrib["ddi"]
                    if ddi_label == "true":
                        e1 = pair.attrib["e1"]
                        e2 = pair.attrib["e2"]
                        ddi_type = pair.attrib.get("type", "")
                        ddilist.append({
                            "drug1": entities[e1]["text"],
                            "drug2": entities[e2]["text"],
                            "interaction_type": ddi_type
                        })

                data.append({
                    "sentence_id": sent_id,
                    "sentence_text": sent_text,
                    "entities": list(entities.values()),
                    "ddis": ddilist
                })

    return data

# Parse your specific folder
drugbank_dir_train = "/content/drive/MyDrive/w266 Final Project/Train/DrugBank"
drugbank_dir_test = "/content/drive/MyDrive/w266 Final Project/Test/Test for DDI Extraction task/DrugBank"
unfiltered_drugbank_sentences_train = parse_drugbank_corpus(drugbank_dir_train)
unfiltered_drugbank_sentences_test = parse_drugbank_corpus(drugbank_dir_test)

print(f"Parsed {len(unfiltered_drugbank_sentences_train)} sentences from DrugBank train.")
print(f"Parsed {len(unfiltered_drugbank_sentences_test)} sentences from DrugBank test.")

Parsed 5675 sentences from DrugBank train.
Parsed 973 sentences from DrugBank test.


In [None]:
# Filter sentences with 2 or more entities in the sentence
drugbank_sentences_train = [s for s in unfiltered_drugbank_sentences_train if len(s['entities']) >= 2]
drugbank_sentences_test = [s for s in unfiltered_drugbank_sentences_test if len(s['entities']) >= 2]

print(f"{len(drugbank_sentences_train)} filtered sentences from DrugBank Train.")
print(f"{len(drugbank_sentences_test)} filtered sentences from DrugBank Test.")

3256 filtered sentences from DrugBank Train.
620 filtered sentences from DrugBank Test.


In [None]:
from itertools import combinations
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix

def generate_drug_pairs(entities):
  sentence_entities = []
  for entity in entities:
    sentence_entities.append(entity["text"])

  drug_pairs = list(combinations(sentence_entities, 2))
  #print(drug_pairs)
  return drug_pairs

In [None]:
# Helper to parse 'start-end' string into a tuple of integers
def parse_offset(offset_str):
    start, end = offset_str.split('-')
    return int(start), int(end)

In [None]:
drugbank_sentences_train

[{'sentence_id': 'DDI-DrugBank.d289.s6',
  'sentence_text': 'Hormonal Contraceptives, Including Oral, Injectable, Transdermal, and Implantable Contraceptives: An interaction study demonstrated that co-administration of bosentan and the oral hormonal contraceptive Ortho-Novum produced average decreases of norethindrone and ethinyl estradiol levels of 14% and 31%, respectively.',
  'entities': [{'id': 'DDI-DrugBank.d289.s6.e0',
    'text': 'Hormonal Contraceptives',
    'char_offset': '0-22'},
   {'id': 'DDI-DrugBank.d289.s6.e1',
    'text': 'Contraceptives',
    'char_offset': '82-95'},
   {'id': 'DDI-DrugBank.d289.s6.e2',
    'text': 'bosentan',
    'char_offset': '158-165'},
   {'id': 'DDI-DrugBank.d289.s6.e3',
    'text': 'hormonal contraceptive',
    'char_offset': '180-201'},
   {'id': 'DDI-DrugBank.d289.s6.e4',
    'text': 'Ortho-Novum',
    'char_offset': '203-213'},
   {'id': 'DDI-DrugBank.d289.s6.e5',
    'text': 'norethindrone',
    'char_offset': '245-257'},
   {'id': 'DDI-Dr

In [None]:
train_formatted_data = [] #drug names added

for i,s in enumerate(drugbank_sentences_train):
  try:
    #print(s['sentence_text'])
    #print(s['entities'])
    # print(i, s['ddis'])

    #sentence_text = s['sentence_text']
    #entities = s['entities']
    ddis = s.get('ddis', [])  # interactions
    #print(ddis)

    drug_pairs = generate_drug_pairs(s['entities'])
    #print(drug_pairs)

    for drug_pair in drug_pairs:
      formatted_sentence = s['sentence_text']
      entity_names = []
      start_list = []
      end_list = []
      #print ("DRUG PAIR", drug_pair)

      for entity in s['entities']:
        #print(entity)
        if len(entity_names) > 1:
          #print("done")
          break

        if entity['text'] in drug_pair:
          start, end = parse_offset(entity['char_offset'])
          start_list.append(start)
          end_list.append(end)
          entity_names.append(entity['text'])

        #print(start_list)
        #print(end_list)

      #formatted_sentence = formatted_sentence[:end_list[-1]+1] + f"[/E2]" + formatted_sentence[end_list[-1]+1:]
      #formatted_sentence = formatted_sentence[:start_list[-1]] + f"[E2]" + formatted_sentence[start_list[-1]:]
      #formatted_sentence = formatted_sentence[:end_list[0]+1] + f"[/E1]" + formatted_sentence[end_list[0]+1:]
      #formatted_sentence = formatted_sentence[:start_list[0]] + f"[E1]" + formatted_sentence[start_list[0]:]
      #print(formatted_sentence)

      label = "false"
      for ddi in ddis:
        d1 = ddi['drug1'].lower()
        d2 = ddi['drug2'].lower()
        pair_lower = [name.lower() for name in entity_names]
        if (d1 in pair_lower and d2 in pair_lower):
          label = ddi['interaction_type']
          break

      # Extract drug1 and drug2 from entity_names
      # entity_names[0] corresponds to E1, entity_names[1] corresponds to E2
      drug1 = entity_names[0] if len(entity_names) > 0 else ""
      drug2 = entity_names[1] if len(entity_names) > 1 else ""

      # print("Formatted:", formatted_sentence)
      # print("Label:", label)
      # print("Drug1:", drug1)
      # print("Drug2:", drug2)
      # print("---")

      train_formatted_data.append({
              "sentence": formatted_sentence,
              "labels": label,
              "drug1": drug1,
              "drug2": drug2
          })

  except ValueError as e:
    # print(f"Skipping due to ValueError: {e}")
    continue

In [None]:
train_formatted_data

[{'sentence': 'Hormonal Contraceptives, Including Oral, Injectable, Transdermal, and Implantable Contraceptives: An interaction study demonstrated that co-administration of bosentan and the oral hormonal contraceptive Ortho-Novum produced average decreases of norethindrone and ethinyl estradiol levels of 14% and 31%, respectively.',
  'labels': 'false',
  'drug1': 'Hormonal Contraceptives',
  'drug2': 'Contraceptives'},
 {'sentence': 'Hormonal Contraceptives, Including Oral, Injectable, Transdermal, and Implantable Contraceptives: An interaction study demonstrated that co-administration of bosentan and the oral hormonal contraceptive Ortho-Novum produced average decreases of norethindrone and ethinyl estradiol levels of 14% and 31%, respectively.',
  'labels': 'false',
  'drug1': 'Hormonal Contraceptives',
  'drug2': 'bosentan'},
 {'sentence': 'Hormonal Contraceptives, Including Oral, Injectable, Transdermal, and Implantable Contraceptives: An interaction study demonstrated that co-adm

In [None]:
test_formatted_data = [] #drug names added

for i,s in enumerate(drugbank_sentences_test):
  try:
    #print(s['sentence_text'])
    #print(s['entities'])
    # print(i, s['ddis'])

    #sentence_text = s['sentence_text']
    #entities = s['entities']
    ddis = s.get('ddis', [])  # interactions
    #print(ddis)

    drug_pairs = generate_drug_pairs(s['entities'])
    #print(drug_pairs)

    for drug_pair in drug_pairs:
      formatted_sentence = s['sentence_text']
      entity_names = []
      start_list = []
      end_list = []
      #print ("DRUG PAIR", drug_pair)

      for entity in s['entities']:
        #print(entity)
        if len(entity_names) > 1:
          #print("done")
          break

        if entity['text'] in drug_pair:
          start, end = parse_offset(entity['char_offset'])
          start_list.append(start)
          end_list.append(end)
          entity_names.append(entity['text'])

        #print(start_list)
        #print(end_list)

      #formatted_sentence = formatted_sentence[:end_list[-1]+1] + f"[/E2]" + formatted_sentence[end_list[-1]+1:]
      #formatted_sentence = formatted_sentence[:start_list[-1]] + f"[E2]" + formatted_sentence[start_list[-1]:]
      #formatted_sentence = formatted_sentence[:end_list[0]+1] + f"[/E1]" + formatted_sentence[end_list[0]+1:]
      #formatted_sentence = formatted_sentence[:start_list[0]] + f"[E1]" + formatted_sentence[start_list[0]:]
      #print(formatted_sentence)

      label = "false"
      for ddi in ddis:
        d1 = ddi['drug1'].lower()
        d2 = ddi['drug2'].lower()
        pair_lower = [name.lower() for name in entity_names]
        if (d1 in pair_lower and d2 in pair_lower):
          label = ddi['interaction_type']
          break

      # Extract drug1 and drug2 from entity_names
      # entity_names[0] corresponds to E1, entity_names[1] corresponds to E2
      drug1 = entity_names[0] if len(entity_names) > 0 else ""
      drug2 = entity_names[1] if len(entity_names) > 1 else ""

      # print("Formatted:", formatted_sentence)
      # print("Label:", label)
      # print("Drug1:", drug1)
      # print("Drug2:", drug2)
      # print("---")

      test_formatted_data.append({
              "sentence": formatted_sentence,
              "labels": label,
              "drug1": drug1,
              "drug2": drug2
          })

  except ValueError as e:
    # print(f"Skipping due to ValueError: {e}")
    continue

In [None]:
#classes are imbalanced

df_train = pd.DataFrame(train_formatted_data)
df_test = pd.DataFrame(test_formatted_data)


In [None]:
df_train = df_train[df_train['labels'].astype(bool)]  # filters out empty strings and NaNs
df_test = df_test[df_test['labels'].astype(bool)]

df_train = df_train.dropna(subset=["labels"])
df_test = df_test.dropna(subset=["labels"])

df_train = df_train[df_train['labels'].notna() & (df_train['labels'] != '')]
df_test = df_test[df_test['labels'].notna() & (df_test['labels'] != '')]

In [None]:
#shuffle the results of my training dataframe

df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)
df_test = df_test.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
df_train = df_train.sample(n=30, random_state=42).reset_index(drop=True)
df_test = df_test.sample(n=30, random_state=42).reset_index(drop=True)

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None) # Display full content of columns

df_train

Unnamed: 0,sentence,labels,drug1,drug2
0,"Therefore, co-administration of clozapine with other drugs that are metabolized by this isozyme, including antidepressants, phenothiazines, carbamazepine, and Type 1C antiarrhythmics (e.g., propafenone, flecainide and encainide), or that inhibit this enzyme (e.g., quinidine), should be approached with caution.",false,Type 1C antiarrhythmics,encainide
1,"ACE Inhibitors and Angiotensin II Receptor Antagonists (Congestive Heart Failure Post-Myocardial Infarction)- In EPHESUS, 3020 (91%) patients receiving INSPRA 25 to 50 mg also received ACE inhibitors or angiotensin II receptor antagonists (ACEI/ARB).",false,ACE Inhibitors,ACEI
2,"Amphotericin, Foscarnet, and Aminoglycosides: Drugs such as amphotericin, foscarnet, and aminoglycosides may increase the risk of developing peripheral neuropathy or other HIVID-associated adverse events by interfering with the renal clearance of zalcitabine (thereby raising systemic exposure).",false,Foscarnet,Aminoglycosides
3,"Substances that are potent inhibitors of CYP3A4 activity (eg, ketoconazole and itraconazole) decrease gefitinib metabolism and increase gefitinib plasma concentrations.",mechanism,ketoconazole,gefitinib
4,"Antacids, Sucralfate, Metal Cations, Multivitamins Quinolones form chelates with alkaline earth and transition metal cations.",false,Multivitamins,Quinolones
5,"Drugs that reportedly may increase oral anticoagulant response, ie, increased prothrombin response, in man include:alcohol*;allopurinol;aminosalicylic acid;amiodarone;anabolic steroids;antibiotics;bromelains;chloral hydrate*;chlorpropamide;chymotrypsin;cimetidine;cinchophen;clofibrate;dextran;dextrothyroxine;diazoxide;dietary deficiencies;diflunisal;disulfiram;drugs affecting blood elements;ethacrynic acid;fenoprofen;glucagon;hepatotoxic drugs;ibuprofen;indomethacin;influenza virus vaccine;inhalation anesthetics;mefenamic acid;methyldopa;methylphenidate;metronidazole;miconazole;monoamine oxidase inhibitors;nalidixic acid;naproxen;oxolinic acid;oxyphenbutazone;pentoxifylline;phenylbutazone;phenyramidol;phenytoin;prolonged hot weather;prolonged narcotics;pyrazolones;quinidine;quinine;ranitidine*;salicylates;sulfinpyrazone;sulfonamides, long acting;sulindac;thyroid drugs;tolbutamide;triclofos sodium;trimethoprim/sulfamethoxazole;unreliable prothrombin time determinations;warfarin sodium overdosage.",false,phenyramidol,prolonged narcotics
6,"Uricosuric drugs, such as probenecid and sulfinpyrazone, can inhibit renal tubular secretion of nitrofurantoin.",mechanism,Uricosuric drugs,nitrofurantoin
7,"Agents that are CYP3A4 inhibitors that have been found, or are expected, to increase plasma levels of EQUETROTM are the following: Acetazolamide, azole antifungals, cimetidine, clarithromycin(1), dalfopristin, danazol, delavirdine, diltiazem, erythromycin(1), fluoxetine, fluvoxamine, grapefruit juice, isoniazid, itraconazole, ketoconazole, loratadine, nefazodone, niacinamide, nicotinamide, protease inhibitors, propoxyphene, quinine, quinupristin, troleandomycin, valproate(1), verapamil, zileuton.",false,Acetazolamide,zileuton
8,"Agents that have been found, or are expected to have decreased plasma levels in the presence of EQUETROTM due to induction of CYP enzymes are the following: Acetaminophen, alprazolam, amitriptyline, bupropion, buspirone, citalopram, clobazam, clonazepam, clozapine, cyclosporin, delavirdine, desipramine, diazepam, dicumarol, doxycycline, ethosuximide, felbamate, felodipine, glucocorticoids, haloperidol, itraconazole, lamotrigine, levothyroxine, lorazepam, methadone, midazolam, mirtazapine, nortriptyline, olanzapine, oral contraceptives(3), oxcarbazepine, Phenytoin(4), praziquantel, protease inhibitors, quetiapine, risperidone, theophylline, topiramate, tiagabine, tramadol, triazolam, valproate, warfarin(5) , ziprasidone, and zonisamide.",false,levothyroxine,praziquantel
9,"Drugs that reportedly may increase oral anticoagulant response, ie, increased prothrombin response, in man include:alcohol*;allopurinol;aminosalicylic acid;amiodarone;anabolic steroids;antibiotics;bromelains;chloral hydrate*;chlorpropamide;chymotrypsin;cimetidine;cinchophen;clofibrate;dextran;dextrothyroxine;diazoxide;dietary deficiencies;diflunisal;disulfiram;drugs affecting blood elements;ethacrynic acid;fenoprofen;glucagon;hepatotoxic drugs;ibuprofen;indomethacin;influenza virus vaccine;inhalation anesthetics;mefenamic acid;methyldopa;methylphenidate;metronidazole;miconazole;monoamine oxidase inhibitors;nalidixic acid;naproxen;oxolinic acid;oxyphenbutazone;pentoxifylline;phenylbutazone;phenyramidol;phenytoin;prolonged hot weather;prolonged narcotics;pyrazolones;quinidine;quinine;ranitidine*;salicylates;sulfinpyrazone;sulfonamides, long acting;sulindac;thyroid drugs;tolbutamide;triclofos sodium;trimethoprim/sulfamethoxazole;unreliable prothrombin time determinations;warfarin sodium overdosage.",false,miconazole,warfarin sodium


### Llama Set Up ###

In [None]:
!pip install -q transformers accelerate bitsandbytes

from huggingface_hub import notebook_login
print("Please log in to your Hugging Face account:")
notebook_login()


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m53.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m58.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


model_id = "meta-llama/Llama-2-7b-chat-hf"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

print("\nLoading model and tokenizer... This may take a few minutes.")

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
)

print("✅ Model loaded successfully!")



Loading model and tokenizer... This may take a few minutes.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

✅ Model loaded successfully!


In [None]:
def create_ddi_prompt_2(sentence, drug1, drug2):
    """Formats the input for the LLaMA model."""
    prompt = f"""
    Given the sentence, identify the interaction type between '{drug1}' and '{drug2}' using only the sentence given.
    The possible interaction types are: advise, effect, mechanism, int, or false.
    Return false for the interaction type.

    Sentence: "{sentence}"

    Interaction type:
    """
    return prompt



In [None]:
def create_ddi_prompt_3(sentence, drug1, drug2):

    prompt = f"""<s>[INST] <<SYS>>You are an efficient robot that extracts interactions between two drugs in the given sentence.
    First you need to identfy what you're analyzing. 1. Only return the two drugs you're analyzing in a list format.

    Sentence: {sentence}
    Drug 1: {drug1}
    Drug 2: {drug2}<<SYS>>

    My questions is can you return the two drugs you're analyzing
"""
    return prompt



In [None]:
def create_ddi_prompt_4(sentence, drug1, drug2):

    prompt = f"""<s>[INST] <<SYS>>You are an efficient robot that extracts interactions between two entities in the given sentence. No questions.
    1. First you need to identfy what you're analyzing. Return the two drugs you're analyzing in a list format.
    2. Don't use outside information. Only use the information given. Identify the interaction using the following types: advise, effect, mechanism, int, or false.
    Sentence: {sentence}
    Drug 1: {drug1}
    Drug 2: {drug2}<<SYS>>

    My questions is can you return the two drugs you're analyzing
"""
    return prompt



In [None]:
def create_ddi_prompt_5(sentence, drug1, drug2):

    prompt = f"""[INST]

    Identify the primary subject drug in the sentence.


    Sentence:{sentence}
    Drug 1: {drug1}
    Drug 2: {drug2}
    Label: ?

    [/INST]
    """
    return prompt



In [None]:
def create_ddi_prompt_6(sentence, drug1, drug2):

    prompt = f"""<<SYS>>
    You are a text analysis assistant that identifies explicit drug-drug interactions. Be precise and literal.
    <</SYS>>

    [INST]
    Analyze this sentence to determine if it explicitly mentions an interaction between {drug1} and {drug2}.

    Sentence: "{sentence}"

    Instructions:
    - Look for direct statements that {drug1} and {drug2} interact with each other
    - Ignore mentions where drugs are in separate lists or contexts
    - Do not infer or assume interactions

    Respond in this exact format:
    ANSWER: [Yes/No]
    REASONING: [One sentence explaining your answer]
    [/INST]"""

    return prompt



In [None]:
def create_ddi_prompt_7(sentence, drug1, drug2):

    prompt = f"""<<SYS>>
    You are a precise text analysis assistant. Your task is to identify explicit drug-drug interactions mentioned in sentences. Be extremely literal and do not make inferences.
    <</SYS>>

    [INST]
    Task: Determine if the given sentence explicitly states that {drug1} and {drug2} interact with each other.

    Rules:
    1. Only answer "Yes" if the sentence directly states that {drug1} interacts with {drug2} or vice versa
    2. Answer "False" if the drugs are only mentioned in separate contexts, lists, or different mechanisms
    3. Do not infer interactions from indirect information
    4. Do not add any explanations beyond the required answer

    Sentence: {sentence}

    Question: Does this sentence explicitly state an interaction between {drug1} and {drug2}?

    Answer with only "Yes" or "No":
    [/INST]"""

    return prompt





In [None]:
def create_ddi_prompt_step_1(sentence, drug1, drug2):
    prompt = f"""<<SYS>>
You are a precise text analysis assistant. Your task is to identify explicit drug-drug interactions mentioned in sentences. Be extremely literal and do not make inferences.
<</SYS>>

[INST]
Task: Determine if the given sentence explicitly states that {drug1} and {drug2} interact with each other.

Critical Rules:
1. Only answer "Yes" if the sentence directly states that {drug1} interacts with {drug2} or vice versa
2. Answer "No" if:
   - The drugs are mentioned in the same list or category
   - One drug is an example of the other drug's class
   - The drugs are mentioned in separate contexts
   - The sentence describes both drugs interacting with a third substance
3. Do not infer interactions from indirect information
4. Pay attention to parenthetical examples (e.g., drug1, drug2) - these show classification, not interaction

Sentence: {sentence}

Question: Does this sentence explicitly state an interaction between {drug1} and {drug2}?

Answer with only "Yes" or "No":
[/INST]"""

    return prompt

In [None]:
def create_ddi_prompt_with_types(sentence, drug1, drug2):
    prompt = f"""<<SYS>>
You are a precise text analysis assistant. Classify drug interactions based only on explicit text mentioning both {drug1} and {drug2}. Be extremely literal and do not make inferences.
<</SYS>>

[INST]
Sentence: "{sentence}"

Task: Classify the interaction between {drug1} and {drug2} based on the exact words that describe what happens when they're used together.

Choose ONE category:
- advise: Clinical recommendations. Sentence likely to contain "should not/should/recommend/avoid/contraindicated/not be considered"
- effect: A described changed based pharmacodynamics or pharmacokinetics. Patients experience "causes nausea/toxicity/bleeding/seizures"
- mechanism: sentence explains pharmacokinetic/pharmacodynamic changes "elevated/decreased concentrations/levels", "prolonged/shortened half-life", "blocks/inhibits [enzyme/pathway]"
- int: sentence mentions interaction but gives no specific details

Critical: Look at the exact phrase in the sentence. What category does it match?

Answer: [advise/effect/mechanism/int]
[/INST]"""

    return prompt

In [None]:
def create_ddi_prompt_step_2(sentence, drug1, drug2):
    prompt = f"""<<SYS>>
You are a precise text analysis assistant. Classify drug interactions based only on explicit text mentioning both {drug1} and {drug2}. Be extremely literal and do not make inferences.
<</SYS>>

[INST]
Task: Classify the interaction between two drugs based on the exact words that describe what happens when they're used together.

Choose ONE category:
- advise: Clinical recommendations. Sentence likely to contain "should not/should/recommend/avoid/contraindicated/not be considered"
- effect: A described changed based pharmacodynamics or pharmacokinetics. Patients experience "causes nausea/toxicity/bleeding/seizures"
- mechanism: sentence explains pharmacokinetic/pharmacodynamic changes "elevated/decreased concentrations/levels", "prolonged/shortened half-life", "blocks/inhibits [enzyme/pathway]"
- int: sentence mentions interaction but gives no specific details

Examples:

Sentence: "Warfarin should not be administered with aspirin due to increased bleeding risk."
Drugs: warfarin, aspirin
Answer: advise

Sentence: "Concurrent use of fluoxetine and tramadol causes serotonin syndrome in patients."
Drugs: fluoxetine, tramadol
Answer: effect

Sentence: "Ketoconazole inhibits CYP3A4 metabolism of simvastatin, leading to elevated plasma concentrations."
Drugs: ketoconazole, simvastatin
Answer: mechanism

Sentence: "There is a potential interaction between metformin and contrast dye."
Drugs: metformin, contrast dye
Answer: int

Sentence: "Patients taking digoxin are recommended to avoid concurrent furosemide therapy."
Drugs: digoxin, furosemide
Answer: advise

Sentence: "Combining alcohol with benzodiazepines results in respiratory depression."
Drugs: alcohol, benzodiazepines
Answer: effect

Sentence: "Rifampin decreases the half-life of oral contraceptives by inducing hepatic enzymes."
Drugs: rifampin, oral contraceptives
Answer: mechanism

Now classify this sentence:
Sentence: "{sentence}"
Drugs: {drug1}, {drug2}

Critical: Look at the exact phrase in the sentence. What category does it match based on the examples above?

Answer: [advise/effect/mechanism/int]
[/INST]"""

    return prompt

In [None]:
!pip install -q pandas scikit-learn

import torch
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import notebook_login
from tqdm.auto import tqdm # For a nice progress bar

### Inference ###

In [None]:
import re

def extract_keyword(sentence):
    """
    Extracts one of the target keywords from a sentence, matching whole words only.

    Args:
        sentence (str): The input string to search.

    Returns:
        str: The first keyword found (as a whole word), or None if no keyword is present.
    """
    keywords = ["int", "mechanism", "advise", "effect"]

    for keyword in keywords:
        pattern = r"\b" + re.escape(keyword) + r"\b"
        if re.search(pattern, sentence, re.IGNORECASE):
            return keyword

    return None

In [None]:

label_names = ['advise', 'effect', 'mechanism', 'int', 'false']
label_map = {name: i for i, name in enumerate(label_names)}
id_to_label = {i: name for i, name in enumerate(label_names)}


print("\n Running inference on the test set...")

true_labels = []
predicted_labels = []
misclassifications = []
all_results = []


for index, row in tqdm(df_test.iterrows(), total=df_test.shape[0]):
    sentence = row['sentence']
    drug1 = row['drug1']
    drug2 = row['drug2']
    true_label_str = row['labels']
    # print(sentence)
    # print("TRUE LABEL", true_label_str)
    # print(drug1)
    # print(drug2)

    prompt = create_ddi_prompt_step_1(sentence,drug1,drug2)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    output = model.generate(
        **inputs,
        do_sample=False,
        num_beams=1
    )
    result_text = tokenizer.decode(output[0], skip_special_tokens=True)
    predicted_interaction = result_text[len(prompt):].strip()
    raw_prediction = result_text[len(prompt):].strip()
    processed_prediction = raw_prediction.split('.')[0].strip()

    if processed_prediction.lower() == "no":
      # print("False")
      interaction_extracted = "false"
    else:
      interaction_prompt = create_ddi_prompt_step_2(sentence, drug1, drug2)
      interaction_inputs = tokenizer(interaction_prompt, return_tensors="pt").to("cuda")

      interaction_output = model.generate(
        **interaction_inputs,
        do_sample=False,
        num_beams=1
      )

      result_text_interation = tokenizer.decode(interaction_output[0], skip_special_tokens=True)
      predicted_interaction = result_text_interation[len(interaction_prompt):].strip()
      interaction_extracted = extract_keyword(predicted_interaction)
      # print("Interaction:", predicted_interaction)
      # print(interaction_extracted)



    true_label_id = label_map[true_label_str]
    true_labels.append(true_label_id)


    predicted_label_id = label_map.get(interaction_extracted, label_map['false'])
    predicted_labels.append(predicted_label_id)


    result = {
        'index': index,
        'sentence': sentence,
        'true_label': true_label_str,
        'predicted_label': interaction_extracted,
        'correct': true_label_id == predicted_label_id
    }
    all_results.append(result)


    if true_label_id != predicted_label_id:
        misclassifications.append(result)







🧪 Running inference on the test set...


  0%|          | 0/5176 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperatur

In [None]:
print("\n" + "="*25)
print("Classification Report")
print("="*25)
report = classification_report(
    true_labels,
    predicted_labels,
    target_names=label_names,
    digits=4
)
print(report)

print("\n" + "="*25)
print("Confusion Matrix")
print("="*25)
cm = confusion_matrix(true_labels, predicted_labels, labels=list(label_map.values()))
# Display confusion matrix with labels for clarity
cm_df = pd.DataFrame(cm, index=label_names, columns=label_names)
print(cm_df)



Classification Report
              precision    recall  f1-score   support

      advise     0.1273    0.6711    0.2140       228
      effect     0.1389    0.0147    0.0266       340
   mechanism     0.2252    0.1536    0.1827       384
         int     0.0370    0.0086    0.0140       116
       false     0.9090    0.8074    0.8552      4108

    accuracy                         0.6830      5176
   macro avg     0.2875    0.3311    0.2585      5176
weighted avg     0.7537    0.6830    0.7038      5176


Confusion Matrix
           advise  effect  mechanism  int  false
advise        153       1         26    1     47
effect        200       5         39    4     92
mechanism     202      16         59    6    101
int            17       0          6    1     92
false         630      14        132   15   3317


In [None]:
# NEW: DETAILED MISCLASSIFICATION ANALYSIS
# ==========================================
print("\n" + "="*50)
print("MISCLASSIFICATION ANALYSIS")
print("="*50)

# Group misclassifications by true -> predicted pattern
from collections import defaultdict
error_patterns = defaultdict(list)

for mistake in misclassifications:
    pattern = f"{mistake['true_label']} → {mistake['predicted_label']}"
    error_patterns[pattern].append(mistake)

# Show each error pattern with examples
for pattern, mistakes in error_patterns.items():
    print(f"\n🔍 ERROR PATTERN: {pattern} ({len(mistakes)} cases)")
    print("-" * 60)

    # Show up to 3 examples per pattern
    for i, mistake in enumerate(mistakes[:3]):
        print(f"\nExample {i+1}:")
        print(f"Sentence: {mistake['sentence']}")
        print(f"Expected: {mistake['true_label']} | Got: {mistake['predicted_label']}")

    if len(mistakes) > 3:
        print(f"\n... and {len(mistakes) - 3} more cases of this pattern")

# NEW: CORRECT CLASSIFICATIONS BY CATEGORY (for comparison)
print("\n" + "="*50)
print("CORRECT CLASSIFICATIONS BY CATEGORY")
print("="*50)

correct_by_label = defaultdict(list)
for result in all_results:
    if result['correct']:
        correct_by_label[result['true_label']].append(result)

for label in label_names:
    correct_cases = correct_by_label[label]
    print(f"\n✅ {label.upper()} - {len(correct_cases)} correct predictions")

    # Show 1-2 examples of correct predictions
    for i, case in enumerate(correct_cases[:2]):
        print(f"  Example {i+1}: {case['sentence']}")

print(f"\n📈 Summary: {len(misclassifications)} errors out of {len(all_results)} total ({len(misclassifications)/len(all_results)*100:.1f}% error rate)")


MISCLASSIFICATION ANALYSIS

🔍 ERROR PATTERN: advise → false (47 cases)
------------------------------------------------------------

Example 1:
Sentence: Because there is a theoretical basis that these effects may be additive, use of ergotamine-containing or ergot-type medications (like dihydroergotamine or methysergide) and sumatriptan within 24 hours of each other should be avoided. 
Expected: advise | Got: false

Example 2:
Sentence: Because Matulane exhibits some monoamine oxidase inhibitory activity, sympathomimetic drugs, tricyclic antidepressant drugs (e.g., amitriptyline HCl, imipramine HCl) and other drugs and foods with known high tyramine content, such as wine, yogurt, ripe cheese and bananas, should be avoided. 
Expected: advise | Got: false

Example 3:
Sentence: Because Matulane exhibits some monoamine oxidase inhibitory activity, sympathomimetic drugs, tricyclic antidepressant drugs (e.g., amitriptyline HCl, imipramine HCl) and other drugs and foods with known high tyram