In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import sys
sys.path.append('../')

In [24]:
import os
import logging
import shutil
from datetime import datetime

import numpy as np
import random

from models import t5_utils, t5_fp16_utils, bert_utils, bart_utils

from tqdm.auto import tqdm

from transformers import (
    get_scheduler,
    DataCollatorForSeq2Seq,
    default_data_collator,
)

from accelerate import Accelerator

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer

from torch.optim import AdamW
from torch.utils.data import DataLoader

from datasets import Dataset


In [21]:
from datasets import load_dataset
oqa = load_dataset("m-rousseau/oqa-v1")



  0%|          | 0/3 [00:00<?, ?it/s]

In [87]:
from train import EvaluateBERT
import tomli

cfg_toml = '''
[mode]
runmode = "bert-evaluate"

[model]
name = "pubmedbert-base-presft"
max_seq_len = 512
checkpoint = "m-rousseau/pubmedbert-presft-oqa"
bitfit = false

[dataset]
repository = "m-rousseau/oqa-v1"

[tokenizer]
checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
append_special_tokens = false

[hyperparameters]
validation_batch_size = 8
seed = 0

[misc]
save_dir = "./"
'''

config = tomli.loads(cfg_toml)
runmode = config["mode"]["runmode"]
model_config = config["model"]
dataset_config = config["dataset"]
tokenizer_config = config["tokenizer"]
hyperparameter_config = config["hyperparameters"]
misc_config = config["misc"]

In [97]:
bert_config = bert_utils.BERTCFG(
    name=model_config["name"],
    model_checkpoint=model_config["checkpoint"],
    tokenizer_checkpoint=tokenizer_config["checkpoint"],
    checkpoint_savedir=misc_config["save_dir"],
    max_length=model_config["max_seq_len"],
    seed=hyperparameter_config["seed"],
    bitfit=model_config["bitfit"],
    append_special_token=tokenizer_config["append_special_tokens"],
)

config = bert_utils.setup_evaluate_oqa(dataset_config["repository"], bert_config)

evaluater = EvaluateBERT(config)
f1, em = evaluater()



  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] :: 2023-06-21 08:36:45,190 @ bert-utils :: test dataset loaded
[INFO] :: 2023-06-21 08:36:46,886 @ bert-utils :: model and tokenizer initialized
[INFO] :: 2023-06-21 08:36:46,915 @ bert-utils :: Test dataset processed and tokenized : n = 139
[INFO] :: 2023-06-21 08:36:47,081 @ eval_logger :: Test dataloader created


  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/114 [00:00<?, ?it/s]

[INFO] :: 2023-06-21 08:40:36,527 @ eval_logger :: 
************************************************** 
Evaluation results 
model : pubmedbert-base-presft 
 > F1 = 76.31436227856749 
 > EM = 51.75438596491228 
**************************************************


In [89]:
# evaluate (same configuration simply pass new checkpoint)
bert_config = bert_utils.BERTCFG(
    name=model_config["name"],
    model_checkpoint=model_config["checkpoint"],
    tokenizer_checkpoint=tokenizer_config["checkpoint"],
    checkpoint_savedir=misc_config["save_dir"],
    max_length=model_config["max_seq_len"],
    seed=hyperparameter_config["seed"],
    bitfit=model_config["bitfit"],
    append_special_token=tokenizer_config["append_special_tokens"],
)
bert_config.name = "pubmedbert-presft-evaluate"

bert_config.model, bert_config.tokenizer = bert_utils.bert_init(
    bert_config.model_checkpoint, bert_config.tokenizer_checkpoint)

bert_config.test_dataset = oqa["validation"]
bert_config.test_batches = bert_utils.prepare_inputs(
        oqa["validation"],
        bert_config.tokenizer,
        stride=bert_config.stride,
        max_len=bert_config.max_length,
        padding=bert_config.padding,
        subset="eval",
)
print(bert_config)
evaluater = EvaluateBERT(bert_config)

seqs, samps = evaluater(return_answers=True, multiple_answers=True)

