In [None]:
from flask import Flask, request, jsonify
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import threading
import json

# Initialize Flask apps
receive_app = Flask(__name__)
llama_app = Flask(__name__)

# Load LLaMA model and tokenizer
model_path = "./Test1"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Move model to GPU(s) if available
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model)
    model = model.cuda()

# Store the conversation state and responses
conversation_state = {
    "current_state": "waiting_for_question",
    "current_question": "",
    "current_answer": "",
    "current_mcq": None,
    "correct_answer": None
}

def clean_topic(question):
    """Clean the question to get a proper topic."""
    # Remove question words and common patterns
    question = question.lower()
    question = question.replace("what is", "").replace("how does", "").replace("why does", "")
    question = question.replace("explain", "").replace("describe", "")
    # Remove question marks
    question = question.replace("?", "")
    # Clean up extra spaces
    question = " ".join(question.split())
    return question.strip()

def generate_mcq(question, answer):
    try:
        # Clean the topic for better MCQ generation
        topic = clean_topic(question)
        
        mcq_prompt = f"""Create an educational multiple choice question based on this content:

Topic: {topic}
Detailed Content: {answer}

Generate 4 options where only one is correct. Format exactly like this:

{{
    "question": "What is the primary function of the {topic}?",
    "options": {{
        "A": "A complete, accurate statement about the {topic}",
        "B": "A plausible but incorrect statement about the {topic}",
        "C": "Another plausible but incorrect statement about the {topic}",
        "D": "Another plausible but incorrect statement about the {topic}"
    }},
    "correct_answer": "A",
    "explanation": "The correct answer is A because [explanation based on the content]"
}}

Rules:
1. All options must be complete, clear sentences
2. Options must be directly related to {topic}
3. Only one option should be correct
4. Don't use 'all of the above' or 'none of the above'
5. Make options similar in length
6. Do not use placeholder text in the options

Generate MCQ now:"""

        mcq_inputs = tokenizer(mcq_prompt, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
        mcq_ids = model.module.generate(
            **mcq_inputs,
            max_new_tokens=500,
            num_return_sequences=1,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
        mcq_response = tokenizer.decode(mcq_ids[0], skip_special_tokens=True)
        
        # Extract and validate JSON
        try:
            start_idx = mcq_response.find('{')
            end_idx = mcq_response.rfind('}') + 1
            if start_idx != -1 and end_idx != -1:
                json_str = mcq_response[start_idx:end_idx]
                mcq_data = json.loads(json_str)
                
                if all(key in mcq_data for key in ["question", "options", "correct_answer"]):
                    if isinstance(mcq_data["options"], dict) and len(mcq_data["options"]) == 4:
                        return mcq_data

        except json.JSONDecodeError:
            pass
        
        # Fallback MCQ generation
        sentences = answer.split('.')
        sentences = [s.strip() for s in sentences if s.strip()]
        main_fact = sentences[0] if sentences else answer[:100]
        
        return {
            "question": f"Which statement about the {topic} is correct?",
            "options": {
                "A": main_fact,
                "B": f"The {topic} is responsible for breaking down fats in the blood",
                "C": f"The {topic} is only active during physical exercise",
                "D": f"The {topic} plays no role in the human body's functioning"
            },
            "correct_answer": "A",
            "explanation": f"The correct answer is A because {main_fact.lower()}"
        }

    except Exception as e:
        print(f"Error generating MCQ: {e}")
        raise

def generate_response(message, context=""):
    try:
        if conversation_state["current_state"] == "waiting_for_mcq_answer":
            # Verify MCQ answer
            selected_option = message.upper()
            correct_answer = conversation_state["correct_answer"]
            is_correct = selected_option == correct_answer
            explanation = conversation_state["current_mcq"].get("explanation", "")
            
            if is_correct:
                return f"Correct! Well done! {explanation}"
            else:
                return f"That's not quite right. {explanation}"
        else:
            # Normal question-answer
            answer_prompt = f"Question: {message}\nProvide a detailed educational answer:"
            answer_inputs = tokenizer(answer_prompt, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
            answer_ids = model.module.generate(
                **answer_inputs,
                max_new_tokens=200,
                num_return_sequences=1,
                temperature=0.7
            )
            answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
            return answer.split("Answer:")[-1].strip()
    except Exception as e:
        print(f"Error during response generation: {e}")
        return "An error occurred while generating the answer."

@receive_app.route('/receive', methods=['POST'])
def receive_data():
    try:
        if not request.is_json:
            return jsonify({"error": "Request must be in JSON format"}), 400

        data = request.get_json()
        message = data.get("message")
        
        if not message:
            return jsonify({"error": "No 'message' provided"}), 400

        if conversation_state["current_state"] == "waiting_for_question":
            conversation_state["current_question"] = message
            answer = generate_response(message)
            conversation_state["current_answer"] = answer
            conversation_state["current_state"] = "waiting_for_understanding"
            return jsonify({
                "response_type": "answer",
                "answer": answer,
                "instruction": "Please reply with 'i got it' when you understand the answer."
            })

        elif conversation_state["current_state"] == "waiting_for_understanding":
            if "i got it" in message.lower():
                try:
                    mcq_data = generate_mcq(
                        conversation_state["current_question"],
                        conversation_state["current_answer"]
                    )
                    conversation_state["current_mcq"] = mcq_data
                    conversation_state["correct_answer"] = mcq_data["correct_answer"]
                    conversation_state["current_state"] = "waiting_for_mcq_answer"
                    
                    return jsonify({
                        "response_type": "mcq",
                        "question": mcq_data["question"],
                        "options": mcq_data["options"],
                        "instruction": "Please select one option (A, B, C, or D)"
                    })
                except Exception as e:
                    print(f"MCQ generation error: {e}")
                    return jsonify({"error": f"Failed to generate MCQ: {str(e)}"}), 500
            else:
                return jsonify({
                    "error": "Please confirm your understanding by saying 'I got it'"
                }), 400

        elif conversation_state["current_state"] == "waiting_for_mcq_answer":
            if message.upper() not in ["A", "B", "C", "D"]:
                return jsonify({
                    "error": "Please select a valid option (A, B, C, or D)"
                }), 400

            result = generate_response(message, "verify_mcq")
            conversation_state["current_state"] = "waiting_for_question"  # Reset state
            return jsonify({
                "response_type": "mcq_result",
                "result": result,
                "correct_answer": conversation_state["correct_answer"],
                "explanation": conversation_state["current_mcq"].get("explanation", ""),
                "message": "You can ask another question now!"
            })

        return jsonify({"error": "Invalid state"}), 400

    except Exception as e:
        print(f"Error in /receive endpoint: {e}")
        return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500

@llama_app.route('/generate_response', methods=['POST'])
def get_response():
    try:
        if not request.is_json:
            return jsonify({"error": "Request must be in JSON format"}), 400

        data = request.get_json()
        message = data.get("message", "")
        
        if not message:
            return jsonify({"error": "No 'message' provided"}), 400

        answer = generate_response(message)
        return jsonify({"answer": answer})

    except Exception as e:
        print(f"Error in /generate_response endpoint: {e}")
        return jsonify({"error": "An internal error occurred"}), 500

def run_receive_app():
    receive_app.run(host='0.0.0.0', port=2025)

def run_llama_app():
    llama_app.run(host='0.0.0.0', port=2026)

if __name__ == '__main__':
    receive_thread = threading.Thread(target=run_receive_app)
    llama_thread = threading.Thread(target=run_llama_app)
    
    receive_thread.start()
    llama_thread.start()
    
    receive_thread.join()
    llama_thread.join()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

 * Serving Flask app '__main__'
 * Serving Flask app '__main__'
 * Debug mode: off
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:2026
 * Running on http://192.168.0.13:2026
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:2025
 * Running on http://192.168.0.13:2025
[33mPress CTRL+C to quit[0m
[33mPress CTRL+C to quit[0m
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
66.207.195.70 - - [24/Nov/2024 16:37:45] "POST /receive HTTP/1.1" 200 -
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
66.207.195.70 - - [24/Nov/2024 16:38:02] "POST /receive HTTP/1.1" 200 -
66.207.195.70 - - [24/Nov/2024 16:38:16] "[31m[1mPOST /receive HTTP/1.1[0m" 400 -
66.207.195.70 - - [24/Nov/2024 16:38:29] "[31m[1mPOST /receive HTTP/1.1[0m" 400 -
66.207.195.70 - - [24/Nov/2024 16:38:36] "POST /receive HTTP/1.1" 200 -
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
66.207.195.70 - - [24/Nov/2024 16:39:13] "POST /receive HTTP/1.1" 200 -
Setting `pad_token_id` to `eos_token_