In [None]:
# Health Q&A Chatbot - TensorFlow Optimized Notebook (SIMPLIFIED VERSION)
# Optimized for 6GB GPU, 32GB RAM, 24 cores

# Install required packages
!pip install -q tensorflow sentence-transformers faiss-cpu pandas nltk scikit-learn matplotlib

import os
import re
import json
import math
import logging
from pathlib import Path
from collections import Counter
from typing import List, Dict, Tuple, Optional

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Embedding, LSTM, Bidirectional, Concatenate, GlobalMaxPooling1D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

# NLP
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import html

# Vector search
from sentence_transformers import SentenceTransformer
import faiss

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Set memory growth to avoid GPU memory issues
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth enabled")
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")

print(f"TensorFlow version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.experimental.list_physical_devices('GPU'))}")

# Configuration optimized for your hardware - SIMPLIFIED
CONFIG = {
    'batch_size': 8,
    'max_sequence_length': 128,
    'embedding_dim': 128,
    'lstm_units': 128,
    'dense_units': 64,
    'learning_rate': 1e-4,
    'max_vocab_size': 20000,
    'max_answers': 3,
    'similarity_threshold': 0.3,
    'epochs': 15  # Reduced for quick training
}

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class HealthDataProcessor:
    """Optimized data processor for health Q&A dataset"""
    
    def __init__(self, max_sequence_length=128, max_vocab_size=20000):
        self.max_sequence_length = max_sequence_length
        self.max_vocab_size = max_vocab_size
        self.question_tokenizer = None
        self.answer_tokenizer = None
        self.stop_words = set(stopwords.words('english'))
        
    def clean_text(self, text: str) -> str:
        """Clean and normalize text"""
        if not isinstance(text, str):
            return ""
        
        # HTML unescape and basic cleaning
        text = html.unescape(text)
        text = re.sub(r'\n+', ' ', text)
        text = re.sub(r'\[.*?\]', ' ', text)
        text = re.sub(r'Key Points[:\\s-]*', '', text, flags=re.I)
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        
        return text
    
    def preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """Preprocess the entire dataframe"""
        logger.info("Preprocessing dataframe...")
        
        # Create copies to avoid SettingWithCopyWarning
        df = df.copy()
        
        # Drop missing values
        initial_count = len(df)
        df = df.dropna(subset=['Question', 'Answer']).reset_index(drop=True)
        logger.info(f"Dropped {initial_count - len(df)} rows with missing values")
        
        # Clean text
        df['Question_clean'] = df['Question'].apply(self.clean_text)
        df['Answer_clean'] = df['Answer'].apply(self.clean_text)
        
        # Filter very short Q/A
        df = df[(df['Question_clean'].str.len() > 10) & 
                (df['Answer_clean'].str.len() > 20)].reset_index(drop=True)
        logger.info(f"After length filtering: {len(df)} rows")
        
        # Normalize topics
        if 'topic' in df.columns:
            df['topic'] = df['topic'].astype(str).str.lower().str.strip()
            df['topic'] = df['topic'].str.replace(r'[^a-z0-9_ ]', '', regex=True)
        
        # Remove duplicates
        df = df.drop_duplicates(subset=['Question_clean', 'Answer_clean']).reset_index(drop=True)
        logger.info(f"After deduplication: {len(df)} rows")
        
        return df
    
    def build_tokenizers(self, questions: List[str], answers: List[str]):
        """Build tokenizers for questions and answers"""
        logger.info("Building tokenizers...")
        
        self.question_tokenizer = Tokenizer(
            num_words=self.max_vocab_size, 
            oov_token='<OOV>',
            filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n'
        )
        self.answer_tokenizer = Tokenizer(
            num_words=self.max_vocab_size,
            oov_token='<OOV>',
            filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n'
        )
        
        self.question_tokenizer.fit_on_texts(questions)
        self.answer_tokenizer.fit_on_texts(answers)
        
        logger.info(f"Question vocab size: {len(self.question_tokenizer.word_index)}")
        logger.info(f"Answer vocab size: {len(self.answer_tokenizer.word_index)}")
    
    def texts_to_sequences(self, questions: List[str], answers: List[str]) -> Tuple[np.ndarray, np.ndarray]:
        """Convert texts to padded sequences"""
        question_sequences = self.question_tokenizer.texts_to_sequences(questions)
        answer_sequences = self.answer_tokenizer.texts_to_sequences(answers)
        
        X = pad_sequences(question_sequences, maxlen=self.max_sequence_length, padding='post')
        y = pad_sequences(answer_sequences, maxlen=self.max_sequence_length, padding='post')
        
        return X, y

