# 🏥 Health Domain Q&A Chatbot — Capstone Project

**Author:** Loue Sauveur Christian (Chriss)  
**Institution:** African Leadership University  
**Domain:** Healthcare (Medical Q&A Chatbot)  
**Purpose:**  
This project develops a healthcare question-answering chatbot that provides concise, accurate, and safe responses to health-related questions using a generative Transformer model (T5). The chatbot is trained on a large, domain-specific health dataset containing real medical questions and verified answers.

**Relevance & Justification:**  
- Healthcare misinformation is widespread; this chatbot provides reliable information from verified medical sources.  
- It supports users by answering general health questions safely (not replacing doctors).  
- Domain alignment ensures the model focuses strictly on health content, rejecting unrelated queries.


##  Imports

In [1]:
# Code cell: imports, constants, seed
import os
import re
import json
from pathlib import Path
from collections import Counter
import random
import logging
import math

import pandas as pd
import numpy as np

# NLP & ML libs
import nltk
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score
from datasets import Dataset, DatasetDict
import evaluate  # metrics library

# Transformers & sentence embeddings (TensorFlow)
import transformers
from transformers import (
    AutoTokenizer, 
    TFAutoModelForSeq2SeqLM, 
    T5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
)
from sentence_transformers import SentenceTransformer, util

# Faiss for vector index
import faiss

# TensorFlow
import tensorflow as tf

# Set logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)


  from .autonotebook import tqdm as notebook_tqdm
2025-10-16 16:30:49.684121: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-16 16:30:49.703035: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760625049.725325  224757 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760625049.732507  224757 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1760625049.751175  224757 computation_placer.cc:177] computation placer already r

## NlKT

In [2]:
# Run if missing packages (uncomment to install)
# !pip install -q nltk transformers datasets sentence-transformers faiss-cpu optuna accelerate rouge_score sacrebleu tensorflow

# NLTK resources
import nltk
nltk.download('punkt')
nltk.download('stopwords')


[nltk_data] Downloading package punkt to /home/lscblack/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/lscblack/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [3]:
# Load merged dataset
DATA_PATH = "../dataset/merged_health_dataset.csv"
df = pd.read_csv(DATA_PATH)

# Quick overview
print("Rows:", len(df))
print(df.columns.tolist())

# Display first few rows (works in Jupyter/Colab)
display(df.head(6))

# Show top 20 topics
display(df['topic'].value_counts().head(20))

# Show data split counts (train/test/val)
display(df['split'].value_counts())


Rows: 16371
['Question', 'Answer', 'topic', 'split']


