In [111]:
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
import json
from langgraph.checkpoint.memory import MemorySaver

# Initialize the OpenAI LLM model
llm = ChatOpenAI(model="gpt-3.5-turbo")

# Define the state object to store data during the workflow
class ExtractionState:
    def __init__(self, student_file_content: str, model_file_content: str, rubric_content: str, question_content: str = None):
        self.student_file_content = student_file_content
        self.model_file_content = model_file_content
        self.rubric_content = rubric_content
        self.question_content = question_content  # This is now optional
        self.student_classes = {}
        self.model_classes = {}
        self.class_rubric_mapping = {}
        self.class_evaluations = {}  # Ensure this is initialized as an empty dictionary
        self.final_class_evaluations = {}
        self.total_marks = None

# Class Extraction Module
async def extract_student_classes(state: ExtractionState) -> ExtractionState:
    prompt = f"Extract the entire Java classes from the student's solution. Return a dictionary where the key is the class name, and the value is the entire class code:\n{state.student_file_content}"
    response = await llm.ainvoke([HumanMessage(content=prompt)])

    # Assuming the response is a JSON string, we parse it into a dictionary
    try:
        state.student_classes = json.loads(response.content)
    except json.JSONDecodeError as e:
        print(f"Failed to parse student classes: {e}")
        state.student_classes = {}

    return state

async def extract_model_classes(state: ExtractionState) -> ExtractionState:
    prompt = f"Extract the entire Java classes from the model solution. Return a dictionary where the key is the class name, and the value is the entire class code:\n{state.model_file_content}"
    response = await llm.ainvoke([HumanMessage(content=prompt)])

    try:
        state.model_classes = json.loads(response.content)
    except json.JSONDecodeError as e:
        print(f"Failed to parse model classes: {e}")
        state.model_classes = {}

    return state

# Rubric Extraction Module
async def extract_rubric_for_classes(state: ExtractionState) -> ExtractionState:
    prompt = f"Extract the relevant rubric for evaluating the following Java classes based on the provided rubric. Return a JSON where the key is the class name and the value is the corresponding rubric criteria:\n{state.rubric_content}"
    response = await llm.ainvoke([HumanMessage(content=prompt)])

    # Parse the response as a dictionary
    try:
        state.class_rubric_mapping = json.loads(response.content)
    except json.JSONDecodeError as e:
        print(f"Failed to parse class rubric mapping: {e}")
        state.class_rubric_mapping = {}

    return state

# Initial Evaluation Module
def generate_evaluation_prompt(class_name: str, student_class_code: str, model_class_code: str, rubric_criteria: list) -> str:
    return (
        f"Evaluate the student's Java class '{class_name}' against the model class and rubric. "
        f"**Student Class Code:**\n{student_class_code}\n\n**Model Class Code:**\n{model_class_code}\n\n"
        f"**Rubric Criteria:**\n{rubric_criteria}\n\n"
        "Provide a detailed score for each criterion and specific feedback."
    )

async def evaluate_classes(state: ExtractionState) -> ExtractionState:
    for class_name, student_class_code in state.student_classes.items():
        model_class_code = state.model_classes.get(class_name, "")
        rubric_criteria = state.class_rubric_mapping.get(class_name, [])
        prompt = generate_evaluation_prompt(
            class_name,
            student_class_code,
            model_class_code,
            rubric_criteria
        )
        response = await llm.ainvoke([HumanMessage(content=prompt)])
        if state.class_evaluations is None:
            state.class_evaluations = {}  # Initialize if not already done
        state.class_evaluations[class_name] = response.content  # Ensure class_evaluations is initialized before this line
    return state

