In [1]:
import pandas as pd
from random import randint
import json
import os

In [2]:
# define dataframe
data_path = "./data/input/intern_screening_dataset.csv"

df = pd.read_csv(data_path)

## Inspect data

In [3]:
# df shape
df.shape

(16406, 2)

In [4]:
# inspect beggining of data
df.head()

Unnamed: 0,question,answer
0,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...
1,What is (are) Glaucoma ?,The optic nerve is a bundle of more than 1 mil...
2,What is (are) Glaucoma ?,Open-angle glaucoma is the most common form of...
3,Who is at risk for Glaucoma? ?,Anyone can develop glaucoma. Some people are a...
4,How to prevent Glaucoma ?,"At this time, we do not know how to prevent gl..."


In [5]:
# inspect random data span
span = 10

start = randint(0, len(df.index) - span)
df.iloc[start:start + span]

Unnamed: 0,question,answer
6137,What are the treatments for Congenital contrac...,How might congenital contractural arachnodacty...
6138,What is (are) Geographic tongue ?,Geographic tongue is a condition that causes c...
6139,What are the symptoms of Geographic tongue ?,What are the signs and symptoms of Geographic ...
6140,What causes Geographic tongue ?,What causes geographic tongue? Is it genetic? ...
6141,What are the treatments for Geographic tongue ?,What treatment is available for geographic ton...
6142,What is (are) Synovial Chondromatosis ?,Synovial chondromatosis is a type of non-cance...
6143,What causes Synovial Chondromatosis ?,What causes synovial chondromatosis? The exact...
6144,What are the symptoms of Mucopolysaccharidosis...,What are the signs and symptoms of Mucopolysac...
6145,What is (are) Congenital laryngeal palsy ?,Congenital laryngeal palsy is also known as co...
6146,What are the symptoms of Congenital laryngeal ...,What are the signs and symptoms associated wit...


In [6]:
# inspect specific row
selected_row = 0

row = df.iloc[selected_row]
for key, value in row.items():
    print(f"{key}: {value}")

question: What is (are) Glaucoma ?
answer: Glaucoma is a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. The most common form of the disease is open-angle glaucoma. With early treatment, you can often protect your eyes against serious vision loss. (Watch the video to learn more about glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.)  See this graphic for a quick overview of glaucoma, including how many people it affects, whos at risk, what to do if you have it, and how to learn more.  See a glossary of glaucoma terms.


In [7]:
# inspect end of data
df.tail()

Unnamed: 0,question,answer
16401,What is (are) Diabetic Neuropathies: The Nerve...,Autonomic neuropathy affects the nerves that c...
16402,What is (are) Diabetic Neuropathies: The Nerve...,"Proximal neuropathy, sometimes called lumbosac..."
16403,What is (are) Diabetic Neuropathies: The Nerve...,Focal neuropathy appears suddenly and affects ...
16404,How to prevent Diabetic Neuropathies: The Nerv...,The best way to prevent neuropathy is to keep ...
16405,How to diagnose Diabetic Neuropathies: The Ner...,Doctors diagnose neuropathy on the basis of sy...


In [8]:
# inspect data types
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 16406 entries, 0 to 16405
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   question  16406 non-null  object
 1   answer    16401 non-null  object
dtypes: object(2)
memory usage: 256.5+ KB


In [9]:
# max question length
max_q_len = df["question"].str.len().max()
print(f"Max question length: {int(max_q_len)}")

# max answer length
max_a_len = df["answer"].str.len().max()
print(f"Max answer length: {int(max_a_len)}")

Max question length: 191
Max answer length: 29046


In [10]:
# compute quartiles counts and lengths
quartile_counts = {}
quartile_lengths = {}

for col in df.columns:
    # length of each string in the column
    lengths = df[col].str.len()
    
    # compute quartiles
    Q1, Q2, Q3 = lengths.quantile([0.25, 0.50, 0.75])
    
    # count how many strings fall into each range
    count_less_Q1 = (lengths < Q1).sum()
    count_Q1_Q2 = ((lengths >= Q1) & (lengths < Q2)).sum()
    count_Q2_Q3 = ((lengths >= Q2) & (lengths < Q3)).sum()
    count_greater_Q3 = (lengths >= Q3).sum()
    
    quartile_counts[col] = {
        'less_than_Q1': count_less_Q1,
        'Q1_to_Q2': count_Q1_Q2,
        'Q2_to_Q3': count_Q2_Q3,
        'greater_than_Q3': count_greater_Q3
    }

    quartile_lengths[col] = {
        'Q1': Q1,
        'Q2': Q2,
        'Q3': Q3
    }


In [11]:
# show quartile counts
for col in quartile_counts.keys():
    print(f"Counts for column: {col}")
    print(quartile_counts[col])

Counts for column: question
{'less_than_Q1': 3784, 'Q1_to_Q2': 4103, 'Q2_to_Q3': 4417, 'greater_than_Q3': 4102}
Counts for column: answer
{'less_than_Q1': 4097, 'Q1_to_Q2': 4095, 'Q2_to_Q3': 4108, 'greater_than_Q3': 4101}


In [12]:
# show quartile split lengths
for col in quartile_lengths.keys():
    print(f"Quartile lengths for column: {col}")
    print(quartile_lengths[col])

Quartile lengths for column: question
{'Q1': 38.0, 'Q2': 48.0, 'Q3': 60.75}
Quartile lengths for column: answer
{'Q1': 487.0, 'Q2': 889.0, 'Q3': 1588.0}


## Data preprocessing