Unnamed: 0,Question,Answer,topic,split
0,What is (are) Non-Small Cell Lung Cancer ?,Key Points - Non-small cell lung cancer is a d...,cancer,train
1,Who is at risk for Non-Small Cell Lung Cancer? ?,Smoking is the major risk factor for non-small...,cancer,train
2,What are the symptoms of Non-Small Cell Lung C...,Signs of non-small cell lung cancer include a ...,cancer,test
3,How to diagnose Non-Small Cell Lung Cancer ?,Tests that examine the lungs are used to detec...,cancer,train
4,What is the outlook for Non-Small Cell Lung Ca...,Certain factors affect prognosis (chance of re...,cancer,train
5,What are the stages of Non-Small Cell Lung Can...,Key Points - After lung cancer has been diagno...,cancer,train


topic
growth_hormone_receptor          5430
Genetic_and_Rare_Diseases        5388
Diabetes_Digestive_Kidney        1157
Neurological_Disorders_Stroke    1088
Other                             981
SeniorHealth                      769
cancer                            729
Heart_Lung_Blood                  559
Disease_Control_Prevention        270
Name: count, dtype: int64

split
train    13089
test      3282
Name: count, dtype: int64

In [4]:
# Drop rows missing Question/Answer
df = df.dropna(subset=['Question', 'Answer']).reset_index(drop=True)

# Normalize column names
df.columns = [c.strip() for c in df.columns]

# Lowercase and strip topic column
df['topic'] = df['topic'].astype(str).str.lower().str.strip()

# Drop exact duplicates based on Q/A
before = len(df)
df = df.drop_duplicates(subset=['Question', 'Answer'])
logger.info(f"Dropped {before - len(df)} exact duplicate rows")

# Remove rows with extremely short Questions or Answers
df = df[(df['Question'].str.len() > 10) & (df['Answer'].str.len() > 20)].reset_index(drop=True)
logger.info("After dropping very short Q/A: %d rows", len(df))

# Show a preview
df.head()


INFO:__main__:Dropped 13 exact duplicate rows
INFO:__main__:After dropping very short Q/A: 16357 rows


Unnamed: 0,Question,Answer,topic,split
0,What is (are) Non-Small Cell Lung Cancer ?,Key Points - Non-small cell lung cancer is a d...,cancer,train
1,Who is at risk for Non-Small Cell Lung Cancer? ?,Smoking is the major risk factor for non-small...,cancer,train
2,What are the symptoms of Non-Small Cell Lung C...,Signs of non-small cell lung cancer include a ...,cancer,test
3,How to diagnose Non-Small Cell Lung Cancer ?,Tests that examine the lungs are used to detec...,cancer,train
4,What is the outlook for Non-Small Cell Lung Ca...,Certain factors affect prognosis (chance of re...,cancer,train


In [5]:
# Cleaning utilities
import html
from nltk.corpus import stopwords

# Stopwords set
STOPWORDS = set(stopwords.words('english'))

# Text cleaning function
def clean_text(text):
    if not isinstance(text, str):
        return ""
    t = html.unescape(text)                        # unescape HTML entities
    t = re.sub(r'\n+', ' ', t)                     # remove newlines
    t = re.sub(r'\[.*?\]', ' ', t)                 # remove bracketed text
    t = re.sub(r'Key Points[:\s-]*', '', t, flags=re.I)  # remove "Key Points"
    t = re.sub(r'\s+', ' ', t)                     # collapse whitespace
    t = t.strip()
    return t

# Apply cleaning
df['Question_clean'] = df['Question'].apply(clean_text)
df['Answer_clean'] = df['Answer'].apply(clean_text)

# Optional: lowercase for retrieval / embedding tasks
df['Question_norm'] = df['Question_clean'].str.lower()
df['Answer_norm'] = df['Answer_clean'].str.strip()

# Preview cleaned data
df[['Question', 'Question_clean', 'Answer_clean']].head(6)


Unnamed: 0,Question,Question_clean,Answer_clean
0,What is (are) Non-Small Cell Lung Cancer ?,What is (are) Non-Small Cell Lung Cancer ?,Non-small cell lung cancer is a disease in whi...
1,Who is at risk for Non-Small Cell Lung Cancer? ?,Who is at risk for Non-Small Cell Lung Cancer? ?,Smoking is the major risk factor for non-small...
2,What are the symptoms of Non-Small Cell Lung C...,What are the symptoms of Non-Small Cell Lung C...,Signs of non-small cell lung cancer include a ...
3,How to diagnose Non-Small Cell Lung Cancer ?,How to diagnose Non-Small Cell Lung Cancer ?,Tests that examine the lungs are used to detec...
4,What is the outlook for Non-Small Cell Lung Ca...,What is the outlook for Non-Small Cell Lung Ca...,Certain factors affect prognosis (chance of re...
5,What are the stages of Non-Small Cell Lung Can...,What are the stages of Non-Small Cell Lung Can...,"After lung cancer has been diagnosed, tests ar..."


In [6]:
# Optional: detect language and keep English entries
# pip install langdetect
from langdetect import detect, DetectorFactory

# Set seed for reproducibility
DetectorFactory.seed = SEED

# Function to check if text is English
def is_english(s):
    try:
        return detect(s) == 'en'
    except:
        return False

# Apply language detection (small datasets; batch for large ones)
df['is_english'] = df['Question_clean'].apply(is_english)

# Report fraction of English entries
print("English fraction:", df['is_english'].mean())

# Keep only English rows
df = df[df['is_english']].reset_index(drop=True)


English fraction: 0.971510668215443


In [7]:
# Normalize topics: strip punctuation, unify synonyms
df['topic'] = df['topic'].str.replace(r'[^a-z0-9_ ]', '', regex=True).str.strip()

# Show top topics
display(df['topic'].value_counts().head(40))

# Optionally bin small topics into 'other' for stratified splitting
topic_counts = df['topic'].value_counts()
rare_threshold = 50  # tune depending on dataset size
rare_topics = topic_counts[topic_counts < rare_threshold].index.tolist()
df['topic_group'] = df['topic'].apply(lambda x: 'other' if x in rare_topics else x)

# Preview the grouped topics
display(df['topic_group'].value_counts().head(40))


topic
growth_hormone_receptor          5303
genetic_and_rare_diseases        5153
diabetes_digestive_kidney        1126
neurological_disorders_stroke    1082
other                             972
seniorhealth                      756
cancer                            712
heart_lung_blood                  529
disease_control_prevention        258
Name: count, dtype: int64

topic_group
growth_hormone_receptor          5303
genetic_and_rare_diseases        5153
diabetes_digestive_kidney        1126
neurological_disorders_stroke    1082
other                             972
seniorhealth                      756
cancer                            712
heart_lung_blood                  529
disease_control_prevention        258
Name: count, dtype: int64

In [8]:
OUT_DIR = Path("../models/processed-v7")
OUT_DIR.mkdir(parents=True, exist_ok=True)
clean_path = OUT_DIR / "health_qa_cleaned.csv"
df.to_csv(clean_path, index=False)
logger.info(f"Saved cleaned CSV to {clean_path}")


INFO:__main__:Saved cleaned CSV to ../models/processed-v7/health_qa_cleaned.csv


In [9]:
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

# -------------------------
# Stratified train/validation/test split
# -------------------------
train_df, temp_df = train_test_split(
    df, test_size=0.2, stratify=df['topic_group'], random_state=SEED
)
valid_df, test_df = train_test_split(
    temp_df, test_size=0.5, stratify=temp_df['topic_group'], random_state=SEED
)

print("Before oversampling:")
print("Train:", len(train_df), "Valid:", len(valid_df), "Test:", len(test_df))

# -------------------------
# Oversampling to balance classes
# -------------------------
print("\nApplying oversampling to training data...")
topic_counts = train_df['topic_group'].value_counts()
target_size = int(topic_counts.median())  # balance to median class size
print(f"Oversampling target: {target_size} samples per class")

oversampled_dfs = []
for topic in train_df['topic_group'].unique():
    topic_df = train_df[train_df['topic_group'] == topic]
    current_size = len(topic_df)
    
    if current_size < target_size:
        # Oversample minority class
        oversampled_df = resample(
            topic_df,
            replace=True,
            n_samples=target_size,
            random_state=SEED
        )
        oversampled_dfs.append(oversampled_df)
        print(f"  {topic}: {current_size} → {target_size} (oversampled)")
    else:
        # Keep majority class as is
        oversampled_dfs.append(topic_df)
        print(f"  {topic}: {current_size} (unchanged)")

# Combine oversampled data and shuffle
balanced_train_df = pd.concat(oversampled_dfs, ignore_index=True)
balanced_train_df = balanced_train_df.sample(frac=1, random_state=SEED).reset_index(drop=True)

print(f"\nAfter oversampling:")
print("Train:", len(balanced_train_df), "Valid:", len(valid_df), "Test:", len(test_df))

# -------------------------
# Save CSV splits for TF pipeline
# -------------------------
balanced_train_df.to_csv(OUT_DIR / "train.csv", index=False)
valid_df.to_csv(OUT_DIR / "valid.csv", index=False)
test_df.to_csv(OUT_DIR / "test.csv", index=False)

print("Balanced datasets saved!")


Before oversampling:
Train: 12712 Valid: 1589 Test: 1590

Applying oversampling to training data...
Oversampling target: 778 samples per class
  seniorhealth: 605 → 778 (oversampled)
  genetic_and_rare_diseases: 4122 (unchanged)
  diabetes_digestive_kidney: 901 (unchanged)
  growth_hormone_receptor: 4242 (unchanged)
  heart_lung_blood: 423 → 778 (oversampled)
  neurological_disorders_stroke: 865 (unchanged)
  cancer: 570 → 778 (oversampled)
  other: 778 (unchanged)
  disease_control_prevention: 206 → 778 (oversampled)

After oversampling:
Train: 14020 Valid: 1589 Test: 1590
Balanced datasets saved!


In [10]:
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np

# -------------------------
# Sentence embeddings: for retrieval / similarity search
# -------------------------
model_name = "all-MiniLM-L6-v2"  # small, fast, good for production
embedder = SentenceTransformer(model_name)

# Corpus: one passage per Answer (or chunk if very long)
corpus_df = balanced_train_df[['Answer_clean']].drop_duplicates().reset_index(drop=True)
corpus_texts = corpus_df['Answer_clean'].tolist()
print("Corpus size:", len(corpus_texts))

# Encode in batches
corpus_embeddings = embedder.encode(
    corpus_texts, 
    batch_size=2, 
    show_progress_bar=True, 
    convert_to_numpy=True
)

# -------------------------
# Build FAISS index (cosine similarity)
# -------------------------
d = corpus_embeddings.shape[1]

# Inner product works as cosine similarity if embeddings are normalized
index = faiss.IndexFlatIP(d)

# Normalize embeddings to unit length
faiss.normalize_L2(corpus_embeddings)
index.add(corpus_embeddings)
print("FAISS index size:", index.ntotal)

# -------------------------
# Save index and corpus texts for retrieval
# -------------------------
faiss.write_index(index, str(OUT_DIR / "faiss_index.ivf"))
np.save(OUT_DIR / "corpus_texts.npy", np.array(corpus_texts))
print("Saved FAISS index and corpus.")


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


Corpus size: 11921


Batches: 100%|██████████| 5961/5961 [00:34<00:00, 173.16it/s]


FAISS index size: 11921
Saved FAISS index and corpus.


In [11]:
# -------------------------
# Retrieval function using embeddings + FAISS
# -------------------------
def retrieve_answers(query, k=5):
    # Encode query
    q_emb = embedder.encode([query], convert_to_numpy=True)
    
    # Normalize to unit length for cosine similarity
    faiss.normalize_L2(q_emb)
    
    # Search FAISS index
    D, I = index.search(q_emb, k)
    
    # Collect top-k answers with scores
    results = []
    for idx, score in zip(I[0], D[0]):
        results.append({
            'answer': corpus_texts[idx],
            'score': float(score)
        })
    return results

# -------------------------
# Example usage
# -------------------------
query_example = "What are early symptoms of lung cancer?"
top_answers = retrieve_answers(query_example, k=3)
print(top_answers)


Batches: 100%|██████████| 1/1 [00:00<00:00, 213.35it/s]

[{'answer': "Signs of non-small cell lung cancer include a cough that doesn't go away and shortness of breath. Sometimes lung cancer does not cause any signs or symptoms. It may be found during a chest x-ray done for another condition. Signs and symptoms may be caused by lung cancer or by other conditions. Check with your doctor if you have any of the following: - Chest discomfort or pain. - A cough that doesnt go away or gets worse over time. - Trouble breathing. - Wheezing. - Blood in sputum (mucus coughed up from the lungs). - Hoarseness. - Loss of appetite. - Weight loss for no known reason. - Feeling very tired. - Trouble swallowing. - Swelling in the face and/or veins in the neck.", 'score': 0.7313530445098877}, {'answer': "Lung cancer is one of the most common cancers in the world. It is a leading cause of cancer death in men and women in the United States. Cigarette smoking causes most lung cancers. The more cigarettes you smoke per day and the earlier you started smoking, the 




In [13]:
from datasets import Dataset, DatasetDict
import tensorflow as tf
from transformers import AutoTokenizer

# -------------------------
# Build HF datasets
# -------------------------
hf_train = Dataset.from_pandas(balanced_train_df[['Question_clean','Answer_clean','topic_group']])
hf_valid = Dataset.from_pandas(valid_df[['Question_clean','Answer_clean','topic_group']])
hf_test = Dataset.from_pandas(test_df[['Question_clean','Answer_clean','topic_group']])

ds = DatasetDict({"train": hf_train, "validation": hf_valid, "test": hf_test})

# -------------------------
# Tokenizer & model
# -------------------------
MODEL_NAME = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

max_input_length = 256
max_target_length = 256

# -------------------------
# Preprocessing function
# -------------------------
def preprocess_function(examples):
    inputs = [f"question: {q} topic: {t}" for q, t in zip(examples['Question_clean'], examples['topic_group'])]
    
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        padding="max_length",
        truncation=True,
        return_tensors="tf"
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            [a if a is not None else "" for a in examples["Answer_clean"]],
            max_length=max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="tf"
        )

    # Set labels tensor directly
    model_inputs["labels"] = labels["input_ids"]  # <-- must be a tf.Tensor
    return model_inputs


# -------------------------
# Map preprocessing to datasets
# -------------------------
print("Tokenizing datasets...")
tokenized_train = hf_train.map(preprocess_function, batched=True, remove_columns=hf_train.column_names)
tokenized_valid = hf_valid.map(preprocess_function, batched=True, remove_columns=hf_valid.column_names)
tokenized_test = hf_test.map(preprocess_function, batched=True, remove_columns=hf_test.column_names)

# -------------------------
# Convert to TensorFlow datasets
# -------------------------
def to_tf_dataset(tokenized_dataset, batch_size=2, shuffle=True):
    columns = ["input_ids", "attention_mask", "labels"]
    
    # Important: use `label_cols=["labels"]` to separate features and labels
    tf_ds = tokenized_dataset.to_tf_dataset(
        columns=columns,            # features
        label_cols=["labels"],       # labels
        shuffle=shuffle,
        batch_size=batch_size,
        collate_fn=None,
    )
    return tf_ds


batch_size = 1
tf_train_ds = to_tf_dataset(tokenized_train, batch_size=batch_size)
tf_valid_ds = to_tf_dataset(tokenized_valid, batch_size=batch_size, shuffle=False)
tf_test_ds = to_tf_dataset(tokenized_test, batch_size=batch_size, shuffle=False)

print("TF datasets ready!")


Tokenizing datasets...


Map: 100%|██████████| 14020/14020 [00:02<00:00, 4794.71 examples/s]
Map: 100%|██████████| 1589/1589 [00:00<00:00, 5149.50 examples/s]
Map: 100%|██████████| 1590/1590 [00:00<00:00, 4876.74 examples/s]
Old behaviour: columns=['a'], labels=['labels'] -> (tf.Tensor, tf.Tensor)  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor)  
New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor})  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) 


