In [None]:
import torch

from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gc
from sklearn.metrics import classification_report

from main import load_data

In [None]:
MODEL_NAME = 'large'
USE_SECTION = False
INCLUDE_ID = True

In [None]:
# convert CRT to premise, hypothesis and label sets
# use_section: controls whether to extract a specific section or full text
# include_id: controls whether to include "section_title" in text when extracting full text

def ctr_to_full_text(ctr, include_id=False):
  """extract full text from ctr
     include_id: whether to include seciton id for each section sentence list
  """
  if include_id:
    intervention = ["intervention:"] + ctr.get('intervention', [])
    eligibility = ["eligibility:"] + ctr.get('eligibility', [])
    adverse_events = ["adverse_events:"] + ctr.get('adverse_events', [])
    results = ["results:"] + ctr.get('results', [])
  else:
    intervention = ctr.get('intervention', [])
    eligibility = ctr.get('eligibility', [])
    adverse_events = ctr.get('adverse_events', [])
    results = ctr.get('results', [])
  return "\n".join(intervention + eligibility + adverse_events + results)

def get_premise_hypothesis(sample, ctrs, use_section=False, include_id=True):
  """get premise, hypothesis, label, type from a train sample
     use_section: whether to export full ctr or section only
     full_text_include_id: when exporting full ctr, whether to include section id inside presmise
  """
  sample_type = sample["type"]
  section_id = sample["section_id"].lower().replace(" ", "_")

  primary_ctr = ctrs[sample["primary_id"]]
  if use_section:
    primary_text = "\n".join(primary_ctr[section_id])
  else:
    primary_text = ctr_to_full_text(primary_ctr, include_id)

  if sample_type == "Comparison":
    secondary_ctr = ctrs[sample["secondary_id"]]
    if use_section:
      secondary_text = "\n".join(secondary_ctr[section_id])
    else:
      secondary_text = ctr_to_full_text(secondary_ctr, include_id)
    premise = (f"Primary trial evidence are {primary_text}\n and Secondary "
               + f"trial evidence are {secondary_text}.")
  else:
    premise = (f"Primary trial evidence are {primary_text}.")

  hypothesis = sample['statement']
  label = sample['label']
  return premise, hypothesis, label, sample_type

def get_premise_hypothesis_by_section(sample, ctrs):
  """get premise, hypothesis, label, type from a train sample
     use_section: whether to export full ctr or section only
     full_text_include_id: when exporting full ctr, whether to include section id inside presmise
  """

  premises = {}

  primary_ctr = ctrs[sample["primary_id"]]
  sample_type = sample["type"]
  if sample_type == "Comparison":
    secondary_ctr = ctrs[sample["secondary_id"]]
  else:
    secondary_ctr = None

  for section_id in ['intervention', 'eligibility', 'results', 'adverse_events']:
    section_id = section_id.lower().replace(" ", "_")
    primary_text = "\n".join(primary_ctr[section_id])
    if sample_type == "Comparison":
      secondary_text = "\n".join(secondary_ctr[section_id])
      premise = (f"Primary trial evidence are {primary_text}\n and Secondary "
                + f"trial evidence are {secondary_text}.")
    else:
      premise = (f"Primary trial evidence are {primary_text}.")
    premises[section_id] = premise

  hypothesis = sample['statement']
  label = sample['label']

  return premises, hypothesis, label, sample_type

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [None]:
annotations, id_to_clinical_trial_record = load_data()

In [None]:
extraction_prompt = "A premise contains four sections: {}\nA hypothesis describes one of the sections: {}\nDetermine the most relevant section from the four options: intervention, results, eligibility, adverse_events"

In [None]:
def test(sample_idx):
    sample = annotations['validation'][sample_idx]
    premise, hypothesis, label, sample_type = get_premise_hypothesis(sample, id_to_clinical_trial_record,
                                                                 use_section=False, include_id=True)
    prompt = extraction_prompt.format(premise, hypothesis)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(**inputs)
    pred = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower()

    # observed variations
    if pred == 'interventions':
        pred = 'intervention'

    true_section_id = sample['section_id'].lower().replace(" ", "_")
    return pred, true_section_id

In [None]:
result = []
too_large = []
for i, item in enumerate(tqdm(annotations['validation'])):
   try:
      pred, label = test(i) 
   except RuntimeError as e:
      if "out of memory" in str(e):
         too_large.append(i)
         gc.collect()
         torch.cuda.empty_cache()
         continue
      
   result.append((pred, label))

In [None]:
#TODO: handle "unknown"
print(classification_report([x[1] for x in result], [x[0] for x in result]))