In [13]:
# remove rows with missing values and duplicates
df = df.dropna()
df = df.drop_duplicates(subset=['question'], keep='first')

df.shape

(14976, 2)

## Fine-tuning

In [14]:
from datasets import Dataset

In [15]:
# load split data into dataset
dataset = Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.2, seed=0)

dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', '__index_level_0__'],
        num_rows: 11980
    })
    test: Dataset({
        features: ['question', 'answer', '__index_level_0__'],
        num_rows: 2996
    })
})

## Fine-tuning

In [16]:
from transformers import T5ForConditionalGeneration, T5TokenizerFast
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorForSeq2Seq

In [17]:
# define model, tokenizer and data collator
model_name = "google/flan-t5-base"

tokenizer = T5TokenizerFast.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to("cuda")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [18]:
# define t5 task prefix
prefix = "Please answer this medical question: "

# define max question and answer length
max_source_tokens = 128
max_sample_tokens = 128

def preprocess_function(examples):
    """Preprocess function for T5 model."""

    inputs = [prefix + example for example in examples["question"]]
    model_inputs = tokenizer(inputs, max_length=max_source_tokens, truncation=True)

    labels = tokenizer(text_target=examples["answer"], max_length=max_sample_tokens, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [19]:
# preprocess dataset
tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["question", "answer", "__index_level_0__"])

Map:   0%|          | 0/11980 [00:00<?, ? examples/s]

Map:   0%|          | 0/2996 [00:00<?, ? examples/s]

In [20]:
# define training arguments
training_args = TrainingArguments(
    output_dir="./data/results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=2,
    push_to_hub=False
)



In [21]:
# define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

  trainer = Trainer(


In [22]:
# train model
metrics = trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,1.8393,1.635396
2,1.7365,1.572415
3,1.6842,1.556254


In [23]:
# store experiment results
experiment = {
    "model_name": model_name,
    "dataset_shape": dataset.shape,
    "max_source_tokens": max_source_tokens,
    "max_sample_tokens": max_sample_tokens,
    "metrics": metrics,
}

experiment

{'model_name': 'google/flan-t5-base',
 'dataset_shape': {'train': (11980, 3), 'test': (2996, 3)},
 'max_source_tokens': 128,
 'max_sample_tokens': 128,
 'metrics': TrainOutput(global_step=4494, training_loss=1.8207663997417565, metrics={'train_runtime': 1645.8838, 'train_samples_per_second': 21.836, 'train_steps_per_second': 2.73, 'total_flos': 1476379649212416.0, 'train_loss': 1.8207663997417565, 'epoch': 3.0})}

In [24]:
def save_experiment(experiment):
    """Save experiment results to json file."""
    
    # detect last experiment number on filename
    experiment_path = f"./data/experiments/"
    experiment_path += f"experiment_{len([name for name in os.listdir(experiment_path) if 'experiment' in name]) + 1}.json"
    
    # save experiment data
    with open(experiment_path, 'w') as f:
        json.dump(experiment, f, indent=4)

    # save model
    model_path = f"./data/experiments/model_{experiment_path.split('_')[-1].split('.')[0]}"
    model.save_pretrained(model_path)

In [25]:
# save experiment and model
save_experiment(experiment)

## Inference

In [26]:
dataset["train"][0]

{'question': 'What is (are) X-linked myotubular myopathy ?',
 'answer': 'X-linked myotubular myopathy is a condition that primarily affects muscles used for movement (skeletal muscles) and occurs almost exclusively in males. People with this condition have muscle weakness (myopathy) and decreased muscle tone (hypotonia) that are usually evident at birth.  The muscle problems in X-linked myotubular myopathy impair the development of motor skills such as sitting, standing, and walking. Affected infants may also have difficulties with feeding due to muscle weakness. Individuals with this condition often do not have the muscle strength to breathe on their own and must be supported with a machine to help them breathe (mechanical ventilation). Some affected individuals need breathing assistance only periodically, typically during sleep, while others require it continuously. People with X-linked myotubular myopathy may also have weakness in the muscles that control eye movement (ophthalmopleg

In [27]:
question = dataset["train"][0]["question"]

In [28]:
# inference analysis A
input_text = prefix + question
input_ids = tokenizer.encode(input_text, return_tensors="pt")
input_ids = input_ids.to("cuda")
output_ids = model.generate(
    input_ids,
    max_length=max_a_len,
    num_beams=4,
    early_stopping=True
)

print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

X-linked myotubular myopathy is a condition characterized by abnormalities of the myotubules (myotubules). The myotubules (myotubules) are a group of structures that make up the myotubules. Myotubular myopathy is a condition characterized by abnormalities of the myotubules (myotubules) and abnormalities of the myotubules (myotubules). X-linked myotubular myopathy is


In [29]:
# inference analysis B
input_text = prefix + question
input_ids = tokenizer.encode(input_text, return_tensors="pt")
input_ids = input_ids.to("cuda")
output_ids = model.generate(
    input_ids,
    max_length=max_a_len,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.2,
    num_return_sequences=3,
)

print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

X-linked myotubular myopathy is a disorder that affects the body and the nerves within the nervous system. It can develop at any age and can cause shortness of breath (difference to the blinking pattern of light) and difficulty breathing. Some people may experience long-term cognitive decline with spasticity or stiffness; in others, a shortness of breath or rapid heartbeat may result. The signs and symptoms of X-linked myotubular myopathy occur when the abnormal elasticity of nerve cells and the ability to connect blood vessels (fibromyalgi