TF datasets ready!


In [14]:
# Install if not already
# !pip install optuna plotly

import optuna
from optuna.trial import TrialState
import matplotlib.pyplot as plt


In [15]:
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM

# -------------------------
# Model
# -------------------------
model = TFAutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    from_pt=True,         # convert PyTorch weights to TF
    use_safetensors=True  # optional but faster & safer
)


# -------------------------
# Optimizer & learning rate
# -------------------------
learning_rate = 1e-5
optimizer = tf.keras.optimizers.Adam(
    learning_rate=learning_rate, epsilon=1e-08
)

# -------------------------
# Loss function
# -------------------------
# For T5, labels with -100 are masked automatically
loss = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True
)

# -------------------------
# Metrics
# -------------------------
# Can integrate ROUGE with a custom callback if needed
metrics = []  # leave empty; compute ROUGE offline or via custom callback

# -------------------------
# Compile model
# -------------------------
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

# -------------------------
# Callbacks
# -------------------------
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",      # or a custom metric
    patience=6,
    min_delta=0.001,
    restore_best_weights=True,
)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=str(OUT_DIR / "t5_health_ckpt"),
    save_weights_only=True,
    save_best_only=True,
    monitor="val_loss"
)

print(f"TensorFlow training setup ready for {MODEL_NAME}")


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFT5ForConditionalGeneration: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
- This IS expected if you are initializing TFT5ForConditionalGeneration from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFT5ForConditionalGeneration from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