[INFO] :: 2023-06-20 20:04:04,634 @ bert-utils :: Test dataset processed and tokenized : n = 121
[INFO] :: 2023-06-20 20:04:04,696 @ eval_logger :: Test dataloader created



BERT-like model configuration
************************************
        Name : pubmedbert-presft-evaluate
        Model checkpoint : m-rousseau/pubmedbert-presft-oqa
        Tokenizer checkpoint : microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
        Max sequence length : 512
        Hyperparameters :
                bitfit=False,
                lr=2e-05,
                lr_scheduler=False,
                num_epochs=12,
                batch_size=4
************************************
        


  0%|          | 0/31 [00:00<?, ?it/s]

  0%|          | 0/92 [00:00<?, ?it/s]

[INFO] :: 2023-06-20 20:07:12,068 @ eval_logger :: 
************************************************** 
Evaluation results 
model : pubmedbert-presft-evaluate 
 > F1 = 81.36523674581863 
 > EM = 53.26086956521739 
**************************************************


In [2]:
import re
import string
from collections import Counter

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

In [5]:
#"a vertical extrusive force" "extrusive force"
#``3.5 to 7 times'' & ``the stiffnesses of stainless steel wires were significantly greater than those of the nickel-titanium wires by 3.5 to 7 times'' & \\
#``mri is the preferred imaging technique'' & ``mri''  &\\

f1_score("mri", "mri is the preferred imaging technique")

0.33333333333333337

In [43]:
# a) compute the max_f1 per id and get the mean f1
oqa["validation"][1]["id"]

# b) compare with default -- highest logit f1
samps

{'id': ['addd66d0-c48f-4447-a3f8-8dc65a99e276',
  '62f16e43-1abd-4940-ad88-9fdc801804b1',
  '22c313f9-50d9-46aa-86e2-3c7d4bdec018',
  'fd3dde24-9155-4cd3-9519-a2a7a9583bce',
  'abfa396b-1b92-4b80-bfce-dc01f320138a',
  '1280d2e0-c2f0-42da-86af-b69d3c350a5b',
  'b7820340-732d-4790-bed6-9c0038b60cbf',
  'c3476eb7-296f-4bf4-b900-9d2454488de0',
  'fc714e90-bed2-43fa-9930-7ea0bf12ab99',
  'db244642-ea9e-4170-b9b5-4e95217c591d',
  'c49a2285-2eac-443a-b646-d34ad8395769',
  'f7302341-03d3-4d3d-b2c7-22cf7f548504',
  '0fdb4c86-50c2-4cd0-ba50-18b04618348a',
  '858f2028-0477-4b5a-9565-f0b7411fca09',
  'db2e9e06-97c1-454b-a0f4-9e994a8ce352',
  '98b0350c-8729-4c2e-9f69-dd1365c8817d',
  '60d0daaa-b33a-4532-8ddf-172d04ae06d2',
  '1c638f04-d62f-40ca-a486-cb97ff04838b',
  'bf6cf487-893f-47e0-ac7a-a36cf03486a1',
  '2e3ecbba-7380-49ab-a0bb-55ad30c5a1c1',
  '2277ca27-35dc-4891-9dc2-130a2d823537',
  '2bb9c2cd-6939-4e5a-bdb5-9ad924bf95a2',
  '2471d4dd-5f9b-4e33-81c5-f034050216a1',
  'ef14135e-7ea5-419e-9909-f

In [94]:
max_f1s = []
for i in range(len(oqa["validation"])):
    ground_truth = oqa["validation"][i]["answers"]["text"][0]
    top_k = sorted(samps["samples"][i], key=lambda x: x[1], reverse=True)[:12]
    sampled_answers = [x[0] for x in top_k]
    f1_list = []
    
    for s in sampled_answers:
        f1_list.append(f1_score(s, ground_truth))

    f1_list.sort(reverse=True)
    max_f1s.append(f1_list[0])
    #import numpy as np
np.mean(max_f1s) # Woooooooo!!!!!

0.9628436559872703

In [75]:
import numpy as np
np.mean(max_f1s) # Woooooooo!!!!! 90+ (which reflects the actual accuracy of the model)

0.9170560751065272