# BERT-based Question Answering Demo

This project demonstrates how to use a fine-tuned BERT model for the **Question Answering (QA)** task. It leverages the Hugging Face `transformers` library with TensorFlow to load a pre-trained model and tokenizer for answering questions based on a given context.

## Project Structure
- **`trained_model/`**: Directory containing the fine-tuned BERT model for question answering.
- **`data/`**: Directory containing the `dataset.json` file with context data for answering questions.
- **`script.py`**: The main Python script that performs question answering.

## Requirements
- TensorFlow
- Hugging Face Transformers library
- JSON dataset containing contexts for answering questions

## How It Works

1. **Model and Tokenizer Initialization**:
   - The model is loaded from the specified `trained_model` directory using the `TFBertForQuestionAnswering` class.
   - The tokenizer is loaded from the same directory using `BertTokenizerFast`.

2. **Answering a Question**:
   - The `QAInference` class defines the process for taking a `context` (a passage of text) and a `question`, and using the model to predict the answer.
   - The input question and context are tokenized using the tokenizer.
   - The model predicts the start and end positions of the answer in the context.
   - The predicted span of text is decoded back into a readable string and returned as the answer.

3. **User Interaction**:
   - When the script runs, users are prompted to enter a question.
   - The context for the question is retrieved from the first entry in the `dataset.json` file (which you can modify as needed).
   - The script outputs the question along with the model's predicted answer.


In [2]:
import tensorflow as tf
from transformers import TFBertForQuestionAnswering, BertTokenizerFast
import json
import os

# Directory stuff
workspace_dir = os.path.abspath(os.path.join(os.getcwd()))
model_path = os.path.join(workspace_dir, 'trained_model')
dataset_path = os.path.join(workspace_dir, 'data', 'dataset.json')

class QAInference:
    def __init__(self, model_path):
        """
        Initializes the model and tokenizer for inference.
        """
        # Load the fine-tuned model and tokenizer
        self.model = TFBertForQuestionAnswering.from_pretrained(model_path)
        self.tokenizer = BertTokenizerFast.from_pretrained(model_path)

    def answer_question(self, context, question):
        """
        Given a context and a question, return the answer predicted by the model.
        """
        # Tokenize the input question and context
        inputs = self.tokenizer(question, context, return_tensors="tf", truncation=True, padding=True)
        
        # Get model outputs
        outputs = self.model(inputs)
        
        # Find the start and end logits of the predicted answer
        start_scores = outputs.start_logits
        end_scores = outputs.end_logits
        
        # Get the start and end positions of the answer
        start_idx = tf.argmax(start_scores, axis=1).numpy()[0]
        end_idx = tf.argmax(end_scores, axis=1).numpy()[0]
        
        # Extract the answer tokens and decode to string
        answer_tokens = inputs["input_ids"][0][start_idx:end_idx + 1]
        answer = self.tokenizer.decode(answer_tokens, skip_special_tokens=True)
        
        return answer

def main():
    # Load dataset from JSON (just for the context)
    with open(dataset_path, 'r') as file:
        dataset = json.load(file)

    # Initialize the QA model
    qa_model = QAInference(model_path)

    # Allow user to input a question
    question = input("Enter your question: ")

    # Use the first context from the dataset (or modify as needed)
    context = dataset[0]["context"]

    # Get the predicted answer for the given question and context
    predicted_answer = qa_model.answer_question(context, question)

    print("\nQuestion:", question)
    print("Predicted Answer:", predicted_answer)

if __name__ == "__main__":
    main()



Question: What does natural language processing enable?
Predicted Answer: building systems that learn from data
