In [15]:
import math
from datetime import datetime
import json
import torch

from tqdm import tqdm
import datasets
from collections import Counter
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gc

In [None]:
# install transformers
!pip install transformers &> /dev/null
!pip install datasets &> /dev/null

In [39]:
MODEL_NAME = 'large'
USE_SECTION = False
INCLUDE_ID = True

PROMPT_TEMPLATE = """{}\nQuestion: Does this imply that {}?\nOPTIONS:\nEntailment\nContradiction"""

# Load dataset


we use third party dataset loader: [here](https://huggingface.co/datasets/bigbio/sem_eval_2024_task_2/viewer/sem_eval_2024_task_2_source)


In [None]:
# huggingface dataset
annotations = datasets.load_dataset("bigbio/sem_eval_2024_task_2", name="sem_eval_2024_task_2_source")
raw_texts = datasets.load_dataset("bigbio/sem_eval_2024_task_2", name="sem_eval_2024_task_2_ct")['train']

# Data Preprocessing

In [3]:
# convert CRT to premise, hypothesis and label sets
# use_section: controls whether to extract a specific section or full text
# include_id: controls whether to include "section_title" in text when extracting full text

def ctr_to_full_text(ctr, include_id=False):
  """extract full text from ctr
     include_id: whether to include seciton id for each section sentence list
  """
  if include_id:
    intervention = ["intervention:"] + ctr.get('intervention', [])
    eligibility = ["eligibility:"] + ctr.get('eligibility', [])
    adverse_events = ["adverse_events:"] + ctr.get('adverse_events', [])
    results = ["results:"] + ctr.get('results', [])
  else:
    intervention = ctr.get('intervention', [])
    eligibility = ctr.get('eligibility', [])
    adverse_events = ctr.get('adverse_events', [])
    results = ctr.get('results', [])
  return "\n".join(intervention + eligibility + adverse_events + results)

def get_premise_hypothesis(sample, ctrs, use_section=False, include_id=True):
  """get premise, hypothesis, label, type from a train sample
     use_section: whether to export full ctr or section only
     full_text_include_id: when exporting full ctr, whether to include section id inside presmise
  """
  sample_type = sample["type"]
  section_id = sample["section_id"].lower().replace(" ", "_")

  primary_ctr = ctrs[sample["primary_id"]]
  if use_section:
    primary_text = "\n".join(primary_ctr[section_id])
  else:
    primary_text = ctr_to_full_text(primary_ctr, include_id)

  if sample_type == "Comparison":
    secondary_ctr = ctrs[sample["secondary_id"]]
    if use_section:
      secondary_text = "\n".join(secondary_ctr[section_id])
    else:
      secondary_text = ctr_to_full_text(secondary_ctr, include_id)
    premise = (f"Primary trial evidence are {primary_text}\n and Secondary "
               + f"trial evidence are {secondary_text}.")
  else:
    premise = (f"Primary trial evidence are {primary_text}.")

  hypothesis = sample['statement']
  label = sample['label']
  return premise, hypothesis, label, sample_type

def get_premise_hypothesis_by_section(sample, ctrs):
  """get premise, hypothesis, label, type from a train sample
     use_section: whether to export full ctr or section only
     full_text_include_id: when exporting full ctr, whether to include section id inside presmise
  """

  premises = {}

  primary_ctr = ctrs[sample["primary_id"]]
  sample_type = sample["type"]
  if sample_type == "Comparison":
    secondary_ctr = ctrs[sample["secondary_id"]]
  else:
    secondary_ctr = None

  for section_id in ['intervention', 'eligibility', 'results', 'adverse_events']:
    section_id = section_id.lower().replace(" ", "_")
    primary_text = "\n".join(primary_ctr[section_id])
    if sample_type == "Comparison":
      secondary_text = "\n".join(secondary_ctr[section_id])
      premise = (f"Primary trial evidence are {primary_text}\n and Secondary "
                + f"trial evidence are {secondary_text}.")
    else:
      premise = (f"Primary trial evidence are {primary_text}.")
    premises[section_id] = premise

  hypothesis = sample['statement']
  label = sample['label']

  return premises, hypothesis, label, sample_type

In [2]:
def get_prompt(premise, hypothesis):
  # prompt = """{}\nQuestion: Does this imply that {}?\nOPTIONS:\nEntailment\nContradiction"""
  prompt = PROMPT_TEMPLATE
  return prompt.format(premise, hypothesis)

In [4]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")

In [13]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [5]:
from main import load_data

In [6]:
annotations, id_to_clinical_trial_record = load_data()

In [7]:
# use model/tokenizer to run a prompt.
# output_max_length: controls output token limit. Default will use max_token generic and is only 20.
# recommend 512 or model max to generate long sentences.
def answer(prompt, output_max_length=None):
  inputs = tokenizer(prompt, return_tensors="pt")
  outputs = model.generate(**inputs, max_new_tokens=output_max_length)
  return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [8]:
def ans_to_label(result):
  if result == 'yes':
    return 'entailment'
  elif result == 'no':
    return 'contradiction'
  return result

In [None]:
# generate premise chunks with sliding window method
# make sure stride is large enough so we don't run too many sections per sample
# 256 is probably the smallest you should go. max is 512. *4 for word counts.
def get_windowed_premise(premise, hypothesis, stride=4*256):
    # use 4-1 word-token ratio (english)
    window_size = 4*512 - len(hypothesis) - len(PROMPT_TEMPLATE) - 4
    sections = []
    total_length = len(premise)
    start = 0
    while start < total_length:
      end = min(start + window_size, total_length)
      sections.append(premise[start: end])
      start += stride
    return sections

In [None]:
# generate summarization for text. Long texts will be chunked first then joined.
# concatenate: whether to join chunk summaries by concatenation (true) or LLM summarization again (false)
def get_summarization(premise, concatenate=True):
    tmp = []
    total_length = len(premise)
    chunk_size = math.floor(512 * 4 * 0.9) # estimate chunk size. 4-1 word-token ratio (English) is assumed.
    template = "Please summarize the following and include important details as much as possible: {}"

    start = 0
    while start < total_length:
        end = min(start + chunk_size, total_length)
        # output_max_length controls how many words are generated.
        chunk_summary = answer(template.format(premise[start: end]), output_max_length=512)[0].strip()
        tmp.append(chunk_summary)
        start += chunk_size
    if concatenate:
        return '. '.join(tmp), tmp
    else:
        return answer(template.format('. '.join(tmp)), output_max_length=512)[0], tmp

In [40]:
# test

sample = annotations['validation'][1]
premise, hypothesis, label, sample_type = get_premise_hypothesis(sample, id_to_clinical_trial_record,
                                                                 use_section=USE_SECTION, include_id=INCLUDE_ID)
print(label)

Entailment


In [35]:
set([x['section_id'] for x in annotations['validation']])

{'Adverse Events', 'Eligibility', 'Intervention', 'Results'}

# Section ID Extraction

In [11]:
extraction_prompt = "A premise contains four sections: {}\nA hypothesis describes one of the sections: {}\nDetermine the most relevant section from the four options: intervention, results, eligibility, adverse_events"

In [14]:
def test(sample_idx):
    sample = annotations['validation'][sample_idx]
    premise, hypothesis, label, sample_type = get_premise_hypothesis(sample, id_to_clinical_trial_record,
                                                                 use_section=False, include_id=True)
    prompt = extraction_prompt.format(premise, hypothesis)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(**inputs)
    pred = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower()

    if pred == 'interventions':
        pred = 'intervention'

    true_section_id = sample['section_id'].lower().replace(" ", "_")
    return pred, true_section_id

In [44]:
torch.cuda.empty_cache()

In [16]:
result = []
too_large = []
for i, item in enumerate(tqdm(annotations['validation'])):
   try:
      pred, label = test(i) 
   except RuntimeError as e:
      if "out of memory" in str(e):
         too_large.append(i)
         gc.collect()
         torch.cuda.empty_cache()
         continue
      
   result.append((pred, label))

100%|██████████| 200/200 [06:08<00:00,  1.84s/it]


In [25]:
from sklearn.metrics import classification_report

In [33]:
print(classification_report([x[1] for x in result], [x[0] for x in result]))

                precision    recall  f1-score   support

adverse events       0.89      0.63      0.74        52
   eligibility       1.00      0.50      0.67        56
  intervention       0.46      0.81      0.59        36
       results       0.69      0.89      0.78        56

      accuracy                           0.70       200
     macro avg       0.76      0.71      0.69       200
  weighted avg       0.79      0.70      0.70       200



# Try Summarization

In [6]:
sample_section = '''intervention:
INTERVENTION 1:
  Dovitinib
  Dovitinib 500 mg single oral dose for 5 consecutive days, followed by a 2-day rest period (5 days on/2 days off schedule) for every 28 day cycle.
eligibility:
Inclusion Criteria:
  Patients have histological confirmation of breast carcinoma with a clinical diagnosis of IBC based on presence of inflammatory changes in the involved breast, including diffuse erythema and edema (peau d orange), with or without an underlying palpable mass involving the majority of the skin of the breast. Pathological evidence of dermal lymphatic invasion should be noted but is not required for diagnosis.
  Patients have stage IV disease with local or distant relapse
  Patients have negative HER2 expression by IHC (defined as 0 or1+), or fluorescence in situ hybridization (FISH). If HER2 is 2+, negative HER2 expression must be confirmed by FISH.
  Patients are able to swallow and retain oral medication.
  Patients have Eastern Cooperative Oncology Group (ECOG) performance status 0-2.
  Patients have received two or more standard chemotherapies for metastatic disease and have relapsed.
  Patients have ability and willingness to sign written informed consent.
  Patients are 18 years of age or older.
  Female patients of childbearing potential (A female not free from menses > 2 years or not surgically sterilized) must be willing to use highly effective contraception to prevent pregnancy or agree to abstain from heterosexual activity throughout the study. Highly effective contraception, defined as male condom with spermicide, diaphragm with spermicide, intra-uterine device. Highly effective contraception must be used by both sexes during the study and must be continued for 8 weeks after the end of study treatment. Oral, implantable, or injectable contraceptives may be affected by cytochrome P450 interactions, and are therefore not considered effective for this study.
  Female patients of childbearing potential must have negative serum pregnancy test </=14 days prior to starting study treatment.
  If Patients have been treated with anti-vascular endothelial growth factor (VEGF) agents, such as Bevacizumab, last dose must be > 4 weeks.
  Patients have biopsy tissue of the metastatic disease (including chest wall or regional nodes) available (paraffin blocks or up to 20 unstained slides), if no biopsy tissue available, a biopsy (or thoracentesis if patient has pleural effusion only) of the metastatic disease will be performed to confirm the diagnoses.
  Serum total bilirubin must be within Upper Limited Normal (T. Bilirubin upper limit of normal (ULN)=1.0 mg/dl)
  AST and ALT must be < 2.5 x ULN(with or without liver metastases).
Exclusion Criteria:
  Patients are receiving concurrent anti-cancer therapy (chemotherapy, immunotherapy, radiation therapy and biological therapy) while taking study medication.
  Cirrhosis of liver, or known hepatitis B or C infection have hepatic impairment Child-Pugh Score of B or worse.
  Absolute neutrophil count (ANC) < 1.5
  Patients have an active infection and require IV or oral antibiotics.
  Impaired cardiac function or clinically significant cardiac diseases, including any of the following: a) History or presence of serious uncontrolled ventricular arrhythmias or presence of atrial fibrillation; b) Clinically significant resting bradycardia (< 50 beats per minute); c) left ventricular ejection fraction (LVEF) assessed by 2-D echocardiogram (ECHO) < 50% or lower limit of normal (which ever is higher) or multiple gated acquisition scan (MUGA) < 45% or lower limit of normal (which ever is higher). d) Any of the following within 6 months prior to study entry: myocardial infarction (MI), severe/unstable angina, Coronary Artery Bypass Graft (CABG), Congestive Heart Failure (CHF), Cerebrovascular Accident (CVA), Transient Ischemic Attack (TIA), Pulmonary Embolism (PE); e) Uncontrolled hypertension defined by an SBP>150 and/or a diastolic blood pressure (DBP)>100 mm Hg with or without anti-hypertensive medication.
  History of gastrointestinal disorders (medical disorders or extensive surgery) which may interfere with the absorption of the study drug.
  Patients have a concurrent disease or condition that would make them inappropriate for study participation, or any serious medical disorder that would interfere with patients safety.
  Patients with only locally or regionally confined disease without evidence of metastatic disease.'''

In [None]:
concat_summary, _ = get_summarization(sample_section, concatenate=True)
print(concat_summary)

In [None]:
llm_summary, _ = get_summarization(sample_section, concatenate=False)
print(llm_summary)

*Try directly summary an entire section (obviously over token limit)*

In [None]:
answer('''Please summarize the following and include important details as much as possible: Primary trial evidence are intervention:
INTERVENTION 1:
  Dovitinib
  Dovitinib 500 mg single oral dose for 5 consecutive days, followed by a 2-day rest period (5 days on/2 days off schedule) for every 28 day cycle.
eligibility:
Inclusion Criteria:
  Patients have histological confirmation of breast carcinoma with a clinical diagnosis of IBC based on presence of inflammatory changes in the involved breast, including diffuse erythema and edema (peau d orange), with or without an underlying palpable mass involving the majority of the skin of the breast. Pathological evidence of dermal lymphatic invasion should be noted but is not required for diagnosis.
  Patients have stage IV disease with local or distant relapse
  Patients have negative HER2 expression by IHC (defined as 0 or1+), or fluorescence in situ hybridization (FISH). If HER2 is 2+, negative HER2 expression must be confirmed by FISH.
  Patients are able to swallow anad retain oral medication.
  Patients have Eastern Cooperative Oncology Group (ECOG) performance status 0-2.
  Patients have received two or more standard chemotherapies for metastatic disease and have relapsed.
  Patients have ability and willingness to sign written informed consent.
  Patients are 18 years of age or older.
  Female patients of childbearing potential (A female not free from menses > 2 years or not surgically sterilized) must be willing to use highly effective contraception to prevent pregnancy or agree to abstain from heterosexual activity throughout the study. Highly effective contraception, defined as male condom with spermicide, diaphragm with spermicide, intra-uterine device. Highly effective contraception must be used by both sexes during the study and must be continued for 8 weeks after the end of study treatment. Oral, implantable, or injectable contraceptives may be affected by cytochrome P450 interactions, and are therefore not considered effective for this study.
  Female patients of childbearing potential must have negative serum pregnancy test </=14 days prior to starting study treatment.
  If Patients have been treated with anti-vascular endothelial growth factor (VEGF) agents, such as Bevacizumab, last dose must be > 4 weeks.
  Patients have biopsy tissue of the metastatic disease (including chest wall or regional nodes) available (paraffin blocks or up to 20 unstained slides), if no biopsy tissue available, a biopsy (or thoracentesis if patient has pleural effusion only) of the metastatic disease will be performed to confirm the diagnoses.
  Serum total bilirubin must be within Upper Limited Normal (T. Bilirubin upper limit of normal (ULN)=1.0 mg/dl)
  AST and ALT must be < 2.5 x ULN(with or without liver metastases).
Exclusion Criteria:
  Patients are receiving concurrent anti-cancer therapy (chemotherapy, immunotherapy, radiation therapy and biological therapy) while taking study medication.
  Cirrhosis of liver, or known hepatitis B or C infection have hepatic impairment Child-Pugh Score of B or worse.
  Absolute neutrophil count (ANC) < 1.5
  Patients have an active infection and require IV or oral antibiotics.
  Impaired cardiac function or clinically significant cardiac diseases, including any of the following: a) History or presence of serious uncontrolled ventricular arrhythmias or presence of atrial fibrillation; b) Clinically significant resting bradycardia (< 50 beats per minute); c) left ventricular ejection fraction (LVEF) assessed by 2-D echocardiogram (ECHO) < 50% or lower limit of normal (which ever is higher) or multiple gated acquisition scan (MUGA) < 45% or lower limit of normal (which ever is higher). d) Any of the following within 6 months prior to study entry: myocardial infarction (MI), severe/unstable angina, Coronary Artery Bypass Graft (CABG), Congestive Heart Failure (CHF), Cerebrovascular Accident (CVA), Transient Ischemic Attack (TIA), Pulmonary Embolism (PE); e) Uncontrolled hypertension defined by an SBP>150 and/or a diastolic blood pressure (DBP)>100 mm Hg with or without anti-hypertensive medication.
  History of gastrointestinal disorders (medical disorders or extensive surgery) which may interfere with the absorption of the study drug.
  Patients have a concurrent disease or condition that would make them inappropriate for study participation, or any serious medical disorder that would interfere with patients safety.
  Patients with only locally or regionally confined disease without evidence of metastatic disease.
       ''')

**Use summarizaiton technique to run through dev set**

In [None]:
def run_eval(method="base", include_id=INCLUDE_ID):
  logs = [] # record detailed logs/inner loop results
  acc = [] # record [pred, label] pairs

  logs.append(f'method: {method}, include_id: {include_id}')

  for i, instance in enumerate(tqdm(annotations['validation'])):

    logs.append(f'sample: {i}')
    premise, hypothesis, label, sample_type = get_premise_hypothesis(instance, id_to_clinical_trial_record,
                                                                    use_section=False, include_id=include_id)
    if method == 'base':
      prompt_from_sample = get_prompt(premise, hypothesis)
      final_result = ans_to_label(answer(prompt_from_sample)[0].lower())
    elif method == 'sliding_window':
      # stride size should be large enough to generate only a few sections per sample but smaller than 512
      windowed_premises = get_windowed_premise(premise, hypothesis, stride=4*364)
      tmp = []
      for j, _premise in enumerate(windowed_premises):
        _result = ans_to_label(answer(get_prompt(_premise, hypothesis))[0].lower())
        tmp.append(_result)
        logs.append(f'{j}: {_result}')
      counts = Counter(tmp)
      final_result = max(counts, key=counts.get)
    elif method == 'summarize':
      # get summary for each section
      premises_by_section, _, _, _ = get_premise_hypothesis_by_section(instance, id_to_clinical_trial_record)
      final_premise = []
      for section_id, premise in premises_by_section.items():
        section_summary, _ = get_summarization(premise, concatenate=False)
        final_premise.append(section_id + ": " + section_summary)

      final_premise = '\n'.join(final_premise)
      final_result = ans_to_label(answer(get_prompt(final_premise, hypothesis)))
    elif method == 'summarize-concat':
      premises_by_section, _, _, _ = get_premise_hypothesis_by_section(instance, id_to_clinical_trial_record)
      final_premise = []
      for section_id, premise in premises_by_section.items():
        section_summary, _ = get_summarization(premise, concatenate=True)
        final_premise.append(section_id + ": " + section_summary)

      final_premise = '\n'.join(final_premise)
      final_result = ans_to_label(answer(get_prompt(final_premise, hypothesis)))

    acc.append([final_result, label.lower()])
    logs.append(f'final: {final_result}, label: {label}')

  print('acc: ')
  accuracy = sum([1 for x in acc if x[0] == x[1]]) / len(acc)
  print(accuracy)

  timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
  printed = {
    "acc": accuracy,
    "result": acc,
    "logs": logs
  }

  with open(f'result_{timestamp_str}.json', 'w') as f:
      json.dump(printed, f)


# Caching models using gdrive?


https://gist.github.com/tueda/ec0181f0f4c8961d49dc659f79cbfd4a

In [None]:
print(premise)

# Reproduction from SEMEVAL 2023 FlanT5 paper

https://github.com/kamalkraj/NLI4CT

In [None]:
import os
import json

import torch
import tqdm

from transformers import T5Tokenizer, T5ForConditionalGeneration

In [1]:
len('''Patients have histological confirmation of breast carcinoma with a clinical diagnosis of IBC based on presence of inflammatory changes in the involved breast, including diffuse erythema and edema (peau d orange), with or without an underlying palpable mass involving the majority of the skin of the breast. Pathological evidence of dermal lymphatic invasion should be noted but is not required for diagnosis.
  Patients have stage IV disease with local or distant relapse''')

471

In [2]:
471/4

117.75