In [1]:
%%capture
!pip install transformers
!pip install evaluate
!pip install rouge

In [2]:
import torch
import json
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
import nltk
import spacy
import string
import evaluate  # Bleu
from torch.utils.data import Dataset, DataLoader, RandomSampler
import pandas as pd
import numpy as np
import transformers
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, T5TokenizerFast

import warnings
warnings.filterwarnings("ignore")

In [None]:
TOKENIZER = T5TokenizerFast.from_pretrained("t5-base")
MODEL = T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True)
OPTIMIZER = Adam(MODEL.parameters(), lr=0.00001)
Q_LEN = 256   # Question Length
T_LEN = 32    # Target Length
BATCH_SIZE = 4
DEVICE = "cuda:0"

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
%%capture
!git clone https://github.com/LasseRegin/medical-question-answer-data.git

In [None]:
with open('/content/medical-question-answer-data/questionDoctorQAs.json') as f:
    data = json.load(f)

In [None]:
def prepare_data(data):
    articles = []

    for item in data:
        question = item['question']
        answer = item['answer']

        inputs = {"context": "", "question": question, "answer": answer}
        articles.append(inputs)

    return articles


In [None]:
data = prepare_data(data)
data = pd.DataFrame(data)

In [None]:
class QA_Dataset(Dataset):
    def __init__(self, tokenizer, dataframe, q_len, t_len):
        self.tokenizer = tokenizer
        self.q_len = q_len
        self.t_len = t_len
        self.data = dataframe
        self.questions = self.data["question"]
        self.context = self.data["context"]
        self.answer = self.data['answer']

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.context[idx]
        answer = self.answer[idx]

        question_tokenized = self.tokenizer(question, context, max_length=self.q_len, padding="max_length",
                                                    truncation=True, pad_to_max_length=True, add_special_tokens=True)
        answer_tokenized = self.tokenizer(answer, max_length=self.t_len, padding="max_length",
                                          truncation=True, pad_to_max_length=True, add_special_tokens=True)

        labels = torch.tensor(answer_tokenized["input_ids"], dtype=torch.long)
        labels[labels == 0] = -100

        return {
            "input_ids": torch.tensor(question_tokenized["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(question_tokenized["attention_mask"], dtype=torch.long),
            "labels": labels,
            "decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long)
        }

In [None]:
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

train_sampler = RandomSampler(train_data.index)
val_sampler = RandomSampler(val_data.index)

qa_dataset = QA_Dataset(TOKENIZER, data, Q_LEN, T_LEN)

train_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
val_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

In [None]:
MODEL.to(DEVICE)
OPTIMIZER = torch.optim.AdamW(MODEL.parameters(), lr=0.001)
train_loss = 0
val_loss = 0
train_batch_count = 0
val_batch_count = 0

epochs = 15

for epoch in range(epochs):
    MODEL.train()
    for batch in tqdm(train_loader, desc="Training batches"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)

        outputs = MODEL(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask
        )

        OPTIMIZER.zero_grad()
        outputs.loss.backward()
        OPTIMIZER.step()
        train_loss += outputs.loss.item()
        train_batch_count += 1

    # Evaluation
    MODEL.eval()
    for batch in tqdm(val_loader, desc="Validation batches"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)

        outputs = MODEL(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask
        )

        val_loss += outputs.loss.item()
        val_batch_count += 1

    print(f"{epoch + 1}/{epochs} -> Train loss: {train_loss / train_batch_count}\tValidation loss: {val_loss/val_batch_count}")


Training batches: 100%|██████████| 1136/1136 [06:14<00:00,  3.03it/s]
Validation batches: 100%|██████████| 284/284 [00:27<00:00, 10.32it/s]


1/15 -> Train loss: 3.5900005135737674	Validation loss: 2.714575819566216


Training batches: 100%|██████████| 1136/1136 [06:13<00:00,  3.04it/s]
Validation batches: 100%|██████████| 284/284 [00:25<00:00, 10.98it/s]


2/15 -> Train loss: 3.270136298113306	Validation loss: 2.4853080399859118


Training batches: 100%|██████████| 1136/1136 [06:11<00:00,  3.05it/s]
Validation batches: 100%|██████████| 284/284 [00:26<00:00, 10.90it/s]


3/15 -> Train loss: 3.032364084197322	Validation loss: 2.2689845171612752


Training batches: 100%|██████████| 1136/1136 [06:11<00:00,  3.06it/s]
Validation batches: 100%|██████████| 284/284 [00:25<00:00, 11.04it/s]


4/15 -> Train loss: 2.8256228385414457	Validation loss: 2.0537691790558084


Training batches: 100%|██████████| 1136/1136 [06:10<00:00,  3.06it/s]
Validation batches: 100%|██████████| 284/284 [00:26<00:00, 10.89it/s]


5/15 -> Train loss: 2.6326295553588532	Validation loss: 1.8593090673567543


Training batches: 100%|██████████| 1136/1136 [06:10<00:00,  3.06it/s]
Validation batches: 100%|██████████| 284/284 [00:25<00:00, 11.05it/s]


6/15 -> Train loss: 2.4515672997602134	Validation loss: 1.6817449469451613


Training batches: 100%|██████████| 1136/1136 [06:11<00:00,  3.06it/s]
Validation batches: 100%|██████████| 284/284 [00:25<00:00, 10.97it/s]


7/15 -> Train loss: 2.2859435630906035	Validation loss: 1.5315556901711334


Training batches: 100%|██████████| 1136/1136 [06:12<00:00,  3.05it/s]
Validation batches: 100%|██████████| 284/284 [00:25<00:00, 11.23it/s]


8/15 -> Train loss: 2.135354980308248	Validation loss: 1.3937909796597883


Training batches: 100%|██████████| 1136/1136 [06:02<00:00,  3.13it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.53it/s]


9/15 -> Train loss: 1.998259013623629	Validation loss: 1.2774844569487844


Training batches: 100%|██████████| 1136/1136 [06:01<00:00,  3.14it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.52it/s]


10/15 -> Train loss: 1.8760835580291673	Validation loss: 1.1784268406442773


Training batches: 100%|██████████| 1136/1136 [06:01<00:00,  3.14it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.53it/s]


11/15 -> Train loss: 1.7673222414801009	Validation loss: 1.093374191405354


Training batches: 100%|██████████| 1136/1136 [06:01<00:00,  3.14it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.57it/s]


12/15 -> Train loss: 1.6708128443326662	Validation loss: 1.0209011296313169


Training batches: 100%|██████████| 1136/1136 [06:00<00:00,  3.15it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.54it/s]


13/15 -> Train loss: 1.5858274443490776	Validation loss: 0.9581088623776628


Training batches: 100%|██████████| 1136/1136 [06:01<00:00,  3.15it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.56it/s]


14/15 -> Train loss: 1.509290379453338	Validation loss: 0.9024228125893075


Training batches: 100%|██████████| 1136/1136 [06:02<00:00,  3.14it/s]
Validation batches: 100%|██████████| 284/284 [00:24<00:00, 11.58it/s]

15/15 -> Train loss: 1.4404055625864878	Validation loss: 0.8537134486954536





In [None]:
MODEL.save_pretrained("qa_model")
TOKENIZER.save_pretrained("qa_tokenizer")

# Saved files
"""('qa_tokenizer/tokenizer_config.json',
 'qa_tokenizer/special_tokens_map.json',
 'qa_tokenizer/spiece.model',
'qa_tokenizer/added_tokens.json',
'qa_tokenizer/tokenizer.json')"""

"('qa_tokenizer/tokenizer_config.json',\n 'qa_tokenizer/special_tokens_map.json',\n 'qa_tokenizer/spiece.model',\n'qa_tokenizer/added_tokens.json',\n'qa_tokenizer/tokenizer.json')"

In [None]:
import zipfile
import os

folders_to_zip = ["/content/qa_model",
           "/content/qa_tokenizer"]

# Define the name of the zip file
zip_filename = '/content/model.zip'

# Create a zip file
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for folder in folders_to_zip:
        for root, dirs, files in os.walk(folder):
            for file in files:
                file_path = os.path.join(root, file)
                # Determine the file name inside the zip file
                relative_path = os.path.relpath(file_path, os.path.dirname(folder))
                zipf.write(file_path, os.path.join(os.path.basename(folder), relative_path))

# Verify that the zip file has been created
if os.path.exists(zip_filename):
    print(f'Zip file created: {zip_filename}')
else:
    print('Failed to create the zip file.')

# You can download the zip file in Google Colab
from google.colab import files
files.download(zip_filename)

Zip file created: /content/model.zip


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
def predict_answer(context, question, ref_answer=None):
    inputs = TOKENIZER(question, context, max_length=Q_LEN, padding="max_length", truncation=True, add_special_tokens=True)

    input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
    attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)

    outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask)

    predicted_answer = TOKENIZER.decode(outputs.flatten(), skip_special_tokens=True)

    if ref_answer:
        # Load the Bleu metric
        bleu = evaluate.load("google_bleu")
        score = bleu.compute(predictions=[predicted_answer],
                            references=[ref_answer])

        print("Context: \n", context)
        print("\n")
        print("Question: \n", question)
        return {
            "Reference Answer: ": ref_answer,
            "Predicted Answer: ": predicted_answer,
            "BLEU Score: ": score
        }
    else:
        return predicted_answer

In [None]:
context