In [16]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
import tensorflow as tf
import tf_keras as keras
import numpy as np



In [17]:
# Load dataset
ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")

# Check structure
print(ds)
print(ds['train'][0])

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 112165
    })
})
{'instruction': "If you are a doctor, please answer the medical questions based on the patient's description.", 'input': 'I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.. After taking panadol and sleep for few hours, i still feel the same.. By the way, if i lay down or sit down, my head do not spin, only when i want to move around then i feel the whole world is spinning.. And it is normal stomach discomfort at the same time? Earlier after i relieved myself, the spinning lessen so i am not sure whether its connected or coincidences.. Thank you doc!', 'output': 'Hi, Thank you for posting your query. The most likely cause for your symptoms is benign paroxysmal positional vertigo (BPPV), a type of peripheral vertigo. I

In [18]:
# Preprocess the dataset


model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess(examples):
    inputs = []
    labels = []
    attention_masks = []
    for i in range(len(examples['input'])):
      input_text = "question: " + examples['input'][i]
      target_text = examples['output'][i]
      tokenized_input = tokenizer(input_text, padding="max_length", truncation=True, max_length=128)
      tokenized_label = tokenizer(target_text, padding="max_length", truncation=True, max_length=128)
      inputs.append(tokenized_input['input_ids'])
      attention_masks.append(tokenized_input['attention_mask'])
      labels.append(tokenized_label['input_ids'])
    return {'input_ids': inputs, 'attention_mask': attention_masks, 'labels': labels}

tokenized_ds = ds["train"].map(preprocess, batched=True)

Map: 100%|██████████| 112165/112165 [04:27<00:00, 419.10 examples/s]


In [19]:
# Fine-tune the model

model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name, use_safetensors=False)

train_dataset = tokenized_ds.shuffle(seed=42).select(range(10000)) 

def prepare_decoder_inputs(examples):
    decoder_input_ids = np.array(examples["labels"])[:, :-1].tolist()
    # Pad decoder_input_ids to the same length as labels
    padded_decoder_input_ids = [ids + [tokenizer.pad_token_id] * (128 - len(ids)) for ids in decoder_input_ids]
    examples["decoder_input_ids"] = padded_decoder_input_ids
    return examples

train_dataset = train_dataset.map(prepare_decoder_inputs, batched=True)


tf_dataset = train_dataset.to_tf_dataset(
    columns=['input_ids', 'attention_mask', 'decoder_input_ids'],
    label_cols=['labels'],
    batch_size=8,
    shuffle=True,
)


optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
# Use ignore_index to exclude padding tokens from loss calculation
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(tf_dataset, epochs=1)

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
Map: 100%|██████████| 10000/10000 [00:10<00:00, 968.87 examples/s] 





<tf_keras.src.callbacks.History at 0x12ee2e1b620>

In [20]:
# Save and test the model
model.save_pretrained("healthcare_chatbot_model")
tokenizer.save_pretrained("healthcare_chatbot_model")

# Test response
input_text = "question: I have a sore throat and cough, what should I do?"
inputs = tokenizer(input_text, return_tensors="tf")
output = model.generate(**inputs, max_length=50)
print(tokenizer.decode(output[0], skip_special_tokens=True))

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.