# Review Evaluation Module
def generate_review_prompt(class_name: str, student_class_code: str, model_class_code: str, rubric_criteria: list, initial_evaluation: dict) -> str:
    return (
        f"Review the initial evaluation of the student's class '{class_name}' and the provided feedback. "
        f"**Student Class Code:**\n{student_class_code}\n\n**Model Class Code:**\n{model_class_code}\n"
        f"**Rubric Criteria:**\n{rubric_criteria}\n\n"
        f"**Initial Evaluation:**\n{initial_evaluation}\n\n"
        "Make necessary corrections, and provide the final assessment with feedback."
    )

async def review_evaluations(state: ExtractionState) -> ExtractionState:
    for class_name, initial_evaluation in state.class_evaluations.items():
        student_class_code = state.student_classes.get(class_name, "")
        model_class_code = state.model_classes.get(class_name, "")
        rubric_criteria = state.class_rubric_mapping.get(class_name, [])
        prompt = generate_review_prompt(
            class_name,
            student_class_code,
            model_class_code,
            rubric_criteria,
            initial_evaluation
        )
        response = await llm.ainvoke([HumanMessage(content=prompt)])
        state.final_class_evaluations[class_name] = response.content
    return state

# Marks Extraction Module
def extract_marks(state: ExtractionState) -> ExtractionState:
    for class_name, evaluation in state.final_class_evaluations.items():
        marks_list = [str(crit.get("score", 0)) for crit in evaluation.get("criterion_evaluations", [])]
        state.final_class_evaluations[class_name]["marks"] = ", ".join(marks_list)
    return state

# Total Marks Calculation Module
@tool
def sum_marks(marks_list: str) -> int:
    """
    Takes a comma-separated list of marks and returns their sum.
    Each mark is expected to be a valid integer.
    """
    marks = [int(mark.strip()) for mark in marks_list.split(",")]
    return sum(marks)

async def calculate_total_marks(state: ExtractionState) -> ExtractionState:
    total_marks_list = [evaluation.get("marks", "") for evaluation in state.final_class_evaluations.values()]
    combined_marks_list = ", ".join(total_marks_list)
    llm.bind_tools([sum_marks])
    prompt = f"Calculate the total marks from the following list of marks: {combined_marks_list}"
    response = await llm.ainvoke([HumanMessage(content=prompt)])
    state.total_marks = response.content
    return state

# Save Final Evaluations to a File
def save_final_evaluations(state: ExtractionState, filename="final_evaluations.txt"):
    with open(filename, "w") as file:
        for class_name, evaluation in state.final_class_evaluations.items():
            file.write(f"Class: {class_name}\n")
            file.write(json.dumps(evaluation, indent=2))
            file.write("\n\n")
        file.write(f"Total Marks: {state.total_marks}\n")
    print(f"Final evaluations and total marks saved to {filename}")

# LangGraph Workflow Construction
graph = StateGraph(ExtractionState)
graph.add_node("extract_student_classes", extract_student_classes)
graph.add_node("extract_model_classes", extract_model_classes)
graph.add_node("extract_rubric", extract_rubric_for_classes)
graph.add_node("evaluate_classes", evaluate_classes)
graph.add_node("review_evaluations", review_evaluations)
graph.add_node("extract_marks", extract_marks)
graph.add_node("calculate_total_marks", calculate_total_marks)

graph.add_edge(START, "extract_student_classes")
graph.add_edge("extract_student_classes", "extract_model_classes")
graph.add_edge("extract_model_classes", "extract_rubric")
graph.add_edge("extract_rubric", "evaluate_classes")
graph.add_edge("evaluate_classes", "review_evaluations")
graph.add_edge("review_evaluations", "extract_marks")
graph.add_edge("extract_marks", "calculate_total_marks")
graph.add_edge("calculate_total_marks", END)

memory = MemorySaver()
class_extraction_graph = graph.compile(checkpointer=memory)

# Execute the graph and save results
await class_extraction_graph.ainvoke(state, config={"configurable": {"thread_id": str(uuid.uuid4())}})
save_final_evaluations(state)


AttributeError: 'ExtractionState' object has no attribute 'class_evaluations'