class SimpleHealthModel:
    """Simplified Health Q&A Model using TensorFlow - NO ATTENTION"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.model = None
        self.data_processor = None
        
    def build_model(self, question_vocab_size: int, answer_vocab_size: int) -> Model:
        """Build a SIMPLIFIED seq2seq model for health Q&A - NO ATTENTION"""
        logger.info("Building SIMPLIFIED model architecture...")
        
        # ===== SIMPLIFIED ENCODER =====
        encoder_inputs = Input(shape=(self.config['max_sequence_length'],), name='encoder_inputs')
        encoder_embedding = Embedding(
            input_dim=question_vocab_size + 1,
            output_dim=self.config['embedding_dim'],
            mask_zero=True,
            name='encoder_embedding'
        )(encoder_inputs)
        
        # Single LSTM layer for encoder (simpler)
        encoder_lstm = LSTM(
            self.config['lstm_units'], 
            return_sequences=False,  # Only return final state
            return_state=True,
            dropout=0.2,
            recurrent_dropout=0.1,
            name='encoder_lstm'
        )
        encoder_outputs, state_h, state_c = encoder_lstm(encoder_embedding)
        encoder_states = [state_h, state_c]
        
        # ===== SIMPLIFIED DECODER =====
        decoder_inputs = Input(shape=(self.config['max_sequence_length'],), name='decoder_inputs')
        decoder_embedding = Embedding(
            input_dim=answer_vocab_size + 1,
            output_dim=self.config['embedding_dim'],
            mask_zero=True,
            name='decoder_embedding'
        )(decoder_inputs)
        
        # Single LSTM layer for decoder
        decoder_lstm = LSTM(
            self.config['lstm_units'],
            return_sequences=True, 
            return_state=False,  # Don't return states for simplicity
            dropout=0.2,
            recurrent_dropout=0.1,
            name='decoder_lstm'
        )
        decoder_outputs = decoder_lstm(decoder_embedding, initial_state=encoder_states)
        
        # ===== SIMPLIFIED OUTPUT =====
        # Global pooling instead of attention
        pooled_output = GlobalMaxPooling1D(name='global_pooling')(decoder_outputs)
        
        # Dense layers
        dense1 = Dense(self.config['dense_units'], activation='relu', name='dense_1')(pooled_output)
        outputs = Dense(answer_vocab_size, activation='softmax', name='outputs')(dense1)
        
        # Create model
        model = Model(
            inputs=[encoder_inputs, decoder_inputs],
            outputs=outputs,
            name='simple_health_qa_model'
        )
        
        # Custom optimizer with gradient clipping
        optimizer = Adam(
            learning_rate=self.config['learning_rate'],
            clipnorm=1.0
        )
        
        model.compile(
            optimizer=optimizer,
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        logger.info("Simplified model built successfully")
        return model
    
    def train(self, X: np.ndarray, y: np.ndarray, validation_data: Tuple, epochs: int = 15):
        """Train the simplified model"""
        logger.info("Starting simplified model training...")
        
        # For this simplified model, we'll use a different approach
        # Let's treat it as a classification problem instead of seq2seq
        
        # Build model if not already built
        if self.model is None:
            question_vocab_size = len(self.data_processor.question_tokenizer.word_index)
            answer_vocab_size = len(self.data_processor.answer_tokenizer.word_index)
            self.model = self.build_model(question_vocab_size, answer_vocab_size)
        
        # Model summary
        print("\nSimplified Model Architecture:")
        self.model.summary()
        
        # Simple callbacks
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=3,
                restore_best_weights=True,
                verbose=1
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=2,
                min_lr=1e-7,
                verbose=1
            )
        ]
        
        # For this simplified approach, let's create dummy decoder inputs
        # We'll use zeros as decoder inputs since we're not doing proper seq2seq
        decoder_input_data = np.zeros_like(X)
        
        # Target is the answer sequences
        target_data = y
        
        # Prepare validation data similarly
        val_X, val_y = validation_data
        val_decoder_input = np.zeros_like(val_X)
        
        # Training
        history = self.model.fit(
            [X, decoder_input_data],
            target_data,
            batch_size=self.config['batch_size'],
            epochs=epochs,
            validation_data=([val_X, val_decoder_input], val_y),
            callbacks=callbacks,
            verbose=1,
            shuffle=True
        )
        
        logger.info("Simplified model training completed")
        return history

class RetrievalSystem:
    """Efficient retrieval system for health Q&A"""
    
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        self.embedder = SentenceTransformer(model_name)
        self.index = None
        self.corpus_texts = []
        
    def build_index(self, answers: List[str], batch_size: int = 16):
        """Build FAISS index for efficient retrieval"""
        logger.info("Building FAISS index...")
        
        self.corpus_texts = answers
        
        # Encode in batches to manage memory
        corpus_embeddings = self.embedder.encode(
            self.corpus_texts, 
            batch_size=batch_size, 
            show_progress_bar=True,
            convert_to_numpy=True
        )
        
        # Normalize for cosine similarity
        faiss.normalize_L2(corpus_embeddings)
        
        # Use IndexFlatIP for inner product (cosine similarity)
        d = corpus_embeddings.shape[1]
        self.index = faiss.IndexFlatIP(d)
        self.index.add(corpus_embeddings)
        
        logger.info(f"FAISS index built with {self.index.ntotal} entries")
    
    def retrieve(self, query: str, k: int = 5, threshold: float = 0.3) -> List[Dict]:
        """Retrieve similar answers with similarity threshold"""
        if self.index is None:
            raise ValueError("Index not built. Call build_index first.")
        
        # Encode query
        query_embedding = self.embedder.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(query_embedding)
        
        # Search
        D, I = self.index.search(query_embedding, k)
        
        results = []
        for idx, score in zip(I[0], D[0]):
            if score >= threshold:
                results.append({
                    'answer': self.corpus_texts[idx],
                    'score': float(score)
                })
        
        return results

class HealthChatbot:
    """Main Health Chatbot Class - Focused on Retrieval"""
    
    def __init__(self, config: Dict = None):
        self.config = config or CONFIG
        self.data_processor = HealthDataProcessor(
            max_sequence_length=self.config['max_sequence_length'],
            max_vocab_size=self.config['max_vocab_size']
        )
        self.generative_model = SimpleHealthModel(self.config)  # Use simplified model
        self.retrieval_system = RetrievalSystem()
        self.is_trained = False
        
    def load_and_preprocess_data(self, data_path: str) -> pd.DataFrame:
        """Load and preprocess the health dataset"""
        logger.info(f"Loading data from {data_path}")
        df = pd.read_csv(data_path)
        
        # Basic data info
        print(f"Dataset shape: {df.shape}")
        print(f"Columns: {df.columns.tolist()}")
        
        if 'topic' in df.columns:
            print("\nTopic distribution:")
            print(df['topic'].value_counts().head(10))
        
        return self.data_processor.preprocess_dataframe(df)
    
    def prepare_training_splits(self, df: pd.DataFrame, test_size: float = 0.2, val_size: float = 0.1):
        """Prepare training, validation, and test splits"""
        from sklearn.model_selection import train_test_split
        
        # Use smaller subset for quick training
        if len(df) > 3000:
            df = df.sample(n=3000, random_state=42)  # Use smaller subset
            print(f"Using subset of {len(df)} samples for faster training")
        
        # First split: training vs temporary
        train_df, temp_df = train_test_split(
            df, test_size=test_size, random_state=42
        )
        
        # Second split: validation vs test
        val_df, test_df = train_test_split(
            temp_df, test_size=val_size/(test_size + val_size), random_state=42
        )
        
        logger.info(f"Training samples: {len(train_df)}")
        logger.info(f"Validation samples: {len(val_df)}")
        logger.info(f"Test samples: {len(test_df)}")
        
        return train_df, val_df, test_df
    
    def train_generative_model(self, train_df: pd.DataFrame, val_df: pd.DataFrame, epochs: int = 15):
        """Train the simplified generative model"""
        logger.info("Training SIMPLIFIED generative model...")
        
        # Prepare questions and answers
        train_questions = train_df['Question_clean'].tolist()
        train_answers = train_df['Answer_clean'].tolist()
        val_questions = val_df['Question_clean'].tolist()
        val_answers = val_df['Answer_clean'].tolist()
        
        # Build tokenizers
        self.data_processor.build_tokenizers(train_questions, train_answers)
        self.generative_model.data_processor = self.data_processor
        
        # Convert to sequences
        X_train, y_train = self.data_processor.texts_to_sequences(train_questions, train_answers)
        X_val, y_val = self.data_processor.texts_to_sequences(val_questions, val_answers)
        
        print(f"Training data shape: {X_train.shape}, {y_train.shape}")
        print(f"Validation data shape: {X_val.shape}, {y_val.shape}")
        
        # Train simplified model
        history = self.generative_model.train(X_train, y_train, (X_val, y_val), epochs=epochs)
        
        self.is_trained = True
        return history
    
    def build_retrieval_system(self, df: pd.DataFrame):
        """Build the retrieval system"""
        logger.info("Building retrieval system...")
        answers = df['Answer_clean'].drop_duplicates().tolist()
        print(f"Building index with {len(answers)} unique answers")
        self.retrieval_system.build_index(answers, batch_size=16)
    
    def generate_answer(self, question: str, use_retrieval: bool = True) -> Dict:
        """Generate answer for a question - PRIMARY: RETRIEVAL"""
        
        # Always use retrieval for now - it's more reliable
        retrieved = self.retrieval_system.retrieve(
            question, 
            k=self.config['max_answers'],
            threshold=self.config['similarity_threshold']
        )
        
        if not retrieved:
            return {
                "answer": "I'm not sure about that specific medical question. Please consult a healthcare professional for accurate medical advice.",
                "score": 0.0,
                "method": "fallback",
                "source": None
            }
        
        # Use the best retrieved answer
        best_result = retrieved[0]
        
        return {
            "answer": best_result['answer'],
            "score": best_result['score'],
            "method": "retrieval",
            "source": "medical knowledge base"
        }
    
    def evaluate_retrieval(self, test_df: pd.DataFrame, k: int = 5) -> Dict:
        """Evaluate retrieval system performance"""
        logger.info("Evaluating retrieval system...")
        
        hits = 0
        total = min(50, len(test_df))  # Smaller sample for quick evaluation
        
        for i, (_, row) in enumerate(test_df.head(total).iterrows()):
            question = row['Question_clean']
            true_answer = row['Answer_clean']
            
            retrieved = self.retrieval_system.retrieve(question, k=k)
            retrieved_answers = [r['answer'] for r in retrieved]
            
            # Check if true answer is in retrieved set (exact match)
            if true_answer in retrieved_answers:
                hits += 1
        
        recall_at_k = hits / total
        return {"recall@k": recall_at_k, "evaluated_samples": total}
    
    def plot_training_history(self, history):
        """Plot training history"""
        if not history:
            print("No training history to plot")
            return
            
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot loss
        ax1.plot(history.history['loss'], label='Training Loss')
        ax1.plot(history.history['val_loss'], label='Validation Loss')
        ax1.set_title('Model Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot accuracy
        ax2.plot(history.history['accuracy'], label='Training Accuracy')
        ax2.plot(history.history['val_accuracy'], label='Validation Accuracy')
        ax2.set_title('Model Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def save_models(self, model_dir: str = "health_chatbot_models"):
        """Save all models and components"""
        model_dir = Path(model_dir)
        model_dir.mkdir(parents=True, exist_ok=True)
        
        # Save generative model (if trained)
        if self.generative_model.model and self.is_trained:
            self.generative_model.model.save(model_dir / "generative_model.h5")
            print("✓ Generative model saved")
        
        # Save tokenizers
        if self.data_processor.question_tokenizer:
            with open(model_dir / "question_tokenizer.json", 'w') as f:
                f.write(self.data_processor.question_tokenizer.to_json())
            with open(model_dir / "answer_tokenizer.json", 'w') as f:
                f.write(self.data_processor.answer_tokenizer.to_json())
            print("✓ Tokenizers saved")
        
        # Save retrieval system
        if self.retrieval_system.index:
            faiss.write_index(self.retrieval_system.index, str(model_dir / "faiss_index.bin"))
            np.save(model_dir / "corpus_texts.npy", np.array(self.retrieval_system.corpus_texts))
            print("✓ Retrieval system saved")
        
        # Save config
        with open(model_dir / "config.json", 'w') as f:
            json.dump(self.config, f, indent=2)
        
        print(f"✓ All models saved to {model_dir}")
    
    def load_models(self, model_dir: str = "health_chatbot_models"):
        """Load saved models and components"""
        model_dir = Path(model_dir)
        
        try:
            # Load config
            with open(model_dir / "config.json", 'r') as f:
                loaded_config = json.load(f)
                self.config.update(loaded_config)
            
            # Load tokenizers
            if (model_dir / "question_tokenizer.json").exists():
                with open(model_dir / "question_tokenizer.json", 'r') as f:
                    self.data_processor.question_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(f.read())
            if (model_dir / "answer_tokenizer.json").exists():
                with open(model_dir / "answer_tokenizer.json", 'r') as f:
                    self.data_processor.answer_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(f.read())
            
            # Load generative model
            if (model_dir / "generative_model.h5").exists():
                self.generative_model.model = tf.keras.models.load_model(
                    model_dir / "generative_model.h5"
                )
                self.is_trained = True
            
            # Load retrieval system
            if (model_dir / "faiss_index.bin").exists():
                self.retrieval_system.index = faiss.read_index(str(model_dir / "faiss_index.bin"))
                self.retrieval_system.corpus_texts = np.load(
                    model_dir / "corpus_texts.npy", 
                    allow_pickle=True
                ).tolist()
            
            print("✓ Models loaded successfully")
            
        except Exception as e:
            print(f"✗ Error loading models: {e}")

# =============================================================================
# MAIN EXECUTION - RETRIEVAL-FOCUSED APPROACH
# =============================================================================

def main():
    """Main function - Focused on reliable retrieval system"""
    
    print("🚀 Health Q&A Chatbot - Retrieval-Focused Version")
    print("=" * 60)
    
    # Initialize chatbot
    chatbot = HealthChatbot()
    
    try:
        # Load and preprocess data
        print("\n📊 Loading and preprocessing data...")
        df = chatbot.load_and_preprocess_data("../dataset/merged_health_dataset.csv")
        
        # Display sample data
        print("\nSample processed data:")
        print(df[['Question_clean', 'Answer_clean']].head(2))
        
        # Build retrieval system (PRIMARY COMPONENT)
        print("\n🔍 Building retrieval system...")
        chatbot.build_retrieval_system(df)
        
        # Test retrieval system immediately
        print("\n🧪 Testing retrieval system...")
        test_questions = [
            "What are symptoms of depression?",
            "How to improve sleep quality?",
            "What foods are good for heart health?"
        ]
        
        for question in test_questions:
            result = chatbot.retrieval_system.retrieve(question)
            print(f"\nQuery: '{question}'")
            print(f"Found {len(result)} results:")
            for i, res in enumerate(result[:2]):
                print(f"  {i+1}. Score: {res['score']:.3f}")
                print(f"     {res['answer'][:80]}...")
        
        # Try simplified generative training (optional)
        try:
            print("\n🧠 Attempting simplified generative model training...")
            train_df, val_df, test_df = chatbot.prepare_training_splits(df)
            history = chatbot.train_generative_model(train_df, val_df, epochs=chatbot.config['epochs'])
            chatbot.plot_training_history(history)
        except Exception as e:
            print(f"⚠️  Generative training skipped: {e}")
            print("🎯 Continuing with retrieval-only system (more reliable)")
        
        # Evaluate retrieval system
        print("\n📊 Evaluating retrieval system...")
        if 'test_df' in locals():
            retrieval_metrics = chatbot.evaluate_retrieval(test_df)
            print(f"Retrieval Recall@{retrieval_metrics['evaluated_samples']}: {retrieval_metrics['recall@k']:.3f}")
        
        # Save models
        print("\n💾 Saving models...")
        chatbot.save_models()
        
        # Test the complete chatbot
        print("\n🤖 Testing Health Chatbot:")
        print("=" * 50)
        
        test_questions = [
            "What are common symptoms of anxiety?",
            "How can I improve my sleep quality?",
            "What foods are good for heart health?",
            "How to manage stress effectively?",
            "What are the benefits of regular exercise?"
        ]
        
        for i, question in enumerate(test_questions, 1):
            result = chatbot.generate_answer(question)
            print(f"\n{i}. Q: {question}")
            print(f"   A: {result['answer'][:120]}...")
            print(f"   Method: {result['method']}, Score: {result.get('score', 0):.3f}")
        
        # Interactive testing
        print("\n" + "=" * 50)
        print("💬 Interactive Chat Mode - RETRIEVAL SYSTEM")
        print("Type 'quit' to exit")
        print("=" * 50)
        
        while True:
            question = input("\n🤔 Enter your health question: ").strip()
            
            if question.lower() == 'quit':
                break
            elif not question:
                continue
                
            result = chatbot.generate_answer(question)
            
            print(f"\n💡 Answer (from medical knowledge base):")
            print(f"   {result['answer']}")
            if result.get('score'):
                print(f"   Confidence: {result['score']:.3f}")
        
        print("\n✅ Health Chatbot ready for use!")
        
    except Exception as e:
        logger.error(f"Error in main execution: {e}")
        print(f"❌ Error: {e}")

# Run the main function
if __name__ == "__main__":
    main()

# =============================================================================
# QUICK RETRIEVAL-ONLY CHATBOT (RECOMMENDED - FAST & RELIABLE)
# =============================================================================

def quick_retrieval_chatbot():
    """Fast retrieval-only chatbot - no training required"""
    print("⚡ Quick Retrieval-Only Health Chatbot")
    print("=" * 50)
    
    # Load data and build retrieval system only
    chatbot = HealthChatbot()
    
    try:
        df = chatbot.load_and_preprocess_data("../dataset/merged_health_dataset.csv")
        chatbot.build_retrieval_system(df)
        
        print("✅ Retrieval system ready!")
        print("You can now ask health questions...\n")
        
        while True:
            question = input("🤔 Your question (or 'quit'): ").strip()
            if question.lower() == 'quit':
                break
            elif question.lower() == 'demo':
                # Demo questions
                demo_questions = [
                    "What are symptoms of flu?",
                    "How to lower blood pressure?",
                    "Benefits of exercise?",
                    "What is healthy diet?"
                ]
                for q in demo_questions:
                    result = chatbot.generate_answer(q)
                    print(f"\nQ: {q}")
                    print(f"A: {result['answer'][:100]}...")
                    print(f"   (Score: {result.get('score', 0):.3f})")
                continue
            elif not question:
                continue
                
            result = chatbot.generate_answer(question)
            print(f"\n💡 {result['answer']}")
            if result.get('score'):
                print(f"   (Confidence: {result['score']:.3f})")
                
    except Exception as e:
        print(f"❌ Error: {e}")

# Uncomment for quick retrieval-only chatbot (RECOMMENDED)
# quick_retrieval_chatbot()

# =============================================================================
# LOAD AND USE PRE-BUILT MODEL
# =============================================================================

def load_and_chat():
    """Load pre-built model and chat"""
    print("🔮 Loading pre-trained chatbot...")
    
    chatbot = HealthChatbot()
    
    try:
        chatbot.load_models("health_chatbot_models")
        print("✓ Chatbot loaded successfully!")
        print("Type 'quit' to exit\n")
        
        while True:
            question = input("🤔 Your question: ").strip()
            if question.lower() == 'quit':
                break
                
            result = chatbot.generate_answer(question)
            print(f"\n💡 {result['answer']}")
            if result.get('score'):
                print(f"   (Confidence: {result['score']:.3f})")
            
    except Exception as e:
        print(f"❌ Error loading chatbot: {e}")
        print("Please run the training cell first or use quick_retrieval_chatbot()")

# Uncomment to load and use pre-built model
# load_and_chat()

2025-10-16 18:19:43.598437: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-16 18:19:43.617114: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760631583.638500  274526 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760631583.645635  274526 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1760631583.662284  274526 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

GPU memory growth enabled
TensorFlow version: 2.19.0
Num GPUs Available: 1
🚀 Health Q&A Chatbot - TensorFlow Optimized Version


INFO:__main__:Loading data from ../dataset/merged_health_dataset.csv



📊 Loading and preprocessing data...


INFO:__main__:Preprocessing dataframe...
INFO:__main__:Dropped 0 rows with missing values


Dataset shape: (16371, 4)
Columns: ['Question', 'Answer', 'topic', 'split']

Topic distribution:
topic
growth_hormone_receptor          5430
Genetic_and_Rare_Diseases        5388
Diabetes_Digestive_Kidney        1157
Neurological_Disorders_Stroke    1088
Other                             981
SeniorHealth                      769
cancer                            729
Heart_Lung_Blood                  559
Disease_Control_Prevention        270
Name: count, dtype: int64


INFO:__main__:After length filtering: 16370 rows
INFO:__main__:After deduplication: 16357 rows
INFO:__main__:Training samples: 4000
INFO:__main__:Validation samples: 666
INFO:__main__:Test samples: 334
INFO:__main__:Building retrieval system...
INFO:__main__:Building FAISS index...



Sample processed data:
                                      Question_clean  \
0         What is (are) Non-Small Cell Lung Cancer ?   
1   Who is at risk for Non-Small Cell Lung Cancer? ?   
2  What are the symptoms of Non-Small Cell Lung C...   

                                        Answer_clean  
0  - Non-small cell lung cancer is a disease in w...  
1  Smoking is the major risk factor for non-small...  
2  Signs of non-small cell lung cancer include a ...  

📈 Preparing data splits...
Using subset of 5000 samples for faster training

🔍 Building retrieval system...
Building index with 15807 unique answers


Batches: 100%|██████████| 988/988 [00:20<00:00, 47.06it/s] 
INFO:__main__:FAISS index built with 15807 entries



🧪 Testing retrieval system...


Batches: 100%|██████████| 1/1 [00:00<00:00, 79.43it/s]
INFO:__main__:Training generative model...
INFO:__main__:Building tokenizers...


Query: 'What are symptoms of depression?'
Retrieved 5 results:
  1. Score: 0.698
     Answer: Common Symptoms There are many symptoms associated with depression, and some will vary depending on ...
  2. Score: 0.697
     Answer: Symptoms of depression often vary depending upon the person. Common symptoms include - feeling nervo...

🧠 Training generative model...


INFO:__main__:Question vocab size: 2940
INFO:__main__:Answer vocab size: 17317
INFO:__main__:Starting model training...
INFO:__main__:Building model architecture...


Training data shape: (4000, 128), (4000, 128)
Validation data shape: (666, 128), (666, 128)


I0000 00:00:1760631621.577880  274526 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5909 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9
INFO:__main__:Model built successfully



Model Architecture:


Epoch 1/20


ERROR:__main__:Error in main execution: Exception encountered when calling Attention.call().

[1mDimensions must be equal, but are 128 and 8 for '{{node health_qa_model_1/attention_layer_1/sub}} = Sub[T=DT_FLOAT](health_qa_model_1/attention_layer_1/MatMul, health_qa_model_1/attention_layer_1/mul)' with input shapes: [8,128,128], [8,128].[0m

Arguments received by Attention.call():
  • inputs=['tf.Tensor(shape=(8, 128, 256), dtype=float32)', 'tf.Tensor(shape=(8, 128, 256), dtype=float32)']
  • mask=['tf.Tensor(shape=(8, 128), dtype=bool)', 'tf.Tensor(shape=(8, 128), dtype=bool)']
  • training=True
  • return_attention_scores=False
  • use_causal_mask=False


❌ Error: Exception encountered when calling Attention.call().

[1mDimensions must be equal, but are 128 and 8 for '{{node health_qa_model_1/attention_layer_1/sub}} = Sub[T=DT_FLOAT](health_qa_model_1/attention_layer_1/MatMul, health_qa_model_1/attention_layer_1/mul)' with input shapes: [8,128,128], [8,128].[0m

Arguments received by Attention.call():
  • inputs=['tf.Tensor(shape=(8, 128, 256), dtype=float32)', 'tf.Tensor(shape=(8, 128, 256), dtype=float32)']
  • mask=['tf.Tensor(shape=(8, 128), dtype=bool)', 'tf.Tensor(shape=(8, 128), dtype=bool)']
  • training=True
  • return_attention_scores=False
  • use_causal_mask=False
Trying to continue with retrieval-only mode...

🔄 Falling back to retrieval-only mode...


Batches: 100%|██████████| 1/1 [00:00<00:00, 231.76it/s]

Q: What are common health tips?
A: Treatment and prevention for P.A.D. often includes making long-lasting lifestyle changes, such as - quitting smoking - lowering blood pressure - lowering high blood cholesterol levels - lowering high blood glucose levels if you have diabetes - getting regular physical activity - following a healthy eating plan that's low in total fat, saturated fat, trans fat, cholesterol, and sodium (salt). quitting smoking lowering blood pressure lowering high blood cholesterol levels lowering high blood glucose levels if you have diabetes getting regular physical activity following a healthy eating plan that's low in total fat, saturated fat, trans fat, cholesterol, and sodium (salt). Two examples of healthy eating plans are Therapeutic Lifestyle Changes (TLC) and Dietary Approaches to Stop Hypertension (DASH).





: 