TensorFlow training setup ready for google/flan-t5-small


In [16]:
import evaluate
from collections import Counter
import numpy as np

# -------------------------
# Metrics
# -------------------------
bleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")

# -------------------------
# Compute metrics function
# -------------------------
def compute_metrics_for_generation(predictions, labels, tokenizer):
    """
    Compute BLEU, ROUGE, Exact Match, and token-level F1 for generative QA.
    
    Args:
        predictions: np.array or list of token IDs (model outputs)
        labels: np.array or list of token IDs (targets)
        tokenizer: Hugging Face tokenizer
    
    Returns:
        Dictionary of metrics
    """
    # Convert to numpy arrays
    preds = np.array(predictions)
    labels_ids = np.array(labels)
    
    # Replace -100 with pad token
    labels_ids = np.where(labels_ids == -100, tokenizer.pad_token_id, labels_ids)
    
    # Filter invalid tokens (negative or beyond vocab)
    def filter_valid_tokens(token_ids, max_id=tokenizer.vocab_size):
        valid_tokens = []
        for seq in token_ids:
            valid_seq = [int(t) if 0 <= int(t) < max_id else tokenizer.pad_token_id for t in seq]
            valid_tokens.append(valid_seq)
        return np.array(valid_tokens)
    
    preds = filter_valid_tokens(preds)
    labels_ids = filter_valid_tokens(labels_ids)
    
    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    
    # BLEU
    try:
        bleu_score = bleu.compute(predictions=decoded_preds, references=[[r] for r in decoded_labels])['score']
    except:
        bleu_score = 0.0
    
    # ROUGE
    try:
        rouge_score = rouge.compute(predictions=decoded_preds, references=decoded_labels)
        rouge1 = rouge_score['rouge1']
        rouge2 = rouge_score['rouge2']
        rougel = rouge_score['rougeL']
    except:
        rouge1, rouge2, rougel = 0.0, 0.0, 0.0

    # Exact Match
    try:
        em = np.mean([int(p.strip() == l.strip()) for p, l in zip(decoded_preds, decoded_labels)])
    except:
        em = 0.0

    # Token-level F1
    def token_f1(a, b):
        a_tokens = a.split()
        b_tokens = b.split()
        common = Counter(a_tokens) & Counter(b_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0.0
        prec = num_same / max(1, len(a_tokens))
        rec = num_same / max(1, len(b_tokens))
        return 2 * prec * rec / (prec + rec)

    try:
        token_f1_mean = float(np.mean([token_f1(p, l) for p, l in zip(decoded_preds, decoded_labels)]))
    except:
        token_f1_mean = 0.0

    return {
        "bleu": bleu_score,
        "rouge1": rouge1,
        "rouge2": rouge2,
        "rougel": rougel,
        "exact_match": float(em),
        "token_f1": token_f1_mean
    }

print("TensorFlow-ready compute_metrics function is ready!")


TensorFlow-ready compute_metrics function is ready!


In [None]:
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM

# -------------------------
# Load model
# -------------------------
model = TFAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# -------------------------
# Optimizer & compile
# -------------------------
learning_rate = 1e-5
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer=optimizer, loss=loss)

