## 03: Fine-Tune LLM for Medical FAQ Generation

### Overview
This notebook fine-tunes a `google/flan-t5-base` model on the preprocessed Medical FAQ dataset (`preprocessed_medquad.parquet`, 16,359 rows) to generate accurate answers for healthcare questions. It uses Hugging Face Transformers for training, tokenization, and evaluation, preparing the model for downstream RAG and multilingual FAQ applications.

### Purpose
- Convert Spark DataFrame to Hugging Face Dataset for fine-tuning.
- Tokenize QA pairs and train the model with Seq2SeqTrainer.
- Save the fine-tuned model and demonstrate predictions.

### Business Value
- Creates a domain-specific LLM for telehealth FAQs, enabling accurate, context-aware responses to reduce clinician workload.
- Supports multilingual expansion (e.g., Spanish translations), improving patient accessibility.
- Accelerates FAQ generation, potentially saving healthcare providers 60% in response time (per Azure economics).

### Technical Approach
- **Input**: `preprocessed_medquad.parquet` with lemmatized questions and original answers.
- **Model**: `google/flan-t5-base` (Seq2Seq, 250M parameters) fine-tuned for 3 epochs.
- **Training**: 90/10 train/validation split, max_length=128 input/256 output, batch_size=3.
- **Output**: Fine-tuned model at `/dbfs/mnt/faqdata/finetuned_llm_prototype` and demo predictions.
- **Runtime**: ~30-60 minutes on Databricks Community Edition (CPU/GPU if available).

### Prerequisites
- Preprocessed data from 02_preprocess_medquad.ipynb.
- Hugging Face Transformers and Datasets installed.
---

## Import Libraries and Load Preprocessed Pata
Load the preprocessed Parquet dataset (16,359 rows) from Azure Blob Storage using Spark, verifying row count and schema for fine-tuning readiness.

In [None]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, pipeline
from transformers.trainer_callback import TrainerCallback
import warnings
import logging
import os
from dotenv import load_dotenv
from IPython.display import display, HTML
import pandas as pd, json
import numpy as np
import time
import torch
import transformers
from IPython.display import display
import random
import evaluate

# Suppress warnings and logs
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Load environment variables
load_dotenv()

True

In [None]:
# Load the saved parquet file from Azure Blob
df = spark.read.parquet(f"/dbfs{os.getenv('MOUNT_PT')}/preprocessed_medquad.parquet")
print(f"Loaded Parquet with {df.count()} rows")

print("First 5 rows:")
df.show(5, truncate=30)

