# Toy Medical Question Answering GPT

In this exercise, we will utilize the MinGPT library to train a basic medical Question and Answer (Q&A) model from scratch. Our goal is to create a model that can provide answers to medical-related questions using the MedQA dataset. While our resulting model may not reach the precision required for clinical use, this exercise will serve as a valuable hands-on experience in understanding the training process of Q&A models and the nuances of handling domain-specific data.

In [None]:
from datasets import load_dataset, load_dataset_builder, get_dataset_split_names
from torch.utils.data import Dataset

from mingpt.bpe import BPETokenizer
from mingpt.model import GPT
from mingpt.trainer import Trainer

import torch

device = 'cuda:0'

## 1. MedQA Dataset

The MedQA dataset, introduced in the paper by Jin et al. in 2021, is a comprehensive dataset formulated for the task of medical question answering. It consists of question-and-answer pairs derived from a variety of medical examinations and literature, providing a robust foundation for training models intended for medical information retrieval and Q&A tasks. The dataset encompasses a broad spectrum of medical knowledge, making it an excellent resource for training specialized models in the medical domain. Its structured format facilitates the training of models capable of interpreting medical queries and providing accurate, informative responses.

Jin, Q., Dhingra, B., Liu, Z., Cohen, W., & Lu, X. (2021). Disease Knowledge Distillation for Medical Dialogue Generation.  [Link to Paper](https://arxiv.org/abs/2109.00704)

In [None]:
# We will use Huggingface datasets to get our working version

DATASET_NAME = "bigbio/med_qa"
DATASET_CONFIG = "med_qa_en_source"
ds_builder = load_dataset_builder(DATASET_NAME,DATASET_CONFIG)

In [None]:
# Print the summary 
print(ds_builder.info.description)

In [None]:
train_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split='train')

Let's look at a couple of examples from the dataset.

In [None]:
idx = 1

sample = train_ds[idx]
print(f"""Sample {idx}:

Question:
{sample['question']}

Answer:
{sample['answer']}
""")

## 2. Preprocessing

In order to feed the text into the GPT model, we must first tokenize the input (converting the text into a series of integers.

In [None]:
# Create the tokenizer

bpe_tokenizer = BPETokenizer()

In [None]:
# Tokenize the sample question

bpe_tokenizer(train_ds[idx]['question'])

In [None]:
# Let's look at a how these map with each other
results = bpe_tokenizer.encoder.encode_and_show_work(train_ds[0]['question'])
# print(train_ds[0]['question'])
print("ID\t|\tTOKEN")
print("------------------")
for token, bpe_id in zip(results['tokens'], results['bpe_idx']):
    print(f"{bpe_id}\t|\t{token}")

In [None]:
# We will use this function to encode each sample into a string
# Our model will learn to predict the most likely answer string
# conditioned on the input question

def encode_examples(example):
    training_sentence = f"""{example['question']}
    
    Answer: {example['answer']}
    """
    return bpe_tokenizer(training_sentence)[0]

In [None]:
# Each of our samples are different lengths
# For simplicity, we will limit our samples to only those longer than 129 tokens
# We will also the beginning of each example to be only 129 tokens long

tokenizer_examples = [encode_examples(ex) for ex in train_ds]

# I only want to keep examples longer than 128 tokens
# I only want to use the last 129 tokens of each example
tokenized_train = [ex[-129:] for ex in tokenizer_examples if len(ex) >= 129]

In [None]:
# This is an idiomatic torch map style dataset wrapper around our data

class SimpleMedQADataset(Dataset):
    def __init__(self, tokenized_examples):
        self.tokenized_examples = tokenized_examples
        
    def __len__(self):
        return len(self.tokenized_examples)
    
    def __getitem__(self, idx):
        return self.tokenized_examples[idx][:-1], self.tokenized_examples[idx][1:]

In [None]:
train_dataset = SimpleMedQADataset(tokenized_train)

# 3. Model

In [None]:
# Just like in the last exercise, we first put our model's hyperparameters
# in a config - we will train a medium sized GPT2 from scratch

model_config = GPT.get_default_config()
model_config.model_type = 'gpt2'
model_config.vocab_size = 50257
model_config.block_size = 256
model = GPT(model_config)

In [None]:
# Here we will use the built-in trainer class of the minGPT library encapsulate our training loop

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

In [None]:
# Here we provide a callback to occasionally log our training progress to output

def batch_end_callback(trainer):
    if trainer.iter_num % 10 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")

# Add the callback to be called at the end of each batch        
trainer.set_callback('on_batch_end', batch_end_callback)

# Here we start the trainer
trainer.run()

In [None]:
# This will take a while to get to reasonable loss - we will come back after a while to check in on this

# 4. Check Model Generation Quality

**DISCLAIMER:** We have trained a small, toy model from scratch(!) on a relatively small dataset (9K examples) so DO NOT use this model for any diagnostic purposes!  This exercise is only intended to demonstrate how to use PyTorch directly for training a model from scratch.

Given the small model and dataset, we should not expect a model with particularly strong performance on this complex, knowledge-intensive clinical reasoning dataset.

What is reasonable to see though is "medically flavored" nonsense.  We should see strings of words that kind of look medical or clinical.  Perhaps certain phrases may appear.  However, there is very little in the way of logical though process.

Larger models with more training data are able to accomplish this mimicry with much greater fidelity.  At the end of the day however they share the same fundamental architecture as this model and most of the same training recipe**

In [None]:
idx = 200

inputs = bpe_tokenizer(train_ds[idx]['question']+"\n\nAnswer:").to(device)
outputs = model.generate(inputs, max_new_tokens=20, top_k=100)

try:
    offset = list(outputs[0][-20:]).index(198)
except:
    offset = 10

In [None]:
print(train_ds[idx]['question'])
bpe_tokenizer.decode(outputs[0][len(inputs[0]):])

In [None]:
train_ds[idx]['answer']