# -------------------------
# Callbacks
# -------------------------
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=6,
    min_delta=0.001,
    restore_best_weights=True
)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=str(OUT_DIR / "t5_health_final_ckpt"),
    save_weights_only=True,
    save_best_only=True,
    monitor="val_loss"
)

# -------------------------
# Metrics callback (TF-style)
# -------------------------
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, val_dataset, tokenizer, max_length=128, num_beams=2):
        self.val_dataset = val_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_beams = num_beams

    def on_epoch_end(self, epoch, logs=None):
        pred_texts = []
        label_texts = []

        for batch_inputs, batch_labels in self.val_dataset:
            batch_size = batch_inputs["input_ids"].shape[0]

            for i in range(batch_size):
                input_ids = tf.expand_dims(batch_inputs["input_ids"][i], 0)
                attention_mask = tf.expand_dims(batch_inputs["attention_mask"][i], 0)

                y_pred = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_length=self.max_length,
                    num_beams=self.num_beams,
                    early_stopping=True
                )

                pred_texts.append(self.tokenizer.decode(y_pred[0], skip_special_tokens=True))
                label_texts.append(self.tokenizer.decode(batch_labels[i], skip_special_tokens=True))

        metrics = compute_metrics_for_generation(pred_texts, label_texts, self.tokenizer)
        print(f"\nEpoch {epoch+1} metrics: {metrics}")