Loaded Parquet with 16359 rows
First 5 rows:
+-----------------+------------------------------+------------------------------+------------------------------+
|           source|                    focus_area|                lemma_question|               original_answer|
+-----------------+------------------------------+------------------------------+------------------------------+
|  NIHSeniorHealth|Peripheral Arterial Disease...|[symptom, peripheral, arter...|People who have P.A.D. may ...|
|        CancerGov|  Adult Acute Myeloid Leukemia|[risk, adult, acute, myeloi...|Smoking, previous chemother...|
|        CancerGov|Myelodysplastic/ Myeloproli...|[treatment, myelodysplastic...|Because myelodysplastic /my...|
|        CancerGov|                   Skin Cancer|[research, clinical, trial,...|New types of treatment are ...|
|MPlusHealthTopics|           Childbirth Problems|         [childbirth, problem]|While childbirth usually go...|
+-----------------+------------------------------+-

---
## Convert Data to Hugging Face 🤗 Dataset
### Purpose
Convert the Spark DataFrame to a Hugging Face Dataset, format QA pairs (e.g., "question: [lemma_question] answer:" + original_answer), and prepare for tokenization and training.

- This standardizes data for LLM fine-tuning, enabling accurate FAQ responses in telehealth.
- Handles variable-length arrays (lemma_question) for consistent input formatting.

### Technical Details
- `Dataset.from_pandas`: Converts Spark DF to HF Dataset (16,359 rows).
- `format_qa`: Creates `input_text` (e.g., "question: symptom peripheral arterial disease pad answer:") and `target_text` (original_answer).
- Sample: First row shown to verify QA pairs.

In [None]:
pandas_df = df.toPandas()
dataset = Dataset.from_pandas(pandas_df)
print(f"Converted dataset with {dataset.num_rows} rows")
print("Sample of first row:")
print(dataset[:1])

Converted dataset with 16359 rows
Sample of first row:
{'source': ['NIHSeniorHealth'], 'focus_area': ['Peripheral Arterial Disease (P.A.D.)'], 'lemma_question': [['symptom', 'peripheral', 'arterial', 'disease', 'pad']], 'original_answer': ['People who have P.A.D. may have symptoms when walking or climbing stairs. These may include pain, numbness, aching, or heaviness in the leg muscles. Symptoms may also include cramping in the affected leg(s) and in the buttocks, thighs, calves, and feet. Some possible signs of P.A.D. include - weak or absent pulses in the legs or feet  - sores or wounds on the toes, feet, or legs that heal slowly  - a pale or bluish color to the skin  - poor nail growth on the toes and decreased hair growth on the legs  - erectile dysfunction, especially among men who have diabetes. weak or absent pulses in the legs or feet sores or wounds on the toes, feet, or legs that heal slowly a pale or bluish color to the skin poor nail growth on the toes and decreased hair gr

In [None]:
# Prepare data as question-answer pairs
def format_qa(examples):
    input_texts = ["question: " + " ".join(q or []) + " answer:" for q in examples["lemma_question"]]
    target_texts = examples["original_answer"]
    return {"input_text": input_texts, "target_text": target_texts}

dataset = dataset.map(format_qa, batched=True)
print("Data formatted for fine-tuning:")
print(dataset[:1])

Map:   0%|          | 0/16359 [00:00<?, ? examples/s]

Data formatted for fine-tuning:
{'source': ['NIHSeniorHealth'], 'focus_area': ['Peripheral Arterial Disease (P.A.D.)'], 'lemma_question': [['symptom', 'peripheral', 'arterial', 'disease', 'pad']], 'original_answer': ['People who have P.A.D. may have symptoms when walking or climbing stairs. These may include pain, numbness, aching, or heaviness in the leg muscles. Symptoms may also include cramping in the affected leg(s) and in the buttocks, thighs, calves, and feet. Some possible signs of P.A.D. include - weak or absent pulses in the legs or feet  - sores or wounds on the toes, feet, or legs that heal slowly  - a pale or bluish color to the skin  - poor nail growth on the toes and decreased hair growth on the legs  - erectile dysfunction, especially among men who have diabetes. weak or absent pulses in the legs or feet sores or wounds on the toes, feet, or legs that heal slowly a pale or bluish color to the skin poor nail growth on the toes and decreased hair growth on the legs erecti

In [None]:
dataset

Dataset({
    features: ['source', 'focus_area', 'lemma_question', 'original_answer', 'input_text', 'target_text'],
    num_rows: 16359
})

---
## Initialize Model and Tokenizer

### Purpose
Load `google/flan-t5-base` model and tokenizer, preparing for tokenization of QA pairs with max_length limits for training efficiency.

- Uses a lightweight Seq2Seq model (250M parameters) for fast fine-tuning on CPU, suitable for healthcare FAQ generation without high compute costs.
- Tokenization ensures consistent input/output lengths, improving model performance on medical text.

### Technical Details
- Model: google/flan-t5-base (pretrained for instruction-following tasks like QA).
- Tokenizer: `Max_length=128` for input, `256` for target (to handle long answers).
- Padding/truncation enabled for batching.

In [None]:
# Initialize model and tokenizer
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print("Model and tokenizer initialized successfully.")

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Model and tokenizer initialized successfully.


---
## Tokenize Dataset

### Purpose
Tokenize input (`input_text`) and target (`target_text`) with padding/truncation, replacing pad tokens with -100 in labels for loss computation, preparing data for Seq2Seq traininValue
- Converts text to model-ready tensors, enabling efficient fine-tuning for accurate, multilingual FAQ responses.
- Handles variable-length medical terms (e.g., lemma_question arrays) without data loss.

### Technical Details
- `tokenize_function`: Tokenizes input (max_length=128) and labels (max_length=256).
- Labels: Pad tokens (-100) ignored in loss; uses `tokenizer.as_target_tokenizer()`.
- Batched mapping for efficiency.
- Output: Tokenized dataset with `input_ids`, `attention_mask`, `labels`.

In [None]:
# Tokenize data
def tokenize_function(examples):
    inputs = tokenizer(examples["input_text"], padding="max_length", truncation=True, max_length=128)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["target_text"], padding="max_length", truncation=True, max_length=256)["input_ids"]
    
    pad_id = tokenizer.pad_token_id
    labels = [[-100 if t == pad_id else t for t in seq] for seq in labels]
    inputs["labels"] = labels
    return inputs

tokenized_dataset = dataset.map(tokenize_function, batched=True)
print("Tokenize stage: Dataset tokenized successfully!!")

Map:   0%|          | 0/16359 [00:00<?, ? examples/s]

Tokenize stage: Dataset tokenized successfully!!


---
## Split Dataset and Configure Training

### Purpose
Split tokenized dataset into 90/10 train/validation sets and define training arguments (e.g., epochs, batch size) for fine-tuning with Seq2SeqTrainer.

- Trains a specialized LLM for medical FAQs, improving response accuracy and reducing hallucination risks in telehealth.
- Balanced split ensures reliable validation metrics for model evaluation.

### Technical Details
- Split: 90% train (14,723 rows), 10% validation (1,636 rows).
- Training Args: 3 epochs, batch_size=3 (CPU-friendly), save/eval every epoch, early stopping on eval_loss.
- Trainer: Seq2SeqTrainer with `predict_with_generate=True` for generation metrics.
- Output: Trainer configured for fine-tuning.

In [None]:
# Clear old files to free space (have only 2GB of free space)                       - Only run before training!
# dbutils.fs.rm(f"dbfs:{os.getenv('MOUNT_PT')}/results", recurse=True)
# dbutils.fs.rm(f"dbfs:{os.getenv('MOUNT_PT')}/logs", recurse=True)
# dbutils.fs.rm(f"dbfs:{os.getenv('MOUNT_PT')}/finetuned_llm_prototype", recurse=True)
# print("Clearing stage: Cleared old model files in DBFS!")

# Split dataset: 90% train, 10% validation
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]
print("Dataset split into train and evaluation sets.")

Dataset split into train and evaluation sets.


---
## Fine-Tune the Model

### Purpose
Train the `flan-t5-base` model on the tokenized QA pairs using Seq2SeqTrainer, monitoring loss and saving the fine-tuned model.

- Produces a domain-specific LLM for healthcare FAQs, enabling accurate answers for patient queries and supporting multilingual telehealth.
- 3 epochs balance training time and performance, optimizing for CPU resources.

### Technical Details
- Trainer: Fits on train/validation datasets, logs every epoch.
- Loss: Eval_loss ~1.589 with early stopping.
- Save: Model saved to as `finetuned_llm_prototype`.
- Runtime: ~31 hours (1 day 7 hours) on CPU.

In [1]:
# Training params (write directly to DBFS so it lands in Blob)
training_args = Seq2SeqTrainingArguments(
    output_dir=f"/dbfs{os.getenv('MOUNT_PT')}/results",
    num_train_epochs=3,
    per_device_train_batch_size=3,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_strategy="epoch",  # Logs at end of each epoch
    logging_steps=1,  
    save_total_limit=1,
    logging_dir=f"/dbfs{os.getenv('MOUNT_PT')}/logs",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    predict_with_generate=True,
    save_safetensors=False,
    ddp_find_unused_parameters=False
)

# Setup trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
print("Trainer configured successfully.")

Trainer configured successfully.


In [None]:
try:
    # Train
    trainer.train()
    print("Training stage: Model fine-tuning is successful!!!")

    # Manual save directly in DBFS
    safe_model_path = f"/dbfs{os.getenv('MOUNT_PT')}/finetuned_llm_prototype"
    tokenizer.save_pretrained(safe_model_path)
    trainer.save_model(safe_model_path)
    print(f"Saving stage: Model and tokenizer saved at {safe_model_path}") #see the path

except Exception as err:
    print("Error during model initialization or training:", err)
    raise

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,2.0446,1.681276
2,1.8083,1.608796
3,1.7407,1.589226


Training stage: Model fine-tuning is successful!!!
Saving stage: Model and tokenizer saved at /dbfs/mnt/faqdata/finetuned_llm_prototype


---
## Load Fine-Tuned Model and Tokenizer for testing

### Purpose
Load the fine-tuned `google/flan-t5-base` model and tokenizer from Azure Blob Storage for generating predictions and evaluations.

- Enables immediate use of the fine-tuned model for FAQ generation, supporting telehealth applications with accurate, context-aware responses.
- Leverages cloud storage for model persistence, ensuring scalability and accessibility.

### Technical Details
- Model: `AutoModelForSeq2SeqLM` with `AutoTokenizer`.
- Device: CPU.
- Pipeline: `text2text-generation` with generation parameters (max_new_tokens=256, num_beams=4). assessment.

In [None]:
# Load trained model and the tokenizer
save_dir = f"/dbfs{os.getenv('MOUNT_PT')}/finetuned_llm_prototype"

tokenizer = AutoTokenizer.from_pretrained(save_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(save_dir)
_ = model.to("cuda" if torch.cuda.is_available() else "cpu")

gen = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
)

GEN_KW = dict(
    max_new_tokens=256,
    num_beams=4,
    early_stopping=True,
    no_repeat_ngram_size=3,
    repetition_penalty=1.15,
    length_penalty=0.9,
)

Device set to use cpu


## Generate Prediction on Dataset Samplese
Generate a prediction for a sample question from the dataset to qualitatively assess the fine-tuned model’s performan
Value
- Validates the model’s ability to produce relevant medical answers, critical for reliable telehealth FAQs.
- Highlights potential issues (e.g., hallucinations) for refinement in the RAG pipeline.

### Technical Details
- Input: First row of dataset (e.g., "question: symptom peripheral arterial disease pad answer:").
- Generation: Uses `pipeline` with max_new_tokens=256, num_beams=4, early_stopping.
- Output: Predicted answer compared to reference (`original_answer`).
- Note: Hallucination observed (e.g., headache symptoms for PAD), indicating need for RAG grounding.

In [None]:
# Question from original (pre-tokenization) HF dataset
sample = dataset[0]  # non-tokenized 'dataset'
inp = f"question: {' '.join(sample['lemma_question'])} answer:"
pred = gen(inp, **GEN_KW)[0]["generated_text"]
print("INPUT:\n", inp)
print(f"\nPREDICTION:\n{pred}...")
print("\nREFERENCE ANSWER:\n", sample["original_answer"])


INPUT:
 question: symptom peripheral arterial disease pad answer:

PREDICTION:
What are the signs and symptoms of Peripheral arterial disease (PAD)? The Human Phenotype Ontology provides the following list of signs and symptom reports for PPD. If the information is available, the table below includes how often the symptom is seen in people with this condition. You can use the MedlinePlus Medical Dictionary to look up the definitions for these medical terms. Signs and Symptoms Approximate number of patients (when available) Abnormality of the arteries - Autosomal recessive inheritance - Arthralgia - Atherosclerosis - Hypertension - Muscular hypotonia - Hypoplasia of the corpus callosum - Decreased blood flow to the kidneys - Increased urination - Reduced blood pressure - The Mental Health Association of America (MHA) has collected information on how often a sign or symptom occurs in a condition. Much of this information comes from Orphanet, a European rare disease database. The frequenc

---
## Evaluate Model with ROUGE Metrics

Evaluate the fine-tuned model on 500 random validation samples using ROUGE metrics, saving predictions and scores for analysis.

- Quantifies model performance (e.g., ROUGE-1: 0.316, ROUGE-L: 0.248), guiding improvements for telehealth FAQ accuracy.
- Saves evaluation data (`eval_preds.csv`) for traceability and further analysis (e.g., verdict assignment).

### Technical Details
- Sample: 500 random rows from `eval_dataset` (input_text, target_text).
- Metrics: ROUGE-1, ROUGE-2, ROUGE-L, ROUGE-Lsum via `evaluate` library.
- Output: `eval_preds.csv` (input, reference, predicted) and `rouge.json`.
- Runtime: ~ 1 hour 33 minutes for 500 predictions.
- Examples: Displays first 3 of 500 evaluation samples (input_text, predicted_answer, reference_answer).
- Truncation: Inputs to 250 chars, predictions/references to 400 chars for readability.
- Issues: Example 2 (CKD nutrition) shows hallucination, suggesting RAG or keyword rules needed.

In [None]:
assert "input_text" in eval_dataset.column_names and "target_text" in eval_dataset.column_names
N = min(500, len(eval_dataset))
idxs = random.sample(range(len(eval_dataset)), k=N)
inputs = [eval_dataset[i]["input_text"] for i in idxs]
refs   = [eval_dataset[i]["target_text"] for i in idxs]

preds = [gen(x, **GEN_KW)[0]["generated_text"] for x in inputs]

rouge = evaluate.load("rouge")
scores = rouge.compute(predictions=preds, references=refs)

eval_df = pd.DataFrame({"input_text": inputs, "reference_answer": refs, "predicted_answer": preds})
eval_csv = f"/dbfs{os.getenv('MOUNT_PT')}/results/eval_preds.csv"
eval_df.to_csv(eval_csv, index=False)

with open(f"/dbfs{os.getenv('MOUNT_PT')}/results/rouge.json","w") as f:
    json.dump(scores, f, indent=2)

print("Saved: eval_preds.csv")
print("ROUGE:", scores)

Saved: eval_preds.csv
ROUGE: {'rouge1': 0.31576639961927666, 'rouge2': 0.17178996510205768, 'rougeL': 0.24845837693786155, 'rougeLsum': 0.25341844016035764}


In [None]:
for i in range(3):
    print(f"\n=== Example {i+1} ===")
    print("INPUT:", inputs[i][:250])
    print("\nPREDICTION:", preds[i][:400], "...")
    print("\nREFERENCE ANSWER:", refs[i][:400], "...")



=== Example 1 ===
INPUT: question: spastic paraplegia type 8 inherited answer:

PREDICTION: This condition is inherited in an autosomal dominant pattern, which means one copy of the altered gene in each cell is sufficient to cause the disorder. In most cases, an affected person inherits the mutation from one affected parent. ...

REFERENCE ANSWER: Spastic paraplegia type 8 is inherited in an autosomal dominant pattern, which means one copy of the altered gene in each cell is sufficient to cause the disorder.  In most cases, an affected person inherits the mutation from one affected parent. Other cases result from new mutations in the gene and occur in people with no history of the disorder in their family. ...

=== Example 2 ===
INPUT: question: nutrition early chronic kidney disease adult answer:

PREDICTION: Key Points - Nutrition is an important part of a healthy diet. - People who have chronic kidney disease (CKD) should eat a balanced diet that includes whole grains, fruits, vege

In [None]:
display(eval_df.head(10))
print("ROUGE:", scores)

Unnamed: 0,input_text,reference_answer,predicted_answer
0,question: spastic paraplegia type 8 inherited ...,Spastic paraplegia type 8 is inherited in an a...,This condition is inherited in an autosomal do...
1,question: nutrition early chronic kidney disea...,"As blood pressure rises, the risk of damage to...",Key Points - Nutrition is an important part of...
2,question: sprengel deformity answer:,Sprengel deformity is a congenital condition c...,Sprengel deformity is a condition that affects...
3,question: outlook spinal cord injury answer:,Spinal cord injuries are classified as either ...,The prognosis for people with spinal cord inju...
4,question: mitochondrial neurogastrointestinal ...,Mitochondrial neurogastrointestinal encephalop...,Mitochondrial neurogastrointestinal encephalop...
5,question: congenital varicella syndrome answer:,Congenital varicella syndrome is an extremely ...,Congenital varicella syndrome (CVS) is a rare ...
6,question: tinnitus answer:,Tinnitus is often described as a ringing in th...,Tinnitus is a condition that affects the nervo...
7,question: need know diarrhea answer:,"- Diarrhea is frequent, loose, and watery bowe...",Diarrhea is a disease in which the body's dige...
8,question: risk leukemia answer:,"For the most part, no one knows why some peopl...",Key Points - There are different types of leuk...
9,question: treatment acquired pure red cell apl...,How might acquired pure red cell aplasia be tr...,These resources address the diagnosis or manag...


ROUGE: {'rouge1': 0.31576639961927666, 'rouge2': 0.17178996510205768, 'rougeL': 0.24845837693786155, 'rougeLsum': 0.25341844016035764}


---
## Predict on New Questionse
Test the fine-tuned model on a een question ("symptoms of continuous headachioation.

### Technical Details
- Input: "question: what are symptoms of continuous headache? answer:".
- Generation: Uses `model.generate` with max_length=256, num_beams=8, no_repeat_ngram_size=2.
- Issue: Prediction shows hallucination (e.g., conjunctivitis references), reinforcing need for RAG grounding.

In [None]:
new_question = "question: what are symptoms of continous headache? answer:"
inputs = tokenizer(new_question, return_tensors="pt", max_length=128, truncation=True)
outputs = model.generate(**inputs, max_length=256, num_beams=8, early_stopping=True, no_repeat_ngram_size=2)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("New Question Input:")
print(new_question)
print("\nPredicted response:")
print(response)

New Question Input:
question: what are symptoms of continous headache? answer:

Predicted response:
What are the signs and symptoms of continous headache? Constant headaches can be caused by a variety of factors, including the following: - The severity of the headache varies from person to person. Some people may have no symptoms at all, while others may not have any symptoms. For example, if you have an X-linked recessive disorder, your symptoms may be milder than usual. Symptoms may include headache pain or tingling in the hands and feet, nausea, vomiting, and irritability. The Human Phenotype Ontology provides detailed information on the diagnosis and management of conjunctivitis. If the information is available, the table below includes how often the symptom is seen in people with this condition. You can use the MedlinePlus Medical Dictionary to look up the definitions for these medical terms. Signs of constipation and diarrhea may also be present in some cases.


---
## Demo Model Predictions

### Purpose
Demonstrates the fine-tuned model's performance on dataset samples and new queries, verifying QA generation quality.

- Validates the fine-tuned model’s ability to generate relevant medical answers, crucial for telehealth FAQ reliability.
- Provides qualitative evaluation beyond metrics, showcasing real-world applicability.

### Technical Details
- `demo_prediction`: Generates answers for dataset row 0 and new query (e.g., "symptoms of continuous headache").
- Uses `max_length=256`, `num_beams=8` for generation.
- Compares prediction to reference for qualitative assessment.

In [2]:
# Demo: Showcase model predictions
def demo_prediction(question, is_dataset=True):
    if is_dataset:
        sample = dataset[0]  # Use first dataset row
        input_text = f"question: {' '.join(sample['lemma_question'])} answer:"
        ref = sample["original_answer"]
    else:
        input_text = f"question: {question} answer:"
        ref = "No reference available"
    inputs = tokenizer(new_question, return_tensors="pt", max_length=128, truncation=True)
    outputs = model.generate(**inputs, max_length=256, num_beams=8, early_stopping=True, no_repeat_ngram_size=2)
    pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"INPUT: {input_text}\n")
    print(f"PREDICTION: {pred}...\n")
    print(f"REFERENCE: {ref[:256]}...\n")

print("### Dataset Question Demo")
demo_prediction(None, is_dataset=True)

### Dataset Question Demo
INPUT: question: symptom peripheral arterial disease pad answer:

PREDICTION: What are the signs and symptoms of continous headache? Constant headaches can be caused by a variety of factors, including the following: - The severity of the headache varies from person to person. Some people may have no symptoms at all, while others may not have any symptoms. For example, if you have an X-linked recessive disorder, your symptoms may be milder than usual. Symptoms may include headache pain or tingling in the hands and feet, nausea, vomiting, and irritability. The Human Phenotype Ontology provides detailed information on the diagnosis and management of conjunctivitis. If the information is available, the table below includes how often the symptom is seen in people with this condition. You can use the MedlinePlus Medical Dictionary to look up the definitions for these medical terms. Signs of constipation and diarrhea may also be present in some cases....

REFERENCE: 

---
## Conclusion

This notebook fine-tuned `google/flan-t5-base` on 16,359 preprocessed Medical FAQ pairs, achieving an eval_loss of ~1.589 after 3 epochs, and evaluated 500 validation samples with ROUGE scores (ROUGE-1: 0.316, ROUGE-L: 0.248). The model was saved as `finetuned_llm_prototype`, and predictions were generated for dataset samples and a new question ("symptoms of continuous headache"). While some answers were accurate (e.g., spastic paraplegia inheritance), hallucinations were observed (e.g., headache symptoms for PAD, incorrect CKD nutrition), indicating the need for RAG grounding.

### Key Results
- **Dataset**: 16,359 QA pairs tokenized, split 90/10 (14,723 train, 1,636 validation).
- **Training**: 3 epochs, batch_size=3, eval_loss from 1.681 to 1.589.
- **Evaluation**: 500 samples, ROUGE-1: 0.316, ROUGE-2: 0.172, ROUGE-L: 0.248, ROUGE-Lsum: 0.253.
- **Outputs**: Fine-tuned model, `eval_preds.csv`, `rouge.json`.
- **Demo**: Predictions showed strengths (correct inheritance) and weaknesses (hallucinations).
- **Runtime**: ~37 hours training, ~1 hour 33 minutes evaluation on Databricks Community Edition (CPU).

### Business Impact
- Produces a domain-specific LLM for telehealth FAQs, reducing clinician time on common queries.
- Evaluation data supports iterative improvements, enhancing FAQ reliability for patient care.
- Sets foundation for multilingual support and RAG, promoting equitable healthcare access.

### Next Steps
- **RAG Pipeline**: Build retrieval with LangChain and FAISS to ground predictions in dataset context (04_langchain_RAG.ipynb).
- **Verdict Assignment**: Add TF-IDF similarity and keyword rules to `eval_preds.csv` for `Correct`/`Incorrect` verdicts.
- **Multilingual Expansion**: Translate 331 `Correct` answers to Spanish/Telugu using GCP Translation API.
- **Deployment**: Develop Streamlit app for interactive FAQ queries, deployable on AWS SageMaker.
- **Portfolio Polish**: Update GitHub README with model performance, ROUGE scores, and screenshot of `eval_preds.csv`.

This step showcases LLM fine-tuning with Hugging Face, cloud integration, and evaluation for healthcare applications.