This is my ongoing final year project for my degree in Data Science and Computing. The project is to experiment with different LLMs for the task of multiple choice question answering in the medical domain. The dataset used is the MedMCQA dataset from hugging face hub. The model is fine-tuned using tensorflow and keras. The model is trained on google colab and the model is saved in google drive. The model is evaluated on the test dataset to measure its accuracy. The project is expected to be done by the end of April 2024. 

# Install & Import packages

In [None]:
!pip install transformers
!pip install datasets


Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m74.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m68.8 MB/s[0m eta [36m0:00:00[0m
Colle

In [None]:
import numpy as np
import tensorflow as tf
import os
import math

# Download dataset

Download medmcqa dataset from hugging face hub


In [None]:
from datasets import load_dataset

raw_datasets = load_dataset("medmcqa")

Downloading builder script:   0%|          | 0.00/5.35k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.41k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/55.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Explore dataset

In [None]:
raw_train_dataset = raw_datasets["train"]

print("Sample training data: ")
raw_train_dataset[0]

Sample training data: 


{'id': 'e9ad821a-c438-4965-9f77-760819dfa155',
 'question': 'Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma',
 'opa': 'Hyperplasia',
 'opb': 'Hyperophy',
 'opc': 'Atrophy',
 'opd': 'Dyplasia',
 'cop': 2,
 'choice_type': 'single',
 'exp': 'Chronic urethral obstruction because of urinary calculi, prostatic hyperophy, tumors, normal pregnancy, tumors, uterine prolapse or functional disorders cause hydronephrosis which by definition is used to describe dilatation of renal pelvis and calculus associated with progressive atrophy of the kidney due to obstruction to the outflow of urine Refer Robbins 7yh/9,1012,9/e. P950',
 'subject_name': 'Anatomy',
 'topic_name': 'Urinary tract'}

In [None]:
raw_train_dataset.features

{'id': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None),
 'opa': Value(dtype='string', id=None),
 'opb': Value(dtype='string', id=None),
 'opc': Value(dtype='string', id=None),
 'opd': Value(dtype='string', id=None),
 'cop': ClassLabel(names=['a', 'b', 'c', 'd'], id=None),
 'choice_type': Value(dtype='string', id=None),
 'exp': Value(dtype='string', id=None),
 'subject_name': Value(dtype='string', id=None),
 'topic_name': Value(dtype='string', id=None)}

# Prepare the dataset for training

As choices are stored in different features, we need to concatinate them with their questions as text input into the training model

In [None]:
def concat_questions_with_answers(example):
  example["question"] = """
  {question}.

  Please select one of the following questions:
  A. {opa}
  B. {opb}
  C. {opc}
  D. {opd}

  The answer is
  """.format(
      question=example["question"],
      opa=example["opa"],
      opb=example["opb"],
      opc=example["opc"],
      opd=example["opd"]
  )

  return example

raw_datasets = raw_datasets.map(concat_questions_with_answers)

Map:   0%|          | 0/182822 [00:00<?, ? examples/s]

Map:   0%|          | 0/6150 [00:00<?, ? examples/s]

Map:   0%|          | 0/4183 [00:00<?, ? examples/s]

In [None]:
raw_datasets["train"][0]

{'id': 'e9ad821a-c438-4965-9f77-760819dfa155',
 'question': 'Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma. Please select one of the following questions:\n  A. Hyperplasia B. Hyperophy C. Atrophy D. Dyplasia\n  ',
 'opa': 'Hyperplasia',
 'opb': 'Hyperophy',
 'opc': 'Atrophy',
 'opd': 'Dyplasia',
 'cop': 2,
 'choice_type': 'single',
 'exp': 'Chronic urethral obstruction because of urinary calculi, prostatic hyperophy, tumors, normal pregnancy, tumors, uterine prolapse or functional disorders cause hydronephrosis which by definition is used to describe dilatation of renal pelvis and calculus associated with progressive atrophy of the kidney due to obstruction to the outflow of urine Refer Robbins 7yh/9,1012,9/e. P950',
 'subject_name': 'Anatomy',
 'topic_name': 'Urinary tract'}

Explore data types of each feature

In [None]:
raw_datasets["train"].features

{'id': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None),
 'opa': Value(dtype='string', id=None),
 'opb': Value(dtype='string', id=None),
 'opc': Value(dtype='string', id=None),
 'opd': Value(dtype='string', id=None),
 'cop': ClassLabel(names=['a', 'b', 'c', 'd'], id=None),
 'choice_type': Value(dtype='string', id=None),
 'exp': Value(dtype='string', id=None),
 'subject_name': Value(dtype='string', id=None),
 'topic_name': Value(dtype='string', id=None)}

Choose the correct tokenizer for the training model

In [None]:
checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenized_questions = tokenizer(raw_datasets["train"]["question"])

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

In [None]:
tokenized_questions

Tokenize the raw datasets

In [None]:
def tokenize_function(example):
  return tokenizer(example["question"], truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets

Padding text to a certain length

In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")

Split tokenized datasets into train, validation & test datasets

In [None]:
tf_train_dataset = tokenized_datasets["train"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["cop"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8
)

tf_validation_dataset = tokenized_datasets["validation"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["cop"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=8
)

tf_test_dataset = tokenized_datasets["test"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["cop"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=8
)

# For loading saved models

In [None]:
path = 'drive/MyDrive/medmcqa'
os.listdir(path)

['config.json',
 'tf_model.h5',
 'bert-cp',
 'cp-0001.ckpt.index',
 'cp-0001.ckpt.data-00000-of-00001',
 'checkpoint']

In [None]:
saved_model = model.from_pretrained(path)
saved_mode.summary()

KeyboardInterrupt: ignored

# Build training model

In [None]:
from transformers import TFAutoModelForSequenceClassification

model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=4, from_pt=True)

Configure learning rate for faster training

In [None]:
from tensorflow.keras.optimizers.schedules import PolynomialDecay

batch_size = 8
num_epochs = 3

num_train_steps = len(tf_train_dataset) * num_epochs
lr_scheduler = PolynomialDecay(
    initial_learning_rate=5e-5,
    end_learning_rate=0.0,
    decay_steps=num_train_steps
)

from tensorflow.keras.optimizers import Adam

opt = Adam(learning_rate=lr_scheduler)

Configure persistent storage of model checkpoint

In [None]:
checkpoint_path = "drive/MyDrive/medmcqa/cp-{epoch:04d}.ckpt"

n_batches = len(tf_train_dataset) / batch_size
n_batches = math.ceil(n_batches)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True,
    save_freq=5*n_batches
)

Configure the model for training

In [None]:
from tensorflow.keras.losses import SparseCategoricalCrossentropy

model.compile(
    optimizer=opt,
    loss=SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

In [None]:
# load pre-saved model
model.from_pretrained("drive/MyDrive/medmcqa/bert-cp")

Some layers from the model checkpoint at drive/MyDrive/medmcqa/bert-cp were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at drive/MyDrive/medmcqa/bert-cp.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.


<transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification at 0x7bbe8cdfc070>

In [None]:
model.fit(
    tf_train_dataset,
    callbacks=[cp_callback],
    validation_data=tf_validation_dataset,
    epochs=1
)

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


    1/22853 [..............................] - ETA: 332:38:35 - loss: 1.3640 - accuracy: 0.3750

Using eos_token, but it is not set yet.


    2/22853 [..............................] - ETA: 67:27:52 - loss: 1.4252 - accuracy: 0.2500 

Using bos_token, but it is not set yet.


    3/22853 [..............................] - ETA: 67:24:23 - loss: 1.4363 - accuracy: 0.2500

Using eos_token, but it is not set yet.


    6/22853 [..............................] - ETA: 50:57:34 - loss: 1.4308 - accuracy: 0.2083

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


    7/22853 [..............................] - ETA: 53:27:11 - loss: 1.4341 - accuracy: 0.1786

Using bos_token, but it is not set yet.


    8/22853 [..............................] - ETA: 57:11:15 - loss: 1.4356 - accuracy: 0.1875

Using eos_token, but it is not set yet.


   11/22853 [..............................] - ETA: 53:08:24 - loss: 1.4105 - accuracy: 0.2500

Using bos_token, but it is not set yet.


   12/22853 [..............................] - ETA: 54:24:18 - loss: 1.4060 - accuracy: 0.2604

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


   13/22853 [..............................] - ETA: 56:27:20 - loss: 1.3990 - accuracy: 0.2692

Using eos_token, but it is not set yet.


   16/22853 [..............................] - ETA: 54:00:46 - loss: 1.4139 - accuracy: 0.2656

Using bos_token, but it is not set yet.


   17/22853 [..............................] - ETA: 54:48:53 - loss: 1.4103 - accuracy: 0.2574

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


   18/22853 [..............................] - ETA: 56:09:20 - loss: 1.4060 - accuracy: 0.2708

Using eos_token, but it is not set yet.


   21/22853 [..............................] - ETA: 54:23:09 - loss: 1.3998 - accuracy: 0.2798

Using bos_token, but it is not set yet.


   22/22853 [..............................] - ETA: 54:55:32 - loss: 1.3969 - accuracy: 0.2841

Using eos_token, but it is not set yet.


   23/22853 [..............................] - ETA: 55:21:40 - loss: 1.3906 - accuracy: 0.2880

Using bos_token, but it is not set yet.


   24/22853 [..............................] - ETA: 55:48:24 - loss: 1.3882 - accuracy: 0.2812

Using eos_token, but it is not set yet.


   27/22853 [..............................] - ETA: 52:39:12 - loss: 1.3827 - accuracy: 0.2917

Using bos_token, but it is not set yet.


   28/22853 [..............................] - ETA: 53:06:18 - loss: 1.3946 - accuracy: 0.2857

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


   29/22853 [..............................] - ETA: 53:59:44 - loss: 1.3967 - accuracy: 0.2845

Using eos_token, but it is not set yet.


   32/22853 [..............................] - ETA: 53:05:02 - loss: 1.3999 - accuracy: 0.2773

Using bos_token, but it is not set yet.


   33/22853 [..............................] - ETA: 53:28:37 - loss: 1.4010 - accuracy: 0.2765

Using eos_token, but it is not set yet.


   34/22853 [..............................] - ETA: 53:52:46 - loss: 1.4009 - accuracy: 0.2757

Using bos_token, but it is not set yet.


   35/22853 [..............................] - ETA: 54:16:03 - loss: 1.3988 - accuracy: 0.2786

Using eos_token, but it is not set yet.


   38/22853 [..............................] - ETA: 52:00:13 - loss: 1.4020 - accuracy: 0.2632

Using bos_token, but it is not set yet.


   39/22853 [..............................] - ETA: 52:23:45 - loss: 1.3996 - accuracy: 0.2692

Using eos_token, but it is not set yet.


   40/22853 [..............................] - ETA: 52:44:01 - loss: 1.3989 - accuracy: 0.2719

Using bos_token, but it is not set yet.


   41/22853 [..............................] - ETA: 53:02:59 - loss: 1.3967 - accuracy: 0.2744

Using eos_token, but it is not set yet.


   45/22853 [..............................] - ETA: 50:07:41 - loss: 1.3991 - accuracy: 0.2722

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


   46/22853 [..............................] - ETA: 50:30:16 - loss: 1.3980 - accuracy: 0.2717

Using bos_token, but it is not set yet.


   47/22853 [..............................] - ETA: 51:06:01 - loss: 1.3976 - accuracy: 0.2713

Using eos_token, but it is not set yet.


   51/22853 [..............................] - ETA: 49:37:23 - loss: 1.3959 - accuracy: 0.2745

Using bos_token, but it is not set yet.


   52/22853 [..............................] - ETA: 49:54:40 - loss: 1.3957 - accuracy: 0.2764

Using eos_token, but it is not set yet.


   53/22853 [..............................] - ETA: 50:11:38 - loss: 1.3944 - accuracy: 0.2759

Using bos_token, but it is not set yet.


   54/22853 [..............................] - ETA: 50:30:35 - loss: 1.3928 - accuracy: 0.2824

Using eos_token, but it is not set yet.


   57/22853 [..............................] - ETA: 49:16:22 - loss: 1.3937 - accuracy: 0.2763

Using bos_token, but it is not set yet.


   58/22853 [..............................] - ETA: 49:34:35 - loss: 1.3901 - accuracy: 0.2802

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


   59/22853 [..............................] - ETA: 50:02:59 - loss: 1.3895 - accuracy: 0.2775

Using eos_token, but it is not set yet.


   63/22853 [..............................] - ETA: 48:57:27 - loss: 1.3873 - accuracy: 0.2778

Using bos_token, but it is not set yet.


   64/22853 [..............................] - ETA: 49:13:04 - loss: 1.3924 - accuracy: 0.2734

Using eos_token, but it is not set yet.


   65/22853 [..............................] - ETA: 49:28:11 - loss: 1.3922 - accuracy: 0.2712

Using bos_token, but it is not set yet.


   66/22853 [..............................] - ETA: 49:43:59 - loss: 1.3927 - accuracy: 0.2708

Using eos_token, but it is not set yet.


   70/22853 [..............................] - ETA: 48:01:44 - loss: 1.3976 - accuracy: 0.2732

Using bos_token, but it is not set yet.


   71/22853 [..............................] - ETA: 48:18:08 - loss: 1.3954 - accuracy: 0.2746

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


   72/22853 [..............................] - ETA: 48:44:39 - loss: 1.3938 - accuracy: 0.2778

Using eos_token, but it is not set yet.


   76/22853 [..............................] - ETA: 47:54:18 - loss: 1.3957 - accuracy: 0.2747

Using bos_token, but it is not set yet.


   77/22853 [..............................] - ETA: 48:07:10 - loss: 1.3965 - accuracy: 0.2744

Using eos_token, but it is not set yet.


   78/22853 [..............................] - ETA: 48:20:36 - loss: 1.3961 - accuracy: 0.2740

Using bos_token, but it is not set yet.


   79/22853 [..............................] - ETA: 48:32:47 - loss: 1.3968 - accuracy: 0.2722

Using eos_token, but it is not set yet.


   82/22853 [..............................] - ETA: 47:45:17 - loss: 1.3964 - accuracy: 0.2759

Using bos_token, but it is not set yet.


   83/22853 [..............................] - ETA: 47:57:57 - loss: 1.3968 - accuracy: 0.2756

Using eos_token, but it is not set yet.


   84/22853 [..............................] - ETA: 48:10:23 - loss: 1.3958 - accuracy: 0.2768

Using bos_token, but it is not set yet.


   85/22853 [..............................] - ETA: 48:22:03 - loss: 1.3942 - accuracy: 0.2794

Using eos_token, but it is not set yet.


   89/22853 [..............................] - ETA: 47:07:00 - loss: 1.3926 - accuracy: 0.2809

Using bos_token, but it is not set yet.


   90/22853 [..............................] - ETA: 47:19:22 - loss: 1.3911 - accuracy: 0.2847

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


   91/22853 [..............................] - ETA: 47:40:13 - loss: 1.3911 - accuracy: 0.2857

Using eos_token, but it is not set yet.


   97/22853 [..............................] - ETA: 46:03:07 - loss: 1.3949 - accuracy: 0.2771

Using bos_token, but it is not set yet.


   98/22853 [..............................] - ETA: 46:15:05 - loss: 1.3945 - accuracy: 0.2768

Using eos_token, but it is not set yet.


   99/22853 [..............................] - ETA: 46:26:57 - loss: 1.3945 - accuracy: 0.2778

Using bos_token, but it is not set yet.


  100/22853 [..............................] - ETA: 46:38:00 - loss: 1.3962 - accuracy: 0.2763

Using eos_token, but it is not set yet.


  104/22853 [..............................] - ETA: 45:37:32 - loss: 1.3925 - accuracy: 0.2812

Using bos_token, but it is not set yet.


  105/22853 [..............................] - ETA: 45:48:12 - loss: 1.3931 - accuracy: 0.2810

Using eos_token, but it is not set yet.


  106/22853 [..............................] - ETA: 45:59:22 - loss: 1.3931 - accuracy: 0.2818

Using bos_token, but it is not set yet.


  107/22853 [..............................] - ETA: 46:10:20 - loss: 1.3935 - accuracy: 0.2804

Using eos_token, but it is not set yet.


  114/22853 [..............................] - ETA: 44:00:31 - loss: 1.3928 - accuracy: 0.2840

Using bos_token, but it is not set yet.


  115/22853 [..............................] - ETA: 44:12:25 - loss: 1.3929 - accuracy: 0.2848

Using eos_token, but it is not set yet.


  116/22853 [..............................] - ETA: 44:22:50 - loss: 1.3929 - accuracy: 0.2845

Using bos_token, but it is not set yet.


  117/22853 [..............................] - ETA: 44:33:06 - loss: 1.3925 - accuracy: 0.2853

Using eos_token, but it is not set yet.


  124/22853 [..............................] - ETA: 42:41:06 - loss: 1.3895 - accuracy: 0.2893

Using bos_token, but it is not set yet.


  125/22853 [..............................] - ETA: 42:51:54 - loss: 1.3901 - accuracy: 0.2900

Using eos_token, but it is not set yet.


  126/22853 [..............................] - ETA: 43:02:18 - loss: 1.3911 - accuracy: 0.2887

Using bos_token, but it is not set yet.


  127/22853 [..............................] - ETA: 43:12:23 - loss: 1.3920 - accuracy: 0.2884

Using eos_token, but it is not set yet.


  132/22853 [..............................] - ETA: 42:11:45 - loss: 1.3930 - accuracy: 0.2860

Using bos_token, but it is not set yet.


  133/22853 [..............................] - ETA: 42:21:38 - loss: 1.3938 - accuracy: 0.2857

Using eos_token, but it is not set yet.


  134/22853 [..............................] - ETA: 42:31:21 - loss: 1.3942 - accuracy: 0.2864

Using bos_token, but it is not set yet.


  135/22853 [..............................] - ETA: 42:41:40 - loss: 1.3936 - accuracy: 0.2880

Using eos_token, but it is not set yet.


  138/22853 [..............................] - ETA: 42:21:18 - loss: 1.3941 - accuracy: 0.2853

Using bos_token, but it is not set yet.


  139/22853 [..............................] - ETA: 42:30:46 - loss: 1.3934 - accuracy: 0.2869

Using eos_token, but it is not set yet.


  140/22853 [..............................] - ETA: 42:40:04 - loss: 1.3932 - accuracy: 0.2875

Using bos_token, but it is not set yet.


  141/22853 [..............................] - ETA: 42:49:10 - loss: 1.3926 - accuracy: 0.2881

Using eos_token, but it is not set yet.


Epoch 1: saving model to drive/MyDrive/medmcqa/cp-0001.ckpt


<keras.src.callbacks.History at 0x79bb35443250>

Save the model

In [None]:
model.save_pretrained("drive/MyDrive/medmcqa/bert-cp", from_pt=True)

# Evaluate the model on test dataset

In [None]:
loss, acc = model.evaluate(tf_test_dataset)
print("Model accuracy on test dataset: {:5.2f}%".format(acc * 100))

 93/769 [==>...........................] - ETA: 34s - loss: nan - accuracy: 0.0000e+00

KeyboardInterrupt: ignored

In [None]:
preds = model.predict(tf_validation_dataset)["logits"]