### 256 Final Project

Shibani Likhite {slikhite@ucsd.edu}

Savani Suranglikar {ssuranglikar@ucsd.edu}

Carl Chow {cychow@ucsd.edu}

In [None]:
### INSTALL stuff we need

# some code from here (A2)
# https://colab.research.google.com/github/huggingface/datasets/blob/main/notebooks/Overview.ipynb#scrollTo=7T5AG3BxvSUr

# install HuggingFace transformers
!pip install transformers[torch]
# install datasets
!pip install datasets
!pip install evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers[torch]
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m72.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers[torch])
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers[torch])
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m114.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers[torch])
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_6

In [None]:
### IMPORT stuff we need

import pandas as pd
import torch
import sys
import numpy as np
import time
import datetime
import evaluate
import random

from dataclasses import dataclass

from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy

from typing import Optional, Union

from datasets import list_datasets, load_dataset, ClassLabel

from pprint import pprint

from transformers import BertForMultipleChoice, BertConfig, BertTokenizer
from transformers import TrainingArguments, Trainer

from torch.optim import AdamW

from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

In [None]:
### GPU

# free unused GPU memory if able
torch.cuda.empty_cache()

# Confirm that the GPU is detected
if torch.cuda.is_available():

    # Get the GPU device name.
    device_name = torch.cuda.get_device_name()
    n_gpu = torch.cuda.device_count()
    print(f"Found device: {device_name}, n_gpu: {n_gpu}")
    device = torch.device("cuda")

else:
    # if no GPU, use CPU (but training takes forever)
    print('No GPU, using CPU instead')
    device = 'cpu'

Found device: Tesla T4, n_gpu: 1


In [None]:
### Downloading and loading a dataset
dataset = load_dataset('medmcqa')

print(f"👉 Dataset Size : {dataset.shape}")
print("\n👉 First item 'dataset['train'][0]' :")
pprint(dataset['train'][0])

Downloading builder script:   0%|          | 0.00/5.35k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.41k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading and preparing dataset medmcqa/default to /root/.cache/huggingface/datasets/medmcqa/default/1.1.0/f2fdfa9ccfbf9d148c0639e6afe3379f3c7e95c4d52d5e68ec1156e5004bd880...


Downloading data:   0%|          | 0.00/55.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Dataset medmcqa downloaded and prepared to /root/.cache/huggingface/datasets/medmcqa/default/1.1.0/f2fdfa9ccfbf9d148c0639e6afe3379f3c7e95c4d52d5e68ec1156e5004bd880. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

👉 Dataset Size : {'train': (182822, 11), 'test': (6150, 11), 'validation': (4183, 11)}

👉 First item 'dataset['train'][0]' :
{'choice_type': 'single',
 'cop': 2,
 'exp': 'Chronic urethral obstruction because of urinary calculi, prostatic '
        'hyperophy, tumors, normal pregnancy, tumors, uterine prolapse or '
        'functional disorders cause hydronephrosis which by definition is used '
        'to describe dilatation of renal pelvis and calculus associated with '
        'progressive atrophy of the kidney due to obstruction to the outflow '
        'of urine Refer Robbins 7yh/9,1012,9/e. P950',
 'id': 'e9ad821a-c438-4965-9f77-760819dfa155',
 'opa': 'Hyperplasia',
 'opb': 'Hyperophy',
 'opc': 'Atrophy',
 'opd': 'Dyplasia',
 'question': 'Chronic urethral obstruction due to benign prismatic hyperplasia '
             'can lead to the following change in kidney parenchyma',
 'subject_name': 'Anatomy',
 'topic_name': 'Urinary tract'}


In [None]:
### Visualize some random samples from the dataset
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

dict_items([('Anatomy', 14560), ('Biochemistry', 8282), ('Surgery', 16862), ('Ophthalmology', 6932), ('Physiology', 8830), ('Social & Preventive Medicine', 11882), ('Gynaecology & Obstetrics', 10013), ('Anaesthesia', 3172), ('Psychiatry', 4442), ('Microbiology', 11314), ('Medicine', 17887), ('Pharmacology', 13758), ('Dental', 8938), ('ENT', 4919), ('Forensic Medicine', 5900), ('Pediatrics', 8037), ('Orthopaedics', 2999), ('Radiology', 4395), ('Pathology', 14884), ('Skin', 1771), ('Unknown', 3045)])

