# All

## Set up

In [None]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

In [None]:
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "v3-8"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU zdetection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)


In [None]:
import gin
import subprocess
gin.parse_config_file(
        'gs://t5_training/t5-data/config/pretrained_models_google_base_operative_config.gin'
    )


## Register Tasks

### Mednli

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/mednli/train.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/mednli/test.tsv'
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["mednli: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('mednli')
t5.data.TaskRegistry.add(
    "mednli",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    # We'll use accuracy as our evaluation metric.
    # output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab)),

    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
                ],
    # output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

## Mixtures

In [None]:
t5.data.MixtureRegistry.remove("all_bioT5")
t5.data.MixtureRegistry.add(
    "all_bioT5",
    ['mednli'],
     default_rate=1.0
)

## Define Model

In [None]:
# !gsutil -m rm -r {MODEL_DIR}

In [None]:
# Using pretrained_models from wiki + books
MODEL_SIZE = "base"
# BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
# BASE_PRETRAINED_DIR = "gs://t5_training/models/bio/pmc_v1"
# BASE_PRETRAINED_DIR = "gs://t5_training/models/bio/pubmed_v2"
BASE_PRETRAINED_DIR = "gs://t5_training/models/export_models/bio/pmc_v4_1200k"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
# MODEL_DIR = "gs://t5_training/models/bio/re_v2"
MODEL_DIR = "gs://t5_training/models/bio/mednli_pmc_v4"
MODEL_DIR = os.path.join(MODEL_DIR, MODEL_SIZE)


# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128*2, 8),
    "large": (8, 64*2, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 256, "targets": 15},
    learning_rate_schedule=0.001,
    save_checkpoints_steps=1000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)


## Finetune

In [None]:
FINETUNE_STEPS = 45000

model.finetune(
    mixture_or_task_name="mednli",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

## Predict

In [None]:
tasks = [
         ['mednli', 'mednli'],
         ]
output_dir = "mednli_pmc_v4"

In [None]:
import tensorflow.compat.v1 as tf
# question_1 = "Emerin is a nuclear membrane protein which is missing or defective in Emery-Dreifuss muscular dystrophy (EDMD). It is one member of a family of lamina-associated proteins which includes LAP1, LAP2 and lamin B receptor (LBR). A panel of 16 monoclonal antibodies (mAbs) has been mapped to six specific sites throughout the emerin molecule using phage-displayed peptide libraries and has been used to localize emerin in human and rabbit heart. Several mAbs against different emerin epitopes did not recognize intercalated discs in the heart, though they recognized cardiomyocyte nuclei strongly, both at the rim and in intranuclear spots or channels. A polyclonal rabbit antiserum against emerin did recognize both nuclear membrane and intercalated discs but, after affinity purification against a pure-emerin band on a western blot, it stained only the nuclear membrane. These results would not be expected if immunostaining at intercalated discs were due to a product of the emerin gene and, therefore, cast some doubt upon the hypothesis that cardiac defects in EDMD are caused by absence of emerin from intercalated discs. Although emerin was abundant in the membranes of cardiomyocyte nuclei, it was absent from many non-myocyte cells in the heart. This distribution of emerin was similar to that of lamin A, a candidate gene for an autosomal form of EDMD. In contrast, lamin B1 was absent from cardiomyocyte nuclei, showing that lamin B1 is not essential for localization of emerin to the nuclear lamina. Lamin B1 is also almost completely absent from skeletal muscle nuclei. In EDMD, the additional absence of lamin B1 from heart and skeletal muscle nuclei which already lack emerin may offer an alternative explanation of why these tissues are particularly affected.." 
# question_2 = "Molecular analysis of the APC gene in 205 families: extended genotype-phenotype correlations in FAP and evidence for the role of APC amino acid changes in colorectal cancer predisposition." 
# question_3 = "Who are the 4 members of The Beatles?" 
# question_4 = "How many teeth do humans have?"

# questions = [question_2]

for t in tasks:
  dir = t[0]
  task = t[1]
  input_file = task + '_predict_input.txt'
  output_file = task + '_predict_output.txt'


  # Write out the supplied questions to text files.
  predict_inputs_path = os.path.join('gs://t5_training/t5-data/bio_data', dir, input_file)
  predict_outputs_path = os.path.join('gs://t5_training/t5-data/bio_data', dir, output_dir , MODEL_SIZE, output_file)
  # Manually apply preprocessing by prepending "triviaqa question:".

  # Ignore any logging so that we only see the model's answers to the questions.
  with tf_verbosity_level('ERROR'):
    model.batch_size = 8  # Min size for small model on v2-8 with parallelism 1.
    model.predict(
        input_file=predict_inputs_path,
        output_file=predict_outputs_path,
        # Select the most probable output token at each step.
        # vocabulary=t5.data.SentencePieceVocabulary(vocab),

        temperature=0,
    )

  # The output filename will have the checkpoint appended so we glob to get 
  # the latest.
  prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
  print("Predicted task : " + task)
  print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])

## Scoring

In [None]:
tasks = [
         ['mednli', 'mednli'],
         ]

In [None]:
for task in tasks:
  # t5_training/t5-data/bio_data/euadr/predicted_output
  !gsutil -m cp gs://scifive/finetune/{task[0]}/{output_dir}/{MODEL_SIZE}/{task[1]}_predict_output.txt-* . 
  !gsutil cp gs://scifive/finetune/{task[0]}/{task[1]}_actual_output.txt . 

In [None]:
from sklearn.metrics import f1_score, accuracy_score, classification_report, recall_score, precision_score, precision_recall_fscore_support
import numpy as np
import re
import os

In [None]:
def convert_RE_labels(filename):
    labels = []
    with open(filename, 'r', encoding='utf-8') as file:
        for line in file:
            labels.append(line.strip().upper())
    return labels

In [None]:
checkpoint = 1245000
total_f1 = 0
total_precision = 0
total_recall = 0
anchor_pred_labels = []
anchor_actual_labels = []
for task in tasks:
    d = task[0]
    t = task[1]
    
    pred_file = os.path.join('/content/', t +'_predict_output.txt-%s'%checkpoint)
    actual_file = os.path.join('/content/', t + '_actual_output.txt')
    

    pred_labels = convert_RE_labels(pred_file)
    actual_labels = convert_RE_labels(actual_file)
    print("Report %s:"%t, classification_report(actual_labels, pred_labels, digits=4))
    f1_score(y_pred=pred_labels, y_true=actual_labels, average='micro')
    p,r,f,_ = precision_recall_fscore_support(y_pred=pred_labels, y_true=actual_labels)
    results = dict()
    results["f1 score"] = f[1]
    results["recall"] = r[1]
    results["precision"] = p[1]
    results["specificity"] = r[0]     
    print(t, results) 