metrics_cb = MetricsCallback(tf_valid_ds, tokenizer)

# -------------------------
# Train model
# -------------------------
epochs = 4
model.fit(
    tf_train_ds,
    validation_data=tf_valid_ds,
    epochs=epochs,
    callbacks=[early_stopping_cb, checkpoint_cb, metrics_cb],
)

# -------------------------
# Save model & tokenizer
# -------------------------
model.save_pretrained(OUT_DIR / "t5_health_final")
tokenizer.save_pretrained(OUT_DIR / "t5_health_final_tokenizer")

print("TensorFlow generative chatbot training complete!")


All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
I0000 00:00:1760625246.102681  225676 service.cc:152] XLA service 0x7a999c345530 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760625246.102706  225676 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 4060 Laptop GPU, Compute Capability 8.9
2025-10-16 16:34:06.109412: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1760625246.129955  225676 cuda_dnn.cc:529] Loaded cuDNN version 91002
I0000 00:00:1760625246.303328  225676 device_compiler.h:188] Compiled cluster



INFO:absl:Sharding callback duration: 181 microseconds


In [None]:
# -------------------------
# Evaluate model on test set
# -------------------------
from tqdm import tqdm

pred_texts = []
label_texts = []

# Loop over TF dataset
for batch_inputs, batch_labels in tqdm(tf_test_ds, desc="Evaluating"):
    batch_size = batch_inputs["input_ids"].shape[0]

    for i in range(batch_size):
        # Prepare single example for generation
        input_ids = tf.expand_dims(batch_inputs["input_ids"][i], 0)
        attention_mask = tf.expand_dims(batch_inputs["attention_mask"][i], 0)

        # Generate prediction
        y_pred = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=128,  # shorter for faster generation
            num_beams=2,     # lower beam count
            early_stopping=True
        )

        # Decode predicted tokens to text
        pred_text = tokenizer.decode(y_pred[0], skip_special_tokens=True)
        label_text = tokenizer.decode(batch_labels[i], skip_special_tokens=True)

        pred_texts.append(pred_text)
        label_texts.append(label_text)