In [None]:
### GET / FILTER certain subject questions for finetuning

ds_anatomy = dataset.filter(lambda d: d['subject_name'] == 'Anatomy')

ds_surgery = dataset.filter(lambda d: d['subject_name'] == 'Surgery')

ds_medicine = dataset.filter(lambda d: d['subject_name'] == 'Medicine')


# can add more if needed

# show some samples from training set
show_random_elements(ds_anatomy['train'])



Unnamed: 0,id,question,opa,opb,opc,opd,cop,choice_type,exp,subject_name,topic_name
0,e24c14a9-8ccc-433c-a75e-6137accfc9b9,Referred otalgia from base of tongue or oropharynx is carried by nerve?,Cranial nerve V,Cranial nerve VII,Cranial nerve IX,Cranial nerve X,c,single,"The Jacobson nerve, tympanic branch of glossopharyngeal nerve (cranial nerve IX) directly innervates the ear but also has pharyngeal, lingual, and tonsillar branches to supply the posterior one-third poion of the tongue, tonsillar fossa, pharynx, eustachian tube, and parapharyngeal and retropharyngeal spaces. So any pathology involving those areas can lead to referred otalgia. Must know: Referred Otalgia: the source of the pain does not reside within the ear but, rather it originates from a source distant from the ear hence it is called as ""referred otalgia"". Any pathology residing within the sensory net of cranial nerves V, VII, IX, and X and upper cervical nerves C2 and C3 can potentially cause referred otalgia.",Anatomy,
1,67ac3aee-71ba-4fcc-91ca-0f16c3668dcc,The gastroduodenal aery is derived from:,Celiac aery,Hepatic aery,Splenic aery,Cystic aery,b,single,B i.e. Hepatic aery,Anatomy,
2,8ad77e33-b91d-4026-99d9-1bc09d7b26e1,A 35 year old woman suffers severe chest trauma. She is unconscious and her blood pressure is substantially decreased. She has sustained a tear in one of the pulmonary veins at the point at which the vein enters the hea. Into which of the following spaces is the patient hemorrhaging?,Between the epicardium and the parietal pericardium,Between the parietal pericardium and the fibrous pericardium,Between the fibrous pericardium and the parietal pleura,Between the myocardium and the epicardium,a,multi,"The pericardial space is located between the epicardium (also known as the visceral pericardium) and the parietal pericardium. A tear of a blood vessel immediately outside of the hea will cause bleeding into the pericardial space. This accumulation of blood in the pericardial space causes increased pressure on the hea, which restricts filling of the hea during diastole (cardiac tamponade). This reduced filling results in reduced cardiac output and reduced blood pressure. The region between the fibrous pericardium and the parietal pleura is outside of the pericardial space. It is pa of the mediastinum and it is in this region in which structures such as the vagus nerve and the phrenic nerve are found. The epicardium is fused to the myocardium and is the outer layer of the hea wall. There is no space between the epicardium and the myocardium. The parietal pericardium and the fibrous pericardium are fused into a single layer that forms the outer wall of the pericardial space. There is no space between the parietal pericardium and the fibrous pericardium.",Anatomy,
3,df976d65-9ca6-4357-a499-b14f43c91f1d,A 14 week post natal women presents with fluctuant breast swelling. What should be the treatment,Incision and drainage,Continue breast feeding with antibiotics,Analgesics,Repeated aspiration under antibiotic cover,d,single,"Bacterial mastitis Most commonly associated with lactation in majority of cases Causative organism--mostly S.aureus. Ascending infection from a sore and cracked nipple may initiate the mastitis Or lactiferous ducts will first become blocked by epithelial debris leading to stasis. Once within the ampulla of the duct, staphylococcus cause clotting of milk and within this clot organisms multiply. Clinical features: The affected breast or more usually a segment of it presents the classical signs of acute inflammation. Early on this is a generalised cellulitis but later an abscess will form Treatment: During cellulitis stage--patient should be treated with an appropriate antibiotic, such as flucloxacillin or coamoxiclav.Feeding from the affected side may continue if the patient can manage. Suppo of the breast, local heat and analgesia will help to relieve pain If an antibiotic is used in the presence of undrained pus, an antibioma may form. This is a large sterile, brawny oedematous swelling that takes many weeks to resolve. At present advice is repeated aspirations under antibiotic cover (if necessary ultrasound for localisation) be performed. This often allows resolution without the need for an incision and will also allow the patient to continue breast feeding. Presence of pus can be confirmed by needle aspiration and the pus should be sent for bacteriological culture.",Anatomy,Endocrinology and breast
4,0af8efbf-f50f-43c4-a82d-6d7617aa71ed,"Carpel tunnel syndrome is caused by all, EXCEPT?",Amylodosis,Hypothyroidism,Addisson's disease,Diabetes mellitus,c,multi,"Carpal tunnel syndrome (tardy median palsy) is the result of compression of the median nerve within the carpal tunnel. Conditions leading to CTS are: pregnancy, history of repetitive use of hands, following injury of the wrist, diabetes mellitus, rheumatoid ahritis, inflammatory tenosynovitis, myxedema, localized amyloidosis in localized kidney disease, sarcoidosis, leukemia, acromegaly and hyperparathyroidism. Patients usually presents with pain, burning, and tingling in the distribution of the median nerve. On examination, weakness or atrophy, especially of the thenar eminence is noted. Tinel sign and phalen sign is positive. Tinel sign is tingling or shock-like pain elicited by tapping the volar surface of the wrist; Phalen sign is pain or paresthesia in the distribution of the median nerve when the patient flexes both wrists to 90 degrees for 60 seconds.",Anatomy,


