<a href="https://colab.research.google.com/github/fedassembly/tensorflow-cert-prep/blob/main/09b_Milestone_Project_2_SkimLit_%F0%9F%93%84%F0%9F%94%A5_Reload_and_test_on_papers_from_the_wild.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from bs4 import BeautifulSoup
import tensorflow as tf
import requests
import re
import json
from pathlib import Path

## Reload model

In [None]:
DATA_DIR = "drive/MyDrive/PubMedProject/data/"
LABELS = {"BACKGROUND": 0, "CONCLUSIONS": 1, "METHODS": 2, "OBJECTIVE": 3, "RESULTS": 4}

label_lookup = {v: k for k, v in LABELS.items()}

def tfrecord_to_tensors(example_proto):
  """Parse features as tensors from TFRecordDatset."""
  feature_description = {
      "text": tf.io.FixedLenFeature([], tf.string),
      "line": tf.io.FixedLenFeature([], tf.int64),
      "total_lines": tf.io.FixedLenFeature([], tf.int64),
      "label": tf.io.FixedLenFeature([len(LABELS)], tf.float32),
  }
  example = tf.io.parse_single_example(example_proto, feature_description)
  text = example["text"]
  line = example["line"]
  total_lines = example["total_lines"]
  label = example["label"]
  return {"text": text, "line": line, "total_lines": total_lines}, label

def create_dataset(file_path):
  """Create tf.data.Dataset from tfrecord."""
  input_file = Path(file_path)
  stem = input_file.stem
  with open(input_file.with_suffix(".metadata"), "r") as meta_file:
    metadata = json.load(meta_file)
    n = metadata["num_examples"]
  files = list(input_file.parent.glob(f"{stem}*.tfrecord"))
  raw_dataset = tf.data.TFRecordDataset(files)
  raw_dataset = raw_dataset.apply(tf.data.experimental.assert_cardinality(n)) # so .cardinality() and len() work
  ds = raw_dataset.map(tfrecord_to_tensors, num_parallel_calls=tf.data.AUTOTUNE)
  return ds

dev_ds = create_dataset(DATA_DIR + "dev.tfrecord")
dev_ds = dev_ds.batch(32).prefetch(tf.data.AUTOTUNE)

In [None]:
tf.keras.mixed_precision.set_global_policy("mixed_float16")

model = tf.keras.saving.load_model("drive/MyDrive/PubMedProject/model/pubmed_abstract.tf")

In [None]:
model.evaluate(dev_ds)



[0.5067348480224609, 0.873593270778656]

## Test model on abstract from the wild

In [None]:
import string
from IPython.display import display, HTML


def clean_abstract(abstract):
  abstract = abstract.encode("ascii",errors="ignore").decode()
  eos_pattern = re.compile("\.(?=\s*(?:[A-Z]|$)) ")
  abstract = re.sub(eos_pattern, " .\n", abstract)
  digit_pattern = re.compile("(\d+(?:,\d{3})*(?:\.\d+)?|\.\d+)")
  abstract = re.sub(digit_pattern, "@", abstract)
  punct_pattern = re.compile("([%s])" % re.escape("".join(set(string.punctuation) - set("@.-"))))
  abstract = re.sub(punct_pattern, r" \1 ", abstract)
  abstract = re.sub(r"(\s)\1+", r"\1", abstract)
  return abstract.encode("ascii",errors="ignore").decode()

class PaperNotRCT(Exception):
    def __init__(self, types):
        self.types = "', '".join(types) + "'"
        self.message = f"Paper must be 'Randomized Controlled Trial', but it's '{self.types}"
        super().__init__(self.message)

def enhance_pubmed_abstract(paper_id):
  r = requests.get("https://pubmed.ncbi.nlm.nih.gov/" + str(paper_id))
  doc = BeautifulSoup(r.text, "html.parser")
  publication_types = [b.text.strip() for b in doc.find(id="publication-types").find_all("button")]
  if "Randomized Controlled Trial" not in publication_types:
    raise PaperNotRCT(publication_types)
  abstract = doc.find(id = "eng-abstract")
  if len(abstract.find_all("strong", {"class": "sub-title"})) > 0:
    display(HTML("<h2>Original:</h2>"))
    display(HTML("".join([str(tag) for tag in abstract.find_all("p")])))
    display(HTML("<h2>Enhanced:</h2>"))
    display(HTML("<p>N/A</p>"))
    return None
  abstract = abstract.text.strip()
  eos_pattern = re.compile("(\.(?=\s*(?:[A-Z]|$)))")
  abstract_split = re.split(eos_pattern, abstract)
  abstract_split = [abstract_split[i].strip() + abstract_split[i + 1] for i in range(0, len(abstract_split) - 1, 2)]
  abstract_clean = clean_abstract(abstract)
  abstract_split_clean = abstract_clean.splitlines()
  assert len(abstract_split) == len(abstract_split_clean), "Number of lines must be the same"
  inputs = {
      "text": tf.constant(abstract_split_clean, dtype=tf.string),
      "line": tf.constant(range(1, len(abstract_split_clean)+1), dtype=tf.int64),
      "total_lines": tf.constant(len(abstract_split_clean), shape=(len(abstract_split_clean),), dtype=tf.int64)
  }
  pred_probas = model.predict(inputs, verbose=0)
  pred_max = tf.argmax(pred_probas, axis=1).numpy()
  sections = [] # list of tuples (label, text)
  current_section = []
  prev_label = label_lookup[pred_max[0]]
  for i, p in enumerate(pred_max):
    label = label_lookup[p]
    if label != prev_label:
      # append prev section and start a new one
      sections.append((prev_label, " ".join(current_section)))
      current_section = [abstract_split[i]]
      prev_label = label
    else:
      # if label is same as previous iteration keep adding
      current_section.append(abstract_split[i])
  # same as above but outside the loop for the very last section
  sections.append((prev_label, " ".join(current_section)))

  enhanced_abstract = "".join(f"<p><b>{label.title()}:</b> {text}</p>" for label, text in sections)
  display(HTML("<h2>Original:</h2>"))
  display(HTML(abstract))
  display(HTML("<h2>Enhanced:</h2>"))
  display(HTML(enhanced_abstract))

In [None]:
enhance_pubmed_abstract(31471173)

In [None]:
enhance_pubmed_abstract(24633056)

In [None]:
enhance_pubmed_abstract(29902546)

PaperNotRCT: ignored