# -------------------------
# Compute metrics
# -------------------------
metrics = compute_metrics_for_generation(pred_texts, label_texts, tokenizer)

print("Evaluation metrics on test set:")
for metric_name, metric_value in metrics.items():
    print(f"{metric_name}: {metric_value:.4f}")


In [None]:
import tensorflow as tf
import math
from tqdm import tqdm

def calc_perplexity_tf(model, tokenizer, dataset, batch_size=2, max_length=256):
    """
    Calculate perplexity for a generative seq2seq model in TensorFlow.

    Args:
        model: TFAutoModelForSeq2SeqLM
        tokenizer: Hugging Face tokenizer
        dataset: list/dict with 'Question_clean' and 'Answer_clean'
        batch_size: evaluation batch size
        max_length: maximum sequence length

    Returns:
        float: perplexity
    """
    total_loss = 0.0
    total_tokens = 0

    for i in tqdm(range(0, len(dataset), batch_size), desc="Calculating Perplexity"):
        batch = dataset[i:i + batch_size]

        # Encode inputs
        inputs = tokenizer(
            [item["Question_clean"] for item in batch],
            return_tensors="tf",
            padding=True,
            truncation=True,
            max_length=max_length,
        )

        # Encode labels
        labels = tokenizer(
            [item["Answer_clean"] for item in batch],
            return_tensors="tf",
            padding=True,
            truncation=True,
            max_length=max_length
        ).input_ids

        # Mask padding tokens
        label_mask = tf.not_equal(labels, tokenizer.pad_token_id)
        masked_labels = tf.where(label_mask, labels, -100)

        # Forward pass
        outputs = model(**inputs, labels=masked_labels, training=False)

        # Compute batch loss weighted by token count
        batch_token_count = tf.reduce_sum(tf.cast(label_mask, tf.float32))
        batch_loss = outputs.loss * batch_token_count

        total_loss += batch_loss.numpy()
        total_tokens += batch_token_count.numpy()

    avg_nll = total_loss / total_tokens
    perplexity = math.exp(avg_nll)
    return perplexity

# -------------------------
# Example usage
# -------------------------
# ppl = calc_perplexity_tf(model, tokenizer, hf_test)
# print("Perplexity (approx):", ppl)


In [None]:
import numpy as np
from sentence_transformers import CrossEncoder

# Load cross-encoder reranker
cross_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker = CrossEncoder(cross_model_name)

def answer_with_retrieval_and_generate_tf(question, top_k=5, gen_model=None, threshold=0.35, tokenizer=None):
    """
    Retrieve top passages using FAISS + cross-encoder reranking, then generate an answer.
    
    Args:
        question (str): User query
        top_k (int): Number of top passages to retrieve
        gen_model: TensorFlow seq2seq model (TFAutoModelForSeq2SeqLM)
        threshold (float): Minimum reranker score to trust context
        tokenizer: HF tokenizer corresponding to gen_model
    
    Returns:
        dict: {answer, score, source (optional)}
    """
    # 1️⃣ Retrieval: encode question and search FAISS
    q_emb = embedder.encode([question], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, top_k)
    candidates = [corpus_texts[idx] for idx in I[0]]

    # 2️⃣ Reranking with cross-encoder
    pairs = [[question, c] for c in candidates]
    scores = reranker.predict(pairs)
    best_idx = int(np.argmax(scores))
    best_score = float(scores[best_idx])
    best_passage = candidates[best_idx]

    # 3️⃣ Check confidence threshold
    if best_score < threshold:
        return {
            "answer": "I'm not sure about that. Please consult a health professional or provide more details.",
            "score": best_score
        }

    # 4️⃣ Generation
    if gen_model is not None and tokenizer is not None:
        prompt = f"Answer this health question: {question} [CONTEXT: {best_passage}]"
        inputs = tokenizer(prompt, return_tensors="tf", truncation=True, padding=True, max_length=256)

        # Generate output
        out = gen_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=256,
            num_beams=8,
            early_stopping=True,
            no_repeat_ngram_size=2,
            length_penalty=1.0,
            temperature=0.8,
            do_sample=True
        )

        gen_ans = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
        return {"answer": gen_ans, "score": best_score, "source": best_passage}

    # Fallback: return the best passage
    return {"answer": best_passage, "score": best_score}

# -------------------------
# Example usage
# -------------------------
# Ensure gen_model is your TF model and tokenizer is passed
# gen_model = model  # TFAutoModelForSeq2SeqLM
# tokenizer = tokenizer
print(answer_with_retrieval_and_generate_tf("What are symptoms of depression?", top_k=5, gen_model=model, tokenizer=tokenizer))


In [None]:
# -------------------------
# Example evaluation suite
# -------------------------
in_domain_examples = [
    "What are common symptoms of anxiety?",
    "How long does antidepressant medication take to work?",
    "I feel hopeless and can't sleep, what should I do?",
]

out_of_domain_examples = [
    "What's the best GPU for gaming?",
    "How do I cook rice?",
    "Explain football offside rule",
]

all_examples = in_domain_examples + out_of_domain_examples
results = []

for question in all_examples:
    res = answer_with_retrieval_and_generate_tf(
        question,
        top_k=5,
        gen_model=model,        # TFAutoModelForSeq2SeqLM
        tokenizer=tokenizer,
        threshold=0.35          # fallback if irrelevant
    )
    results.append((question, res['answer'], res.get('score')))

# -------------------------
# Display results
# -------------------------
for q, a, s in results:
    print(f"Q: {q}")
    print(f"Score: {s:.4f}" if s is not None else "Score: N/A")
    print(f"A: {a[:400]}")  # show first 400 chars
    print("---")


In [None]:
# -------------------------
# 1️⃣ Save SentenceTransformer embeddings
# -------------------------
embedder.save(str(OUT_DIR / "embedder_allMiniLM"))

# -------------------------
# 2️⃣ Save FAISS index and corpus texts
# -------------------------
faiss.write_index(index, str(OUT_DIR / "faiss.index"))
np.save(OUT_DIR / "corpus_texts.npy", np.array(corpus_texts))

# -------------------------
# 3️⃣ Save reranker info
# -------------------------
# CrossEncoders are not fully saveable in the same way as sentence-transformers.
# Recommended: store model name and optionally configuration
reranker_config = {"model_name": cross_model_name}
with open(OUT_DIR / "reranker_config.json", "w") as f:
    json.dump(reranker_config, f, indent=2)

# -------------------------
# 4️⃣ Save seq2seq model + tokenizer
# -------------------------
model.save_pretrained(OUT_DIR / "t5_health_final")
tokenizer.save_pretrained(OUT_DIR / "t5_health_final_tokenizer")

# -------------------------
# Completion message
# -------------------------
print(f"All artifacts saved under {OUT_DIR}")