({'train': (14560, 11), 'test': (259, 11), 'validation': (234, 11)},
 {'train': (16862, 11), 'test': (501, 11), 'validation': (369, 11)},
 {'train': (17887, 11), 'test': (372, 11), 'validation': (295, 11)})

In [None]:
### Choose our Model

# possible model options
# model_name = 'bert-base-uncased'
# model_name = 'dmis-lab/biobert-base-cased-v1.2'

model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'

model = BertForMultipleChoice.from_pretrained(model_name)

# run model on GPU if possible
if torch.cuda.is_available():
    model.cuda()

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForMultipleChoice: ['cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were 

In [None]:
### functions to preprocess (tokenize) data for training

opts = ['opa', 'opb', 'opc', 'opd']
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True, use_fast=True)

max_len = 256
explanations = False # whether to always append explanations when training
use_random = True # whether to randomly choose to add explanations

# preprocess data (for training only)
def preprocess_function(examples, p=0.5):
    # Repeat each first sentence four times to go with the four possibilities of second sentences.
    first_sentences = [[q] * 4 for q in examples['question']]
    # Grab all second sentences possible for each context (include explanation for training).
    if explanations or (use_random and random.random() < p):
        exps = examples['exp']
        second_sentences = [[f" {examples[op][i]} {exp}"[:max_len] for op in opts] for i, exp in enumerate(exps)]
    else:
        # in test and validation, we don't have explanations
        second_sentences = [[f" {examples[op][i]}"[:max_len]  for op in opts] for i in range(len(examples['id']))]

    # Flatten everything
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    # Tokenize
    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    # Un-flatten
    ex_dict = {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
    ex_dict['labels'] = examples['cop']
    return ex_dict

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

In [None]:
### Processing Functions for Inference

# tokenizer for MC here (use only for inference)
# might have to write another one/modify for training (because we want to use explanation)
def tokenize_inference(entries):
    # input is entries from the dataset (train, validation, or test)
    input_dict_list = []
    label_t_list = []

    # loop through input
    for q in entries:
        # concatenate options with the sentence
        choice = [ q['question'] + ' ' + q[op] for op in opts ]
        label = q['cop']

        input_dict = tokenizer(choice, return_tensors="pt", padding=True)
        label_t = torch.tensor(label).unsqueeze(0)

        input_dict_list.append(input_dict)
        label_t_list.append(label_t)

    return input_dict_list, label_t_list


# prediction accuracy function (use only for inference)
def acc_vis_inference(model, input_dict_list, label_t_list):
    correct = []
    incorrect = []
    # eval mode (not training yet)
    model.eval()

    # Tell pytorch not to bother with constructing the compute graph during
    # the forward pass, since this is only needed for backprop (training).
    with torch.no_grad():
        for i in range(len(input_dict_list)):
            # move to GPU
            input_dict_gpu = input_dict_list[i].to(device)
            label_t_gpu = label_t_list[i].to(device)
            # output
            output_i = model(**{k: v.unsqueeze(0) for k, v in input_dict_gpu.items()}, labels=label_t_gpu)
            # takes the highest probability score as the model prediction
            pred = output_i.logits.argmax().to('cpu')

            # (in)correct prediction
            if pred == label_t_list[i].item():
                correct.append((i, output_i.logits))
            else:
                incorrect.append((i, output_i.logits))

    # returns accuracy, correct predictions, and incorrect predictions
    return (len(correct) / (len(correct) + len(incorrect))), correct, incorrect

In [None]:
### test out preprocessing

# can uncomment to check it out but commented out to save memory
'''
features = preprocess_function(ds_anatomy['train'][:5], explanations=False)
print(features)

# decode as sanity check
idx = 3
print(ds_anatomy['train'][idx])
for item in [tokenizer.decode(features["input_ids"][idx][i]) for i in range(3)]:
    print(item)
'''

'\nfeatures = preprocess_function(ds_anatomy[\'train\'][:5], explanations=False)\nprint(features)\n\n# decode as sanity check\nidx = 3\nprint(ds_anatomy[\'train\'][idx])\nfor item in [tokenizer.decode(features["input_ids"][idx][i]) for i in range(3)]:\n    print(item)\n'

In [None]:
### ENCODE the datasets using the Preprocessor (Tokenizer)

# if you want to preprocess entire dataset at once
# ds_enc = dataset.map(preprocess_function, batched=True)

# just the anatomy part (change to ^ for funetuning if you want)
# ds_anatomy_enc = ds_anatomy.map(preprocess_function, batched=True)

# randomly include explanations or not (with surgery questions)
ds_medicine_enc = ds_medicine.map(preprocess_function, batched=True)

###
### BASELINE
###

# test accuracy on validation set with untrained BERT model (no finetuning either)
# using model_name = bert-base-uncased

# input_dict_list, label_t_list = tokenize_and_format(ds_anatomy['validation'])
# print(input_dict_list[0], label_t_list[0])
# print(accuracy(model, input_dict_list, label_t_list))

# 0.24786324786324787 (basically random)

Map:   0%|          | 0/17887 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/372 [00:00<?, ? examples/s]

Map:   0%|          | 0/295 [00:00<?, ? examples/s]

In [None]:
### Necessary Helper Functions for Trainer() class

# collator for batching
@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch['labels'] = torch.tensor(labels, dtype=torch.int64)

        return batch

# metrics function for trainer class below
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# load accuracy function for training below
accuracy = evaluate.load('accuracy')

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [None]:
### Hyperparameters !

# batch size (reduce if CUDA out of memory error)
bs = 4
# Set use_loaded = True to use the saved weights
use_loaded = True

# path/name of saved state_dict file to load model, change path as needed
model_path = './microsoft-BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext-ft-4-256'

training_args = TrainingArguments(
    output_dir = f"test_model-{bs}-{max_len}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    learning_rate = 3e-5,
    per_device_train_batch_size = bs,
    per_device_eval_batch_size = bs,
    num_train_epochs = 6,
    weight_decay = 0.01,
)

# for now, validation includes explanation
# notice that we pass in the encoded datasets
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = ds_medicine_enc['train'],
    eval_dataset = ds_medicine_enc['validation'],
    tokenizer = tokenizer,
    data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics = compute_metrics,
)

if not use_loaded:
    # when not loading model
    trainer.train()

    # save state_dict of model (don't retrain each time)
    model_dict_name = model_name.replace('/', '-')

    # can use save_pretrained() also
    torch.save(model.state_dict(), f'{model_dict_name}-ft-{bs}-{max_len}')

else:
    # code to load already trained model
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    model.eval()



Epoch,Training Loss,Validation Loss,Accuracy
1,1.112,0.714184,0.718644
2,1.0089,0.741076,0.701695
3,0.8576,0.762329,0.705085
4,0.6803,0.891719,0.711864
5,0.559,1.190821,0.701695


In [None]:
### Inference

# test finetuned model on validation set
# (need to send in responses to github repo for test set)
# https://github.com/MedMCQA/MedMCQA


# choose which validation set to use for consistency
# experiment : using a model finetuned on anatomy for inference on surgery
# does it do better than random?
validation_set = ds_surgery['validation']


input_dict_list, label_t_list = tokenize_inference(validation_set)

accuracy, cor, inc = acc_vis_inference(model, input_dict_list, label_t_list)


In [None]:
### Visualize Correct and Incorrect Samples from Inference

# note that these are the accuracies without the explanation!
# otherwise it might be considered cheating

print('Accuracy :', accuracy, '\n')

print('Sample Correct Questions : \n')
for i in range(5):
    pprint(validation_set[cor[i][0]])
    # logits
    print(cor[i][1], '\n')

print('\n Sample Incorrect Questions : \n')
for i in range(5):
    pprint(validation_set[inc[i][0]])
    # logits
    print(inc[i][1], '\n')

Accuracy : 0.3116531165311653 

Sample Correct Questions : 

{'choice_type': 'single',
 'cop': 0,
 'exp': None,
 'id': '9603526f-8c7d-4618-963d-be8a05c28a94',
 'opa': 'Wait & watch',
 'opb': 'Antral pack',
 'opc': 'Titanium Mesh',
 'opd': 'Glass bead mesh',
 'question': 'In a patient with fresh blow out fracture of the orbit, best '
             'immediate management is',
 'subject_name': 'Surgery',
 'topic_name': None}
tensor([[-0.7803, -0.8486, -0.7988, -0.8263]], device='cuda:0') 

{'choice_type': 'single',
 'cop': 3,
 'exp': None,
 'id': '2639d0ba-ef15-4ba4-92fe-ee27b5758fbf',
 'opa': 'MR',
 'opb': 'ASD',
 'opc': 'MS',
 'opd': 'CABG',
 'question': 'Which of these conditions  does not require SABE prophylaxis',
 'subject_name': 'Surgery',
 'topic_name': None}
tensor([[-1.1552, -1.1121, -1.1427, -1.0844]], device='cuda:0') 

{'choice_type': 'single',
 'cop': 1,
 'exp': None,
 'id': '99ebfb54-f46f-4053-8ef1-3c931d657bdb',
 'opa': '10% ethanol',
 'opb': '10% formalin',
 'opc': 'Hydroge

In [None]:
###
### Some Test Runs with different hyperparameters
###

## On ds_anatomy

# epoch=3, learning_rate=5e-5, bs=4, max_len=256, weight_decay=0.01
# 0.31196581196581197 for pubmedbert when finetuned only on Anatomy questions (pubmed-ft state_dict)

# epoch=5, learning_rate=1e-5, bs=4, max_len=64, weight_decay=0.01
# 0.3034188034188034 for pubmedbert when finetuned only on Anatomy questions (pubmed-ft state_dict)

# epoch=6, learning_rate=3e-5, bs=4, max_len=256, weight_decay=0.01 (this one overfits maybe)
# 0.32905982905982906

# epoch=6, learning_rate=1e-5, bs=4, max_len=100, weight_decay=0.01, no explanations
# didn't finish this one

# epoch=5, learning_rate=1e-5, bs=4, max_len=256, weight_decay=0.01 (really bad, probably overfit)
# 0.2564102564102564

# epoch=5, learning_rate=3e-5, bs=4, max_len=256, weight_decay=0.01 (using process_random)
#

# using pubmedbert trained on ds_medicine
# 0.29914529914529914

## On ds_surgery

# using pubmed-ft-4-256 trained on ds_anatomy ("transfer learning" but not really)
# 0.3062330623306233 (still pretty good)

# using pubmedbert trained on ds_medicine
# 0.3116531165311653

## On ds_skin

# using pubmed-ft-4-256 trained on ds_anatomy ("transfer learning" but not really)
# 0.5294117647058824

## On ds_medicine

# epoch=5, learning_rate=3e-5, bs=4, max_len=256, weight_decay=0.01 (using process_random)
# 0.26101694915254237

# using pubmed-ft-4-256 trained on ds_anatomy ("transfer learning" but not really)
# 0.31864406779661014

## On ds_psychiatry
#0.30

# epoch=5, learning_rate=3e-5, bs=4, max_len=256, weight_decay=0.01 (using process_random)
# 0.31

# epoch=5, learning_rate=3e-5, bs=4, max_len=256, weight_decay=0.01, no explanations
# 0.25

# epoch=5, learning_rate=1e-5, bs=4, max_len=256, weight_decay=0.01 (using process_random) - pathology
# 0.3115727002967359

# using pubmed-ft-4-256 trained on ds_anatomy ("transfer learning" but not really)
# 0.375 (still pretty good)

# using pubmedbert trained on ds_medicine
# 0.25