In [None]:
# change data path as you need
import os
import pandas as pd

DATA_PATH = "C:\\Users\\notin\\Bahar\\llm\\AIED\\data"
RESULTS_PATH = "C:\\Users\\notin\\Bahar\\llm\\AIED\\results"

data_file_name = "trans_df.csv"
data_path = os.path.join(DATA_PATH, data_file_name)
print(f"Loading data from: {data_path}")
df = pd.read_csv(data_path)

Loading data from: C:\Users\notin\Bahar\llm\AIED\data\trans_df.csv


In [None]:

df_2022 = df[df['year'] == 2022]
df_2022.drop(columns=['year', "audio_file_path", "session", "audio_file_name"], inplace=True)
print("Number of records in 2022:", len(df_2022))

df.head()

## To do for later runs
* [ ] fix logic: if all agents agree do not go to the next round
* [ ] fix prompt:
    - CODEBOOK needs more cotext, agents seems to be confused about whta other categoy means
    - Reduce emphasis on short setences 
    - Emphasis on returning None only when needed
* [ ] Currently if llm resp is not formatted correctly after <max_tries>, code returns the heuristic label instead of the llm resp. The problem is that it happens alot. Possible fix is  to return the raw resp as well with heuristic label.

## Configs

In [None]:
MAX_RETRIES = 2
BATCH_NUM = 10
MAX_NEW_TOKENS = 512

# Run this

## setup

In [None]:
import os, re, json, time, ast
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, field


import math
import numpy as np
import pandas as pd



import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(0))


from transformers import AutoTokenizer, AutoModelForCausalLM
import accelerate

In [None]:
import torch, gc
# del model   # if available
gc.collect()
torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())

In [None]:
CAD_CODEBOOK_DICT = {
    "WCT": "The teacher is addressing the whole class.",
    "GT":  "The teacher is addressing a group or a student in a group. It also includes any talk: student level",
    "Other": "The teacher isn‚Äôt talking to the whole class or any groups or students. Either she‚Äôs silent or talking to herself or a visitor in a non-distracting way: "
}
output_example = { "code": "WCT", "reasoning": "it is more likly that the teacher is addressing the whole class" }

In [None]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove any existing handlers to prevent duplicate messages or conflicting configurations
# This is important if basicConfig was called before or if the cell is run multiple times
for handler in logger.handlers[:]: # Iterate over a copy of the list
    logger.removeHandler(handler)

# Create a StreamHandler that prints to standard output
handler = logging.StreamHandler(sys.stdout)

# Define a formatter for the log messages
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)

logger.info("This is an info message from the explicitly configured logger.")

In [None]:
class AnnotationResult:
    """Result from annotating a single text."""
    success: bool
    code: Optional[str] = None
    rationale: Optional[str] = None
    raw_output: Optional[str] = None
    error: Optional[str] = None
    text: Optional[str] = None

    def __str__(self) -> str:
        if self.success:
            return f"‚úì {self.code}: {self.rationale}"
        return f"‚úó Error: {self.error}"

    @property
    def parsed(self) -> Dict[str, str]:
        """Get parsed result in legacy format."""
        return {
            "CAD-code": self.code or "NONE",
            "rationale": self.rationale or ""
        }

## Model/gpu manager

In [None]:
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.4
TOP_K = 40
CPU_MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
GPU_MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

# ============================================================================
# MEMORY UTILITIES
# ============================================================================

@staticmethod
def clear_cache():
    """Clears GPU cache and runs garbage collection."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

@staticmethod
def get_memory_info() -> Dict[str, float]:
    """Returns current GPU memory usage in GB."""
    if not torch.cuda.is_available():
        return {}

    device = torch.cuda.current_device()
    allocated = torch.cuda.memory_allocated(device) / 1024**3
    reserved = torch.cuda.memory_reserved(device) / 1024**3
    total = torch.cuda.get_device_properties(device).total_memory / 1024**3
    free = total - allocated

    return {
        "allocated_gb": allocated,
        "reserved_gb": reserved,
        "total_gb": total,
        "free_gb": free
    }

@staticmethod
def log_memory_usage():
    """Logs current memory usage."""
    info = get_memory_info()
    if info:
        logger.debug(
            f"GPU Memory - Allocated: {info['allocated_gb']:.2f}GB, "
            f"Free: {info['free_gb']:.2f}GB, Total: {info['total_gb']:.2f}GB"
        )

@staticmethod
def setup_memory_optimization():
    """Sets up environment variables for better memory management."""
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
class ModelManager:
    """Handles model loading and inference."""

    def __init__(
        self,
        model_id: Optional[str] = None,
        device: Optional[str] = None,
        temperature: float = 0.7,
        top_k: int = 40,
        use_8bit: bool = False,
        use_4bit: bool = True, # Essential for 7B on 8GB GPU
        max_memory_gb: Optional[float] = 7.5 # Leave some headroom
    ):
        self.device = self._resolve_device(device)
        self.model_id = model_id or self._get_default_model_id()
        self.temperature = temperature
        self.top_k = top_k
        self.use_8bit = use_8bit
        self.use_4bit = use_4bit
        self.max_memory_gb = max_memory_gb


        self._tokenizer = None
        self._model = None
        # Setup memory optimization
        setup_memory_optimization()


    def _resolve_device(self, device: Optional[str]) -> torch.device:
        """Determines the appropriate device for model execution."""
        if device:
            return torch.device(device)
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _get_default_model_id(self) -> str:
        """Selects default model based on available hardware."""
        if self.device.type == "cuda":
            return CPU_MODEL_ID
        return CPU_MODEL_ID

    def load_model(self):
        """Loads tokenizer and model if not already loaded."""
        if self._model and self._tokenizer:
            return
        logging.info(f"Loading model: {self.model_id}")
        clear_cache()
        try:
          self._tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)
          
          # Build model loading kwargs
          model_kwargs = {"device_map": "auto"}

          # ########## Add quantization for GPU ##########
          #  Apply 4-bit or 8-bit quantization if specified
          if self.use_4bit:
              from transformers import BitsAndBytesConfig
              model_kwargs["quantization_config"] = BitsAndBytesConfig(
                  load_in_4bit=True,
                  bnb_4bit_compute_dtype=torch.float16,
                  bnb_4bit_quant_type="nf4",
                  bnb_4bit_use_double_quant=True,
              )
              logging.info("Loading with 4-bit quantization")
          elif self.use_8bit:
                from transformers import BitsAndBytesConfig
                model_kwargs["quantization_config"] = BitsAndBytesConfig(
                    load_in_8bit=True
                )
                logging.info("Loading with 8-bit quantization")
          else:
              # Only set dtype if NOT using quantization
              model_kwargs["torch_dtype"] = torch.float16 if self.device.type == "cuda" else torch.float32
            
          if self.max_memory_gb and self.device.type == "cuda":
            device_index = self.device.index if self.device.index is not None else 0
            model_kwargs["max_memory"] = {device_index: f"{self.max_memory_gb}GB"}
          # ########## quantization ##########

          self._model = AutoModelForCausalLM.from_pretrained(
              self.model_id,
              **model_kwargs
          )
          self._model.eval()

          # Log memory after loading if in debug
          logging.info("Model loaded successfully")
          log_memory_usage()
        except torch.cuda.OutOfMemoryError as e:
            logging.error(f"CUDA OOM while loading model: {e}")
            clear_cache()
            raise RuntimeError(
                "Out of GPU memory. Try: \n"
                "1. Use smaller model (1.5B instead of 7B)\n"
                "2. Enable quantization: use_8bit=True or use_4bit=True\n"
                "3. Set max_memory_gb to limit memory per GPU\n"
                "4. Close other GPU processes"
            ) from e


    # def _clean_deepseek_output(self, text: str) -> str:
    #     """
    #     Cleans DeepSeek-R1 model output by removing reasoning tokens.
    #     DeepSeek-R1 models wrap reasoning in <think></think> or similar tags.
    #     """
    #     # Remove content between think tags
    #     text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
    #     text = re.sub(r'<sub>.*?</sub>', '', text, flags=re.DOTALL)

    #     # Try to extract JSON object if present
    #     json_match = re.search(r'\{[^{}]*"CAD-code"[^{}]*\}', text, flags=re.DOTALL)
    #     if json_match:
    #         return json_match.group(0)

    #     return text.strip()


    def generate(
        self,
        prompt: str,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None
    ) -> str:
        """Generates text from the model given a prompt."""
        self.load_model()
        temp = temperature if temperature is not None else self.temperature
        tk = top_k if top_k is not None else self.top_k

        try:
          # Clear cache before generation
          clear_cache()

          inputs = self._tokenizer(
              prompt,
              return_tensors="pt",
              truncation=True
          ).to(self._model.device)
          logging.debug(f" Calling model with these inputs: {inputs}")

          generate_kwargs = {
            "max_new_tokens": MAX_NEW_TOKENS,
            "pad_token_id": self._tokenizer.eos_token_id,
            "eos_token_id": self._tokenizer.eos_token_id,
          }
          generate_kwargs.update({
            "do_sample": True,
            "temperature": float(temp),
            "top_k": int(tk),
        })


          with torch.no_grad():
              outputs = self._model.generate(**inputs, **generate_kwargs)
              
          logging.debug(f"Model parameters: temperature={temp}, top_k={tk}, max_new_tokens={MAX_NEW_TOKENS}")
          logging.debug(f"Model raw output: {outputs}")

          full_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
          prompt_text = self._tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
          # Clean up inputs/outputs tensors
          del inputs, outputs
          clear_cache()

          if full_text.startswith(prompt_text):
                generated = full_text[len(prompt_text):].strip()
          else:
                generated = full_text.strip()

          # Clean DeepSeek-R1 reasoning tokens
          # generated = self._clean_deepseek_output(generated)
          return generated

        except torch.cuda.OutOfMemoryError as e:
            logging.error(f"CUDA OOM during generation: {e}")

            # Log memory after loading if in debug
            log_memory_usage()

            clear_cache()
            raise RuntimeError(
                "Out of GPU memory during generation. Try:\n"
                "1. Reduce max_new_tokens\n"
                "2. Process texts in smaller batches\n"
                "3. Unload and reload model: agent.model_manager.unload_model()\n"
                "4. Enable quantization if not already enabled"
            ) from e



In [None]:
import logging
logging.basicConfig(level=logging.INFO)

# Start with 1.5B model (no quantization needed)
print("Testing 1.5B model...")
manager = ModelManager(
    model_id=CPU_MODEL_ID,
    temperature=0.7,
    top_k=40
)

prompt = "hi?"
response = manager.generate(prompt)
print(f"Response: {response}")

## OutputValidator

In [None]:
class OutputValidator:
    """Validates and parses model outputs."""

    ALLOWED_CODES = {"WCT", "GT", "Other", "NONE"}
    REQUIRED_KEYS = {"CAD-code", "rationale"}

    @classmethod
    def validate_and_parse(cls, text: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
        """
        Validates model output and parses JSON.

        Returns:
            Tuple of (is_valid, parsed_dict, error_message)
        """
        text = text.strip()

        # Try parsing the entire string as JSON
        parsed = cls._extract_json(text)
        if parsed is None:
            return False, None, "Could not extract valid JSON from response"

        # Validate structure
        error = cls._validate_structure(parsed)
        if error:
            return False, None, error

        return True, parsed, None

    @staticmethod
    def _extract_json(text: str) -> Optional[Dict[str, Any]]:
        """Attempts to extract and parse JSON from text."""
        try:
            return json.loads(text)
        except json.JSONDecodeError:
            # Try to extract first {...} block
            match = re.search(r"\{.*\}", text, flags=re.DOTALL)
            if not match:
                return None
            try:
                return json.loads(match.group(0))
            except json.JSONDecodeError:
                return None

    @classmethod
    def _validate_structure(cls, parsed: Dict[str, Any]) -> Optional[str]:
        """Validates the structure and content of parsed JSON."""
        # Check keys
        if set(parsed.keys()) != cls.REQUIRED_KEYS:
            return f"Unexpected keys: {list(parsed.keys())}"

        # Validate code
        code = parsed.get("CAD-code")
        if code not in cls.ALLOWED_CODES:
            return f"Invalid CAD-code: {code}"

        # Validate rationale length
        rationale = parsed.get("rationale", "")


        return None

## Prompt Builder


In [None]:
class PromptBuilder:
    """Handles all prompt construction logic."""

    # JSON Schema definition
    SCHEMA = {
        "CAD-code": "<ONE OF: WCT, GT, Other, NONE>",
        "rationale": "<‚â§5 sentences, evidence-based>"
    }

    # Valid codes for validation
    VALID_CODES = {"WCT", "GT", "Other", "NONE"}

    # Default examples for few-shot learning
    DEFAULT_EXAMPLES = [
        {
            "input": "Everybody please listen.",
            "output": {
                "CAD-code": "WCT",
                "rationale": 'Addresses the whole class using "Everybody" to get attention.'
            }
        },
        {
            "input": "Group 3, read the next paragraph.",
            "output": {
                "CAD-code": "GT",
                "rationale": 'Directs a specific group "Group 3" to perform an action.'
            }
        }
    ]

    def __init__(self, name: str, personality: str, role: str, codebook: Dict[str, str], config: Dict[str, Any]):
        self.name = name
        self.personality = personality
        self.role = role
        self.codebook = codebook
        self.config = config

    def build_system_prompt(self, role: str) -> str:
        """Creates the system prompt with instructions and examples."""
        return (
            f"You are {self.name}, a {self.personality} qualitative-coding agent.\n"
            f"Task: {role}.\n\n"
            "CRITICAL: Output ONLY a single JSON object.\n\n"
            "REQUIREMENTS (follow exactly):\n"
            "1) Your ENTIRE response must be ONLY this JSON object and nothing else:\n"
            f"   {json.dumps(self.SCHEMA)}\n"
            "2) Use double quotes for JSON strings.\n"
            "3) CAD-code must be one of: WCT, GT, Other, NONE\n"
            "4) Rationale:  grounded in evidence from the text.\n"
            "5) If multiple codes could apply, choose the most likely one; if ambiguous, use NONE.\n\n"
            "CORRECT OUTPUT EXAMPLES:\n"
            'Input: "Everybody please listen."\n'
            'Output: {"CAD-code":"WCT","rationale":"Addresses the whole class using \\"Everybody\\" to get attention."}\n\n'
            'Input: "Group 3, read the next paragraph."\n'
            'Output: {"CAD-code":"GT","rationale":"Directs a specific group \\"Group 3\\" to perform an action."}\n\n'
            "Remember: Output ONLY the JSON object. Start your response with { and end with }"
        )
    # def build_system_prompt(self, role: str) -> str:
    #     """Creates the system prompt with instructions and examples."""
    #     return (
    #         f"You are {self.name}, a {self.personality} qualitative-coding agent.\n"
    #         f"Task: {role}.\n\n"
    #         "Remember: Start your response with { and end with } \n\n"
    #     )
    def build_context_prompt(self) -> str:
        """Creates the codebook context."""
        if not self.codebook:
            return ""
        lines = [f"- {k}: {v}" for k, v in self.codebook.items()]
        return "Codebook:\n" + "\n".join(lines)

    def build_user_prompt(self, text: str) -> str:
        """Creates the user prompt with the text to annotate."""
        template = self.config.get(
            "user_template",
            'text to code: \n{text}\n\n'
        )
        return template.format(text=text)

    def build_full_prompt(
        self,
        text: str,
        role: str,
        extra_context: Optional[str] = None,
        previous_turn: Optional[str] = None
    ) -> Dict[str, str]:
        """Builds complete prompt dictionary with all components."""
        prompt = {
            "system": self.build_system_prompt(role),
            "context": self.build_context_prompt(),
            "user": self.build_user_prompt(text),
        }

        if extra_context:
            prompt["extra"] = extra_context

        if previous_turn:
            # Append previous turn to context
            if prompt.get("context"):
                prompt["context"] += f"\n\n###\nPrevious turn: {previous_turn}"
            else:
                prompt["user"] += f"\n\n###\nPrevious turn: {previous_turn}"

        return prompt

    def build_retry_prompt(self, original_prompt: str, failed_output: str) -> str:
        """Builds a retry prompt when the model fails to produce valid JSON."""
        return (
            f"{original_prompt}\n\n"
            "--- RETRY REQUEST ---\n"
            "Your previous output was invalid or incorrectly formatted.\n"
            f"Previous output:\n{failed_output[:500]}\n\n"
            "Please output ONLY a valid JSON object with this exact structure:\n"
            f"{json.dumps(self.SCHEMA)}\n\n"
            "Requirements:\n"
            "- Start with { and end with }\n"
            "- No markdown, no explanations, no extra text\n"
            f"- CAD-code must be exactly one of: {', '.join(sorted(self.VALID_CODES))}\n"
            "- Use double quotes for strings\n\n"
            "Return ONLY the JSON object now:"
        )

    @staticmethod
    def to_string(prompt_dict: Dict[str, str]) -> str:
        """Converts prompt dictionary to a formatted string."""
        parts = []
        for key in ["system", "context", "extra", "user"]:
            if prompt_dict.get(key):
                parts.append(f"=== {key.upper()} ===\n{prompt_dict[key]}")
        return "\n\n".join(parts)

    def __repr__(self) -> str:
        """String representation for debugging."""
        return (
            f"PromptBuilder(name={self.name}, "
            f"personality={self.personality}, "
            f"role={self.role}, "
            f"codebook_size={len(self.codebook)}, "
            # f"examples_count={len(self.examples)})"
        )



## BaseCodingAgent

In [None]:
from typing import Optional, Dict, Any, Union
import torch
import logging
import json
from transformers import AutoTokenizer, AutoModelForCausalLM


class BaseCodingAgent:
    """Base class for coding agents with LLM-based text annotation."""

    ALLOWED_CODES = {"WCT", "GT", "Other", "NONE"}
    DEFAULT_TEMPERATURE = 0.4  # Define constant

    def __init__(
        self,
        name: str,
        personality: str,
        role: str,
        model_id: Optional[str] = None,
        device: Optional[str] = None,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = 40,
        codebook: Optional[Dict[str, str]] = None,
        config: Optional[Dict[str, Any]] = None,
        debug: bool = False,
    ):
        self.name = name
        self.personality = personality
        self.role = role
        self.debug = debug

        # Initialize components
        self.validator = OutputValidator()
        self.prompt_builder = PromptBuilder(
            name, personality, role,
            codebook or {}, config or {}
        )
        self.model_manager = ModelManager(
            model_id, device, temperature, top_k
        )

        # Store options
        self.options = {"temperature": temperature, "top_k": top_k}
        self.codebook = codebook or {}
        self.config = config or {}


    # Main chat interface for the agent to interact with the language model
    def chat(self, text: str, max_retries = int, role: Optional[str] = None, extra_context: Optional[str] = None, **gen_opts):
        # Build prompt
        prompt_str = self.get_prompt_str(
            text=text,
            role=role or self.role,
            extra_context=extra_context,
            **gen_opts
        )

        # Log if debug enabled
        if self.debug:
            logging.info("=== PROMPT ===")
            logging.info(prompt_str)

        # Generate response
        return self._call_and_retry(prompt_str, max_retries, **gen_opts)


    def _call_and_retry(
        self,
        prompt_str: str,
        max_retries: int,
        **gen_opts
    ) -> str:
        """Internal method to handle generation with retries."""
        try:
            raw = self.model_manager.generate(prompt_str, **gen_opts)
            valid, parsed, err = self.validator.validate_and_parse(raw)

            if valid:
                return raw

            logger.debug("Initial attempt failed: %s. Raw: %s", err, raw[:500])

            # Retry with structured prompt
            for attempt in range(max_retries):
                retry_prompt = self.prompt_builder.build_retry_prompt(
                    prompt_str, raw
                )
                logger.debug("Retry attempt %d/%d", attempt + 1, max_retries)

                raw = self.model_manager.generate(retry_prompt, **gen_opts)
                valid, parsed, err = self.validator.validate_and_parse(raw)

                if valid:
                    return raw

                logger.debug(
                    "Retry %d failed: %s. Raw: %s",
                    attempt + 1, err, raw[:500]
                )

            # Fallback to heuristic
            logger.warning(
                "Model failed after %d retries. Using heuristic fallback.",
                max_retries
            )
            return json.dumps(self._heuristic_label(prompt_str))

        except Exception as e:
            logger.error("Error during generation: %s", e, exc_info=True)
            return json.dumps({
                "CAD-code": "NONE",
                "rationale": f"Error: {str(e)}"
            })

    def _heuristic_label(self, text: str) -> Dict[str, str]:
        """Fallback heuristic labeling when model fails."""
        lower = (text or "").lower().strip()

        if not lower:
            return {
                "CAD-code": "NONE",
                "rationale": "Empty input"
            }

        # Check patterns in priority order
        patterns = [
            (r'\b(everybody|everyone|class|students|all of you|all)\b',
             "WCT", "whole-class addressing"),
            (r'\b(group|pair|you two|you three)\b',
             "GT", "group-level addressing"),
            (r'^[A-Z][a-z]+,', "GT", "direct student address"),
        ]

        for pattern, code, reason in patterns:
            if re.search(pattern, text):
                return {"CAD-code": code, "rationale": reason}

        return {
            "CAD-code": "Other",
            "rationale": "Non-directed teacher talk"
        }



    def get_prompt_str(self, text: str, role: Optional[str] = None,extra_context: Optional[str] = None,**gen_opts):
        # Build prompt
        prompt_dict = self.prompt_builder.build_full_prompt(
            text=text,
            role=role or self.role,
            extra_context = extra_context,
            previous_turn=None,
            **gen_opts
        )

        return self.prompt_builder.to_string(prompt_dict)


    def get_agent_info(self) -> Dict[str, Any]:
        """Return agent configuration as dictionary."""
        return {
            "name": self.name,
            "personality": self.personality,
            "role": self.role,
            "model": self.model_manager.model_id,
            "device": str(self.model_manager.device),
            "options": self.options,
            "codebook": self.codebook,
            "config": self.config,
            "debug": self.debug,
        }

 # ---------------- Output validation & parsing ----------------
    def validate_and_parse(self, text: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
        return self.validator.validate_and_parse(text)

    def get_parsed_resp(self, text: str):
        """Parse response."""
        valid, parsed, err = self.validator.validate_and_parse(text)
        return parsed if valid else None

class SingleAgentCoding(BaseCodingAgent):
    # Assigns a code to a given text based on the codebook and generates a rationale
    def assign_code(self, text: str,
                    max_retries: int = MAX_RETRIES,
                    extra_context: Optional[str] = None,
                    **gen_opts) -> str:
        logging.debug(f"Assigning code for text: {text}") # Log the text being processed
        # Call the chat method with a specific role for assigning codes
        response = self.chat(
                text=text,
                role=self.role,
                max_retries=max_retries,
                extra_context=extra_context,
                **gen_opts)
        logging.debug(f"Raw response from agent: {response}") # Log the raw response from the agent


        return response

## MAS


**High-level logic (fixed)**

1.   Round 0: each agent labels independently.
2.   List itemIf all agree ‚Üí stop.
3.   Else iterate discussion rounds:
  *   compute majority (strict majority preferred)
  *   if majority exists: reprompt only minority agents with majority + arguments
  *   else: general discussion (everyone sees everyone‚Äôs arguments)

4.  Stop early if: unanimous OR stable majority OR no changes OR reached T.

5. Extract final answer via EXT.






### AgentResponse and DiscussionResult

In [None]:
from typing import Optional, Dict, Any, Union, Tuple, List
from collections import Counter
from dataclasses import dataclass, field
import logging
from tqdm import tqdm
from collections import Counter


# from prompt_helper import PromptBuilder
# from model_helper import ModelManager
# from agent_helper import OutputValidator, SingleAgentCoding

@dataclass
class AgentResponse:
    """Represents a single agent's response in a discussion round.

    Attributes:
        agent: Name/identifier of the agent
        code: CAD code assigned (e.g., 'WCT', 'GT', 'Other', 'NONE')
        rationale: Agent's reasoning for the code choice
        raw: Raw unparsed response from the agent
        round: Round number (1-indexed)
    """
    agent: str
    code: str
    rationale: str
    raw: str
    round: int

    def __post_init__(self):
        """Validate response data."""
        if self.round < 1:
            raise ValueError(f"Round must be >= 1, got {self.round}")
        if not self.agent:
            raise ValueError("Agent name cannot be empty")

    def __repr__(self) -> str:
        """Developer-friendly representation."""
        return (
            f"AgentResponse(agent={self.agent!r}, code={self.code!r}, "
            f"round={self.round}, rationale={self.rationale[:50]!r}...)"
        )

    def __str__(self) -> str:
        """Human-readable representation."""
        return (
            f"Response from: "
            f"Agent '{self.agent}' in Round {self.round}:\n"
            f"  Code: {self.code}\n"
            f"  Rationale: {self.rationale or '(none provided)'}"
            f"  Raw: {self.raw}"
        )

    def convert_to_dict(self) -> Dict[str, Any]:
        """Convert response to a dictionary."""
        return {
            "agent": self.agent,
            # "text": self.text,
            "code": self.code,
            "rationale": self.rationale,
            "raw": self.raw,
            "round": self.round,
        }


@dataclass
class DiscussionResult:
    """Results from a multi-agent discussion process.

    Attributes:
        final_code: Consensus or plurality code
        final_rationale: Combined rationale from agents who chose final_code
        confidence: Agreement ratio (0.0 to 1.0)
        history: List of responses per round: history[round_idx] = [AgentResponse, ...]
        tallies: Vote counts per round: tallies[round_idx] = {code: count, ...}
        consensus_reached: Whether consensus threshold was met
        num_rounds: Total number of rounds conducted
    """
    text_to_code: str
    human_code: str
    final_code: str
    final_rationale: str
    confidence: float
    history: List[List[AgentResponse]] = field(default_factory=list)
    round_dicts: List[Dict[str, Any]] = field(default_factory=list)
    tallies: List[Dict[str, int]] = field(default_factory=list)
    consensus_reached: bool = False
    num_rounds: int = 0
    num_agents: int = 0

    def __post_init__(self):
        """Validate result data."""
        if not 0.0 <= self.confidence <= 1.0:
            raise ValueError(f"Confidence must be in [0, 1], got {self.confidence}")
        if self.num_rounds < 0:
            raise ValueError(f"num_rounds must be >= 0, got {self.num_rounds}")
        if self.history and len(self.history) != self.num_rounds:
            raise ValueError(
                f"History length ({len(self.history)}) doesn't match "
                f"num_rounds ({self.num_rounds})"
            )

    def __repr__(self) -> str:
        """Developer-friendly representation."""
        return (

            f"DiscussionResult(text_to_code={self.text_to_code!r}, "
            f"human_code={self.human_code!r}, "
            f"final_code={self.final_code!r}, "
            f"confidence={self.confidence:.2f}, "
            f"consensus={self.consensus_reached}, rounds={self.num_rounds})"
            f"  History: {len(self.history)} rounds"
            f"  Tallies: {len(self.tallies)} rounds"
        )

    def __str__(self) -> str:
        """Human-readable summary."""
        consensus_str = "‚úì Consensus" if self.consensus_reached else "‚úó Plurality"
        return (
            f"Discussion Result:\n"
            f"  Text to Code: {self.text_to_code}\n"
            f"  Human Code: {self.human_code}\n"
            f"  consensus_str Result ({consensus_str}):\n"
            f"  Final Code: {self.final_code}\n"
            f"  Confidence: {self.confidence:.1%}\n"
            f"  Rounds: {self.num_rounds}\n"
            f"  Rationale: {self.final_rationale[:100]}..."
            f"  History: {len(self.history)} rounds"
            f"  Tallies: {len(self.tallies)} rounds"
        )
    def __dict__(self) -> Dict[str, Any]:
        """Convert result to a dictionary."""
        return {
            "text_to_code": self.text_to_code,
            "human_code": self.human_code,
            "final_code": self.final_code,
            "final_rationale": self.final_rationale,
            "confidence": self.confidence,
            "history": self.history,
            "tallies": self.tallies,
            "consensus_reached": self.consensus_reached,
            "num_rounds": self.num_rounds,
            "round_dicts": self.get_round_dicts()
        }
    def get_num_agents(self) -> int:
        """Get number of participating agents.

        Looks at the first round of discussion history to count agents.
        If no history exists, returns 0.
        """
        if self.history:
            return len(self.history[0])
        return 0

    def display(self, verbose: bool = True) -> str:
        """Formatted display with optional round-by-round details.

        Args:
            verbose: Show detailed round-by-round breakdown

        Returns:
            Formatted string representation
        """
        lines = [
            f"\n{'='*70}",
            f"DISCUSSION RESULT",
            f"{'='*70}",
            f"Final Code: {self.final_code}",
            f"Confidence: {self.confidence:.1%} ({self.confidence * self.get_num_agents():.0f}/{self.get_num_agents()} agents)",
            f"Consensus: {'‚úì Yes' if self.consensus_reached else '‚úó No (plurality vote)'}",
            f"Rounds: {self.num_rounds}",
            f"\nFinal Rationale:",
            f"{self.final_rationale or '(none provided)'}",
        ]

        if verbose and self.history:
            lines.append(f"\n{'-'*70}")
            lines.append("ROUND-BY-ROUND BREAKDOWN:")
            lines.append('-'*70)

            for round_idx, (responses, tally) in enumerate(zip(self.history, self.tallies), 1):
                lines.append(f"\nüìç Round {round_idx}:")
                lines.append(f"   Votes: {dict(tally)}")

                for resp in responses:
                    lines.append(f"   ‚Ä¢ {resp.agent}: {resp.code}")
                    if resp.rationale:
                        lines.append(f"     ‚Üí {resp.rationale[:80]}...")

        lines.append('='*70 + '\n')
        return '\n'.join(lines)

    def get_agent_journey(self, agent_name: str) -> List[AgentResponse]:
        """Track how a specific agent voted across rounds.

        Args:
            agent_name: Name of the agent to track

        Returns:
            List of AgentResponse objects for this agent, one per round
        """
        journey = []
        for round_responses in self.history:
            for resp in round_responses:
                if resp.agent == agent_name:
                    journey.append(resp)
                    break
        return journey

    def get_round_dicts(self) -> List[Dict[str, Any]]:
        """Convert history to a list of dictionaries."""
        res = []
        for round_idx, (responses, tally) in enumerate(zip(self.history, self.tallies), 1):
            round_dict = {
                "round_num": round_idx,
                "votes": dict(tally),
                "responses": [resp.convert_to_dict() for resp in responses]
            }
            res.append(round_dict)
            self.round_dicts = res
        return res

### utils for write_row_html_log during code run

In [None]:
import os
import html
import json
from datetime import datetime
from typing import Any, Dict, List, Optional

def _esc(x: Any) -> str:
    return html.escape("" if x is None else str(x))

def _safe_pre(x: Any, max_len: int = 5000) -> str:
    s = "" if x is None else str(x)
    if len(s) > max_len:
        s = s[:max_len] + "\n...(truncated)..."
    return html.escape(s)

def _flatten_discussion_result(dr: "DiscussionResult") -> List[Dict[str, Any]]:
    """
    Returns list of rows: one per agent per round.
    Uses dr.get_round_dicts() (your JSON-friendly trace).
    """
    flat = []
    for rd in (dr.get_round_dicts() or []):
        rnum = rd.get("round_num")
        votes = rd.get("votes", {})
        for resp in rd.get("responses", []):
            flat.append({
                "round": rnum,
                "agent": resp.get("agent", ""),
                "code": resp.get("code", ""),
                "rationale": resp.get("rationale", ""),
                "raw": resp.get("raw", ""),
                "votes": votes,
            })
    return flat

def write_row_html_log(
    out_dir: str,
    idx: Any,
    transcript: str,
    dr: Optional["DiscussionResult"] = None,
    human_code: str = "",
    error: str = "",
) -> str:
    """
    Writes an HTML file for this row. Returns filepath.
    """
    # make a out_dir name folder inside RESULTS_PATH
    os.makedirs(os.path.join(RESULTS_PATH, out_dir), exist_ok=True)
    fname = f"row_{idx}.html"
    path = os.path.join(out_dir, fname)
    print(f"Folder created (or already exists) at: {os.path.join(RESULTS_PATH, out_dir)}")

    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    if error or dr is None:
        html_doc = f"""<!doctype html>
<html><head><meta charset="utf-8"><title>Row { _esc(idx) } (ERROR)</title>
<style>
body{{font-family:system-ui,-apple-system,Segoe UI,Roboto,Arial,sans-serif;margin:24px;line-height:1.35}}
.card{{border:1px solid #ddd;border-radius:14px;padding:16px;margin:14px 0}}
pre{{white-space:pre-wrap;word-wrap:break-word;background:#fafafa;padding:12px;border-radius:10px;border:1px solid #eee;margin:0}}
</style></head><body>
<h1>Row { _esc(idx) } (ERROR)</h1>
<p style="color:#666">Generated { _esc(now) }</p>
<div class="card"><h2>Transcript</h2><pre>{ _safe_pre(transcript) }</pre></div>
<div class="card"><h2>Error</h2><pre>{ _safe_pre(error) }</pre></div>
</body></html>"""
        with open(path, "w", encoding="utf-8") as f:
            f.write(html_doc)
        return path

    flat = _flatten_discussion_result(dr)

    # Table rows
    trs = []
    for row in flat:
        votes_json = json.dumps(row["votes"], ensure_ascii=False)
        trs.append(
            "<tr>"
            f"<td>{_esc(row['round'])}</td>"
            f"<td><b>{_esc(row['agent'])}</b></td>"
            f"<td><b>{_esc(row['code'])}</b></td>"
            f"<td style='white-space:pre-wrap'>{_esc(row['rationale'])}</td>"
            f"<td><details><summary>votes</summary><pre style='white-space:pre-wrap'>{_safe_pre(votes_json, max_len=2000)}</pre></details></td>"
            f"<td><details><summary>raw</summary><pre style='white-space:pre-wrap'>{_safe_pre(row['raw'])}</pre></details></td>"
            "</tr>"
        )
    body_rows = "\n".join(trs) if trs else "<tr><td colspan='6'>(No trace)</td></tr>"

    # Agreement badge
    agree = (human_code.strip() != "" and human_code == dr.final_code)
    badge = "‚úÖ" if agree else ("‚ùå" if human_code.strip() else "‚Äî")

    html_doc = f"""<!doctype html>
<html><head><meta charset="utf-8"><title>Row { _esc(idx) }</title>
<style>
body{{font-family:system-ui,-apple-system,Segoe UI,Roboto,Arial,sans-serif;margin:24px;line-height:1.35}}
h1{{margin:0 0 8px 0}}
.sub{{color:#666;margin-bottom:18px}}
.card{{border:1px solid #ddd;border-radius:14px;padding:16px;margin:14px 0}}
.grid{{display:grid;grid-template-columns:180px 1fr;gap:8px 12px}}
.label{{color:#555}}
pre{{white-space:pre-wrap;word-wrap:break-word;background:#fafafa;padding:12px;border-radius:10px;border:1px solid #eee;margin:0}}
table{{border-collapse:collapse;width:100%}}
th,td{{border:1px solid #ddd;padding:10px;vertical-align:top}}
th{{background:#f5f5f5;text-align:left}}
.pill{{display:inline-block;padding:2px 10px;border:1px solid #ddd;border-radius:999px;font-size:12px}}
</style></head><body>

<h1>{badge} Row {_esc(idx)}</h1>
<div class="sub">Generated {_esc(now)}</div>

<div class="card">
  <h2>Transcript</h2>
  <pre>{_safe_pre(transcript)}</pre>
</div>

<div class="card">
  <h2>Final decision</h2>
  <div class="grid">
    <div class="label">final_code</div><div><span class="pill">{_esc(dr.final_code)}</span></div>
    <div class="label">confidence</div><div>{_esc(dr.confidence)}</div>
    <div class="label">consensus_reached</div><div>{_esc(dr.consensus_reached)}</div>
    <div class="label">num_rounds</div><div>{_esc(dr.num_rounds)}</div>
    <div class="label">human_code</div><div>{_esc(human_code)}</div>
    <div class="label">final_rationale</div><div style="white-space:pre-wrap">{_esc(dr.final_rationale)}</div>
  </div>
</div>

<div class="card">
  <h2>Rounds (agent √ó round)</h2>
  <table>
    <tr><th>Round</th><th>Agent</th><th>Code</th><th>Rationale</th><th>Votes</th><th>Raw</th></tr>
    {body_rows}
  </table>
</div>

</body></html>"""

    with open(path, "w", encoding="utf-8") as f:
        f.write(html_doc)

    return path

def write_index_html(out_dir: str, row_files: List[str], title: str = "Discussion Logs") -> str:
    """
    Writes an index.html that links to all row files.
    row_files should be file *names* or relative paths within out_dir.
    """
    path = os.path.join(out_dir, "index.html")
    links = "\n".join([f"<li><a href='{html.escape(os.path.basename(p))}'>{html.escape(os.path.basename(p))}</a></li>"
                       for p in row_files])
    doc = f"""<!doctype html>
<html><head><meta charset="utf-8"><title>{html.escape(title)}</title>
<style>
body{{font-family:system-ui,-apple-system,Segoe UI,Roboto,Arial,sans-serif;margin:24px}}
li{{margin:6px 0}}
</style></head><body>
<h1>{html.escape(title)}</h1>
<ul>
{links}
</ul>
</body></html>"""
    with open(path, "w", encoding="utf-8") as f:
        f.write(doc)
    return path


### MAD

In [None]:
@dataclass
class DiscussionConfig:
    """Configuration for multi-agent discussion behavior."""
    max_rounds: int = 3
    consensus_threshold: float = 0.9
    max_retries_per_agent: int = 3
    allowed_codes: frozenset = field(
        default_factory=lambda: frozenset({"WCT", "GT", "Other", "NONE"})
    )

    def __post_init__(self):
        """Validate configuration parameters."""
        if self.max_rounds < 1:
            raise ValueError("max_rounds must be at least 1")
        if not 0 < self.consensus_threshold <= 1:
            raise ValueError("consensus_threshold must be in (0, 1]")
        if self.max_retries_per_agent < 1:
            raise ValueError("max_retries_per_agent must be at least 1")
        if not self.allowed_codes:
            raise ValueError("allowed_codes cannot be empty")

class MultiAgentDiscussion:
    """
    Draft multi-agent version

    Assumptions (editable):
    - N agents.
    - Round 1: all agents independently propose a code.
    - If unanimous: return immediately.
    - Else if a strict majority exists: re-prompt only minority agents (they see majority label + optional rationales).
    - Else (tie / no majority): run a discussion round where each agent sees a compact transcript of others.
    - Missing-code outputs are re-tried per-agent per-round

    """
    def __init__(self,
                 agents: List[Any],
                 config: Optional[DiscussionConfig] = None):

        # Validate inputs
        if not agents:
            raise ValueError("At least one agent is required")

        # Initialize configuration
        self.config = config or DiscussionConfig()

        self.agents = agents
        self.round_num = 0
        self.prev_responses = None
        self.history = []
        self.tallies = []
        self.consensus_reached = False
        self.logger = logging.getLogger(self.__class__.__name__)
        self.store_trace = True

        self.max_rounds = self.config.max_rounds
        self.threshold = self.config.consensus_threshold
        self.allowed_codes = self.config.allowed_codes

    def _top_non_none(self, items, allowed_codes) -> Tuple[Optional[str], int, int]:
        """
        - validates codes
        - ignores "NONE"
        - returns (top_code, top_count, total_n)
        """
        # Normalize input into a list of codes
        if isinstance(items, dict):
            codes = []
            for code, count in items.items():
                codes.extend([code] * int(count))  # expand tally into votes
        else:
            codes = list(items)

        total_n = len(codes)
        if total_n == 0:
            return None, 0, 0

        # Validate codes (fail fast)
        invalid = set(codes) - allowed_codes
        if invalid:
            self.logger.error(f"Invalid codes encountered: {invalid}")
            return None, 0, total_n

        # Drop NONE votes
        valid = [c for c in codes if c != "NONE"]
        if not valid:
            return None, 0, total_n

        top_code, top_count = Counter(valid).most_common(1)[0]
        return top_code, top_count, total_n

    def is_consensus_reached(self, tally: Dict[str, int], threshold: float = 0.8) -> Tuple[Optional[str], float]:
        """
        Determines if consensus is reached among agents.

        Args:
            tally: Dictionary mapping code to count of agents who chose that code
            threshold: Minimum agreement ratio required for consensus (default: 0.8 = 80%)

        Returns:
            Tuple of (code, agreement_ratio) if consensus reached, else (None, 0.0)
            - code: The consensus code if agreement >= threshold
            - agreement_ratio: Proportion of total agents agreeing on the top code
        """
        top_code, top_count, total_n = self._top_non_none(tally, self.allowed_codes)
        if top_code is None or total_n == 0:
            return None, 0.0

        agreement_ratio = top_count / total_n
        if agreement_ratio >= threshold:
            self.logger.info(f"Consensus reached on code '{top_code}' with agreement {agreement_ratio:.2f}")
            return top_code, agreement_ratio

        self.logger.info(f"No consensus: top code '{top_code}' has agreement {agreement_ratio:.2f}")
        return None, agreement_ratio

    def get_majority_vote(self, codes: List[str]) -> Union[str, int]:
        """
        Get majority vote from list of codes.

        Args:
            codes: List of code strings

        Returns:
            Majority code if strict majority exists (>50%), else -1

        Raises:
            ValueError: If invalid codes are present
        """
        top_code, top_count, total_n = self._top_non_none(codes, self.config.allowed_codes)
        if not top_code or total_n == 0:
            return -1

        return top_code if top_count > (total_n / 2) else -1

    def _finalize(self, last_round_responses: List[AgentResponse]) -> Tuple[str, str, float]:
        """
        Finalize the discussion by plurality vote.
        Returns:
            - final_code: str
            - final_rationale: str
            - confidence: float
        """
        codes = [resp.code for resp in last_round_responses]

        # sanity check (fail fast instead of silently lying)
        invalid = set(codes) - self.allowed_codes
        if invalid:
            raise ValueError(f"Invalid codes encountered: {invalid}")

        counts = Counter(codes)
        total_n = len(codes)
        if not counts:
            return "NONE", "", 0.0

        final_code, freq = counts.most_common(1)[0]
        confidence = freq / len(self.agents)

        # Aggregate rationales for the final code
        rationales = [resp.rationale for resp in last_round_responses
                      if resp.code == final_code and resp.rationale]
        final_rationale = " | ".join(rationales)

        return final_code, final_rationale, confidence

    def _build_simple_discussion_context(self, prev_responses: Optional[List[AgentResponse]],
                                  round_num: int) -> str:
        """
        Build human-readable discussion context from previous round.

        This provides agents with:
        - Clear round indicator
        - All previous agent positions and rationales (full transparency)
        - Vote distribution summary
        - Majority position if exists
        - Instruction to reconsider

        Args:
            prev_round: Previous round's agent responses (None for round 1)
            round_num: Current round number

        Returns:
            Formatted string context for agents to consider
        """
        if not prev_responses:
            return f"Round 1: Provide your independent assessment."

        # Start building context
        lines = [f"Round {round_num}: Previous round responses:", ""]

        # Show each agent's position and rationale
        for resp in prev_responses:
            rationale = resp.rationale if resp.rationale else "(no rationale provided)"
            lines.append(f"- {resp.agent} chose '{resp.code}': {rationale}")

        lines.append("")  # Blank line for readability

        # Add vote distribution summary
        codes = [r.code for r in prev_responses if r.code != "NONE"]
        if codes:
            tally = Counter(codes)
            tally_str = ", ".join(f"{code}: {count}" for code, count in tally.most_common())
            lines.append(f"Vote distribution: {tally_str}")

            # Highlight majority if exists
            majority = self.get_majority_vote([r.code for r in prev_responses])
            if majority != -1:
                lines.append(f"Majority position: {majority}")

        # Instruction to reconsider
        lines.extend([
            "",
            "Consider the above responses. You may:",
            "- Change your assessment if you find others' reasoning convincing",
            "- Maintain your position if you believe your reasoning is stronger",
            "- Provide additional rationale to explain your choice"
        ])

        return "\n".join(lines)

    def reset(self) -> None:
        """
        Reset discussion state for reuse.

        Clears all round history, tallies, and consensus flags.
        Useful for running multiple discussions with the same agent pool.
        """
        self.round_num = 0
        self.history.clear()
        self.tallies.clear()
        self.consensus_reached = False
        self.logger.debug("Discussion state reset")

    def _to_agent_response(self, agent: Any, raw: str, round_num: int) -> AgentResponse:
        """
        Convert raw agent output into AgentResponse using agent.validate_and_parse().
        Keeps invalid outputs as NONE with an error rationale.
        """
        valid, parsed, err = agent.validate_and_parse(raw)

        if not valid or not isinstance(parsed, dict):
            code = "NONE"
            rationale = f"Parse error/invalid format: {err}"
        else:
            code = parsed.get("CAD-code", "NONE")
            rationale = parsed.get("rationale", "") or ""

        code = self._validate_code(code)

        return AgentResponse(
            agent=getattr(agent, "name", str(agent)),
            code=code,
            rationale=rationale,
            raw=raw,
            round=round_num,
        )
    def _validate_code(self, code: str) -> str:
        if code not in self.allowed_codes:
            self.logger.warning(f"Invalid code received: {code}")
            return "NONE"
        return code
    def _tally_round(self, round_responses: List[AgentResponse]) -> Dict[str, int]:
        """Count votes for a single round."""
        return dict(Counter(r.code for r in round_responses))

    ############### Discussion funcs ##################
    def discuss(self, text: str, human_code: str = "", **kwargs) -> DiscussionResult:
        """
        Debate with consensus.
        """
        self.reset()  # avoid leaking history across calls
        self.logger.info("== Starting MultiAgentDiscussion with %d agents", len(self.agents))
        self.text = text

        for round_idx in range(self.config.max_rounds):
            self.round_num = round_idx + 1
            self.logger.info("Round %d/%d", self.round_num, self.config.max_rounds)

            # Agents see previous round responses as context (None on round 1)
            prev_resp = self.history[-1] if self.history else None
            ctx = self._build_simple_discussion_context(prev_resp, self.round_num)

            # Collect responses for this round
            round_responses: List[AgentResponse] = []
            tally: Dict[str, int] = {}
            for agent in self.agents:
              max_retries = self.config.max_retries_per_agent if self.config.max_retries_per_agent else 1

              raw = agent.assign_code(text, extra_context=ctx, max_retries = max_retries, **kwargs)
              # Parse and validate response using the agent's validate_and_parse method
              round_responses.append(self._to_agent_response(agent, raw, self.round_num))
              # update vote counts for this round

            # Save round artifacts
            self.history.append(round_responses)
            tally = self._tally_round(round_responses)
            self.tallies.append(tally)
            self.logger.debug(f"Round {self.round_num} tally: {tally}")

            # stop early if consensus reached
            consensus_code, agreement = self.is_consensus_reached(tally, self.config.consensus_threshold)
            if consensus_code:
                self.consensus_reached = True
                rationales = [resp.rationale for resp in round_responses
                            if resp.code == consensus_code and resp.rationale]

                return DiscussionResult(
                    text_to_code=text,
                    human_code=human_code,
                    final_code=consensus_code,
                    final_rationale=" | ".join(rationales),
                    confidence=agreement,
                    history=self.history,
                    tallies=self.tallies,
                    consensus_reached=True,
                    num_rounds=self.round_num
                )

        # Max rounds reached without consensus - finalize by plurality
        final_code, final_rationale, confidence = self._finalize(self.history[-1])

        return DiscussionResult(
            text_to_code=text,
            human_code=human_code,
            final_code=final_code,
            final_rationale=final_rationale,
            confidence=confidence,
            history=self.history,
            tallies=self.tallies,
            consensus_reached=False,
            num_rounds=self.round_num
        )

    def run_batch_discussions(
        self,
        data_df: pd.DataFrame,
        text_col: str = "transcript",
        store_traces: bool = True,
        batch_num: int = -1,
        log_every: int = 10,
        save_name: Optional[str] = None,   # base path without extension
        save_every: int = 50,
        stop_on_error: bool = False,
        **kwargs
    ) -> Tuple[List[DiscussionResult], pd.DataFrame]:

      if text_col not in data_df.columns:
        raise ValueError(f"Column '{text_col}' not found. Available: {list(data_df.columns)}")

      # How many rows?
      total = len(data_df)
      n = total if (batch_num is None or int(batch_num) < 0) else min(int(batch_num), total)

      idxs = list(data_df.index[:n])
      texts = data_df.loc[idxs, text_col].tolist()

      results: List["DiscussionResult"] = []
      rows: List[Dict[str, Any]] = []  # for output df
      row_html_files = []

      # FIXED: Added human_code_val argument to signature
      def to_row(idx, text, human_code_val, result: Optional["DiscussionResult"], error: str) -> Dict[str, Any]:
        """One output row (success or failure)."""
        if error:
            return {
                "row_index": idx,
                "text_to_code": text,
                "human_code": human_code_val,
                "final_code": "NONE",
                "final_rationale": "",
                "confidence": 0.0,
                "consensus_reached": False,
                "num_rounds": 0,
                "tallies": [],
                "round_dicts": [],
                "error": error,
            }

        assert result is not None
        round_dicts = []
        tallies = []
        if store_traces:
            try:
                round_dicts = result.get_round_dicts()
            except Exception as e:
                self.logger.warning("Trace serialization failed at row %s: %s", idx, e)
                round_dicts = []
            tallies = result.tallies

        return {
            "row_index": idx,
            "text_to_code": result.text_to_code,
            "human_code": human_code_val,
            "final_code": result.final_code,
            "final_rationale": result.final_rationale,
            "confidence": result.confidence,
            "consensus_reached": result.consensus_reached,
            "num_rounds": result.num_rounds,
            "tallies": tallies,
            "round_dicts": round_dicts,
            "error": "",
        }

      def checkpoint(reason: str) -> None:
          """Write a CSV checkpoint (no-op if save_name is None)."""
          if not save_name:
              return
          save_CSV_path = os.path.join(RESULTS_PATH, f"{save_name}.csv")
          out_df = pd.DataFrame(rows, index=[r["row_index"] for r in rows])
          out_df.to_csv(f"{save_name}.csv", index=True)
          self.logger.info("Checkpoint (%s): wrote %d rows -> %s.csv", reason, len(out_df), save_name)

      self.logger.info(
        "Batch start: %d/%d rows | text_col='%s' | traces=%s | checkpoints=%s",
        n, total, text_col, store_traces, ("on" if save_name else "off")
      )

      success = 0
      errors = 0

      for i, (idx, text) in enumerate(tqdm(list(zip(idxs, texts)), desc="Processing", total=n), start=1):
          human_code = str(data_df.loc[idx, "CAD"]) if "CAD" in data_df.columns else ""
          try:
              res = self.discuss(text, human_code=human_code, **kwargs) # Pass human_code here

              if not store_traces:
                  res = DiscussionResult(
                      text_to_code=res.text_to_code,
                      human_code=human_code,
                      final_code=res.final_code,
                      final_rationale=res.final_rationale,
                      confidence=res.confidence,
                      history=[],
                      tallies=[],
                      consensus_reached=res.consensus_reached,
                      num_rounds=res.num_rounds,
                  )

              html_path = write_row_html_log(
                  out_dir="discussion_html",
                  idx=idx,
                  transcript=text,
                  dr=res,
                  human_code=human_code
              )
              row_html_files.append(html_path)

              results.append(res)
              rows.append(to_row(idx, text, human_code, res, error=""))
              success += 1

          except Exception as e:
              err = f"{type(e).__name__}: {str(e)[:500]}"
              self.logger.exception("Row %s failed (%d/%d): %s", idx, i, n, err)

              placeholder = DiscussionResult(
                  text_to_code=str(text),
                  human_code=human_code,
                  final_code="NONE",
                  final_rationale="",
                  confidence=0.0,
                  history=[],
                  tallies=[],
                  consensus_reached=False,
                  num_rounds=0,
              )
              results.append(placeholder)
              rows.append(to_row(idx, text, human_code, None, error=err))
              errors += 1

              if save_name:
                  checkpoint(reason=f"error_at_{idx}")

              if stop_on_error:
                  break

          if log_every and (i % log_every == 0 or i == n):
              rate = (success / i) * 100 if i else 0.0
              self.logger.info("Progress: %d/%d | success=%d | errors=%d | rate=%.1f%%", i, n, success, errors, rate)

          if save_every and save_name and (i % save_every == 0):
              checkpoint(reason=f"periodic_{i}")

      write_index_html("discussion_html", row_html_files, title="MAD / MultiAgentDiscussion Logs")
      if save_name:
          checkpoint(reason="final")

      output_df = pd.DataFrame(rows, index=[r["row_index"] for r in rows])
      self.logger.info("Batch done: success=%d errors=%d total=%d", success, errors, len(output_df))
      return results, output_df

## Experiment run


### Configs

In [None]:
balenced_role = "Your job is to weigh evidence, reconcile disagreements, and enforce codebook fidelity."
adversery_role = "Rigorous prosecutor. Be skeptical. Demand direct textual evidence (quote a short phrase). Actively try to falsify other agents‚Äô codes. If the text is ambiguous, say so and propose a safe fallback."
creative_role = "Creative empathic explorer. Look for subtle intent, context, and edge cases. Propose alternative readings and uncommon-but-plausible codes, but justify with text evidence."

In [None]:
MAD_config=DiscussionConfig(
        max_rounds=3,
        consensus_threshold=0.9,
        max_retries_per_agent=2,
        allowed_codes={"WCT", "GT", "Other", "NONE"}
      )

text = "So remember you guys are in groups so talk to your partner about the cards you move. Make sure your partner agrees with you."

# --- Agents ---
agents = {
    "a1": SingleAgentCoding("Ava", "balanced arbiter", balenced_role, debug=False, codebook=CAD_CODEBOOK_DICT),
    "a2": SingleAgentCoding("Ben", "rigorous and concise", adversery_role, debug=False, codebook=CAD_CODEBOOK_DICT),
    "a3": SingleAgentCoding("Cam", "creative and empathic", creative_role, debug=False, codebook=CAD_CODEBOOK_DICT)
}


mad = MultiAgentDiscussion(
    list(agents.values()),
    config=MAD_config
)



### discussion run

In [None]:
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')

batch_num = BATCH_NUM or 2
output_file_name = f"{today}-batch_0-{batch_num}_results"

res, new_df = mad.run_batch_discussions(df_2022,
                                          text_col="transcript",
                                          batch_num= batch_num,
                                          log_every=10,
                                          save_every=20,
                                          stop_on_error=False,
                                          store_traces = True,
                                          save_name=output_file_name
                                          )


In [None]:
new_df.head()

### Downlaod the output files

In [None]:
import shutil
from google.colab import files
import os

#  the output FILEs
new_df.to_parquet(f"{output_file_name}.parquet")
new_df.to_csv(f"{output_file_name}.csv")


files.download(f"{output_file_name}.parquet")
files.download(f"{output_file_name}.csv")

# DOWNLOAD discussion_html folder
folder_to_download = "discussion_html"
zip_filename = f"{folder_to_download}.zip"

# Create a zip archive of the folder
shutil.make_archive(folder_to_download, 'zip', folder_to_download)

# Download the zip file
if os.path.exists(zip_filename):
    files.download(zip_filename)
    print(f"Downloaded {zip_filename}")
else:
    print(f"Error: {zip_filename} not created or found.")

### Write agents infos to file

In [None]:
import html
import json
import json

# Collect agent info
all_agent_info = {}
for agent_id, agent in agents.items():
    all_agent_info[agent_id] = agent.get_agent_info()

# Save to file
output_file_name = "agent_infos"
output_file = f"{output_file_name}.json"
with open(output_file, "w") as f:
    json.dump(all_agent_info, f, indent=2)

print(f"Agent information saved to {output_file}")

def write_agent_info_to_html(agent_info_dict, filename="agent_infos.html"):
    """Writes agent configuration dictionary to a styled HTML file."""
    html_content = [
        "<!doctype html>",
        "<html><head><meta charset='utf-8'><title>Agent Information</title>",
        "<style>",
        "body{font-family:system-ui,-apple-system,sans-serif;margin:20px;line-height:1.5;color:#333;background:#f4f4f9;}",
        ".agent-card{border:1px solid #ddd;margin-bottom:20px;padding:25px;border-radius:10px;box-shadow:0 4px 6px rgba(0,0,0,0.05);background:#fff;}",
        "h1{color:#2c3e50;text-align:center;margin-bottom:30px;}",
        "h2{margin-top:0;color:#34495e;border-bottom:2px solid #f0f0f0;padding-bottom:10px;}",
        ".property{margin:10px 0;}",
        ".label{font-weight:600;color:#555;display:inline-block;width:140px;vertical-align:top;}",
        ".value{display:inline-block;width:calc(100% - 150px);}",
        "pre{background:#f8f9fa;padding:10px;border-radius:4px;overflow-x:auto;margin:0;}",
        "ul{margin:0;padding-left:20px;}",
        "</style></head><body>",
        "<h1>Agent Configuration</h1>"
    ]

    for agent_id, info in agent_info_dict.items():
        html_content.append(f"<div class='agent-card'>")
        name = info.get('name', agent_id)
        html_content.append(f"<h2>Agent: {html.escape(str(name))} <small style='color:#777;font-weight:normal'>({html.escape(str(agent_id))})</small></h2>")

        for key, value in info.items():
            val_html = ""
            if key == 'codebook' and isinstance(value, dict):
                val_html = "<ul>" + "".join([f"<li><b>{html.escape(str(k))}:</b> {html.escape(str(v))}</li>" for k,v in value.items()]) + "</ul>"
            elif isinstance(value, (dict, list)):
                 val_html = f"<pre>{html.escape(json.dumps(value, indent=2))}</pre>"
            else:
                val_html = html.escape(str(value))

            html_content.append(f"<div class='property'><span class='label'>{html.escape(str(key))}:</span><span class='value'>{val_html}</span></div>")

        html_content.append("</div>")

    html_content.append("</body></html>")

    with open(filename, "w", encoding="utf-8") as f:
        f.write("\n".join(html_content))
    print(f"Agent info HTML saved to {filename}")

# # Execute using the dictionary from the previous step
# if 'all_agent_info' in locals():
#     write_agent_info_to_html(all_agent_info)
# else:
#     # Fallback to regenerate info if variable missing
#     temp_info = {aid: ag.get_agent_info() for aid, ag in agents.items()}
#     write_agent_info_to_html(temp_info)

# Test (don't run)

In [None]:

# read parquet
df = pd.read_parquet(f"{output_file_name}.parquet")
df.head()
df.to_csv(f"{output_file_name}.csv")

#  Convert this to list
for roud_resp in df["round_dicts"][0]:
  print(type(roud_resp))


for result in res:


  a = result.__dict__()
  # print(type(a))
  for round in a.get("round_dicts"):
    print(f"round #{round["round_num"]}")
    for agent in round.get("responses"):
      print(f"agent: {agent['agent']}, coded: {agent['code']} with rationale: {agent['rationale']}")
      print(f"======raw: {agent['raw']}")

  print(result)

In [None]:
def convert_agent_response(agent_resp: str, print_raw = False) -> AgentResponse:
  name = agent_resp.split("agent=")[1].split(",")[0]
  code = agent_resp.split("code=")[1].split(",")[0]
  rationale = agent_resp.split("rationale=")[1].split(", raw=")[0]
  round = agent_resp.split("round=")[1].split(")")[0]
  raw = agent_resp.split("raw=")[1].split("round=")[0]

  print(f" At round {1} - agent {name}, coded this text as {code}, with rational that {rationale}")
  if print_raw:
    print(f"raw agent output was: {raw}")
  agent_resp_obj = AgentResponse(agent=name, code=code, rationale=rationale, raw=raw, round=int(round))
  return agent_resp_obj




# text = df['transcript'][0]
# # print(text)
# agents_text = df['raw'][0]
# discussion_results = agents_text.replace("'", '"')
# # print(discussion_res)
# # convert agents_text str to list
# b = discussion_results.split("], ")
# print(b[0].split("round=")[0])
# # print(b[1])?
# # print(b[2])

# # agents_res_list = agents_text.split("[AgentResponse(")
# # agents_res_valid = []
# # for agent_res in agents_res_list:
# #   if "agent=" in agent_res:
# #     agents_res_valid.append(agent_res)

# # # agent_text
# # len(agents_res_valid)
# # for agent_text in agents_res_valid:
# #   print(agent_text)
# #   agent_resp_obj = convert_agent_response(agent_text)
#   # print(agent_resp_obj)
# # # convert str to list
# # agent_text = agent_text.replace("'", '"')
# # ares = convert_agent_response(agent_text)
# # ares

# **NEW CODE -1/20/2026**

In [None]:
#!/usr/bin/env python3
# ============================================================
# MAD (3 agents x 3 rounds) for CAD coding on MacBook
# TIGHT JSON FORCING VERSION:
# - Balanced-brace JSON extraction (most reliable)
# - Retry-on-invalid JSON (up to 3 tries per agent/round)
# - Deterministic retry (temperature=0) + stronger constraints
# - Decode ONLY new tokens (prevents prompt echo)
# - Normalizes human CAD for accurate agreement
# - Saves Excel + HTML to Desktop
# ============================================================

import os, re, json, time, gc, html
from collections import Counter
from typing import Optional, List, Dict, Any

import numpy as np
import pandas as pd

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# -------------------------
# USER SETTINGS
# -------------------------
FILE_PATH   = "/Users/elahetajik/Desktop/data -ASU.xlsx"
SHEET_NAME  = None
YEAR_FILTER = 2022
K           = 200

OUT_XLSX = "/Users/elahetajik/Desktop/mad_outputs_2022_K200.xlsx"
OUT_HTML = "/Users/elahetajik/Desktop/mad_report_2022_K200.html"

MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Generation controls
TEMP_ROUND1    = 0.3
TEMP_DISCUSS   = 0.0
MAX_NEW_TOKENS = 160

# JSON retry controls
JSON_RETRIES = 3  # attempts per agent per round

# -------------------------
# CAD CODEBOOK
# -------------------------
CAD_CODEBOOK_DICT = {
    "WCT":   "Teacher is addressing the whole class.",
    "GT":    "Teacher is addressing a group or a student in a group (student-level talk).",
    "Other": "Teacher is not addressing the whole class or groups/students (silent/self talk/visitor/tech).",
    "NONE":  "Ambiguous / cannot determine."
}
ALLOWED_CODES = {"WCT", "GT", "Other", "NONE"}

# -------------------------
# DEVICE
# -------------------------
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("‚úÖ Device:", device)

# -------------------------
# LOAD DATA
# -------------------------
if not os.path.exists(FILE_PATH):
    raise FileNotFoundError(f"Excel not found: {FILE_PATH}")

df = pd.read_excel(FILE_PATH) if SHEET_NAME is None else pd.read_excel(FILE_PATH, sheet_name=SHEET_NAME)
print("‚úÖ Loaded rows:", len(df))
print("‚úÖ Columns:", df.columns.tolist())

if "transcript" not in df.columns:
    raise ValueError("‚ùå Your Excel must have a column named exactly: transcript")

df_work = df.copy()

if "year" in df_work.columns:
    df_work = df_work[df_work["year"] == YEAR_FILTER].copy()
    print(f"‚úÖ After year=={YEAR_FILTER} filter:", len(df_work))
else:
    print("‚ö†Ô∏è No 'year' column found; skipping year filter.")

df_work["transcript"] = df_work["transcript"].fillna("").astype(str).str.strip()

# -------------------------
# NORMALIZE HUMAN CAD
# -------------------------
def normalize_human_code(x: Any) -> str:
    s = "" if pd.isna(x) else str(x)
    s = s.strip()
    if not s:
        return ""
    sl = s.lower()

    # exact / common
    if sl in {"wct", "whole class", "whole-class", "wholeclass"}:
        return "WCT"
    if sl in {"gt", "group", "small group", "partner"}:
        return "GT"
    if sl in {"other", "oth"}:
        return "Other"
    if sl in {"none", "na", "n/a", "ambiguous", "unclear"}:
        return "NONE"

    # containment for messy labels
    if "wct" in sl or "whole" in sl:
        return "WCT"
    if re.search(r"\bgt\b", sl) or "group" in sl or "partner" in sl:
        return "GT"
    if "other" in sl:
        return "Other"
    if "none" in sl or "ambig" in sl or "unclear" in sl:
        return "NONE"

    return s

if "CAD" in df_work.columns:
    df_work["CAD"] = df_work["CAD"].fillna("").astype(str)
    df_work["CAD_norm"] = df_work["CAD"].apply(normalize_human_code)
else:
    df_work["CAD_norm"] = ""

df_work = df_work.head(K).copy()
print("‚úÖ Running K =", len(df_work))

# -------------------------
# LOAD MODEL
# -------------------------
print("‚è≥ Loading model:", MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

torch_dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch_dtype,
    device_map=None
).to(device)

model.eval()
print("‚úÖ Model loaded.")

# -------------------------
# JSON EXTRACTION (brace-balanced)
# -------------------------
def extract_first_balanced_json(text: str) -> Optional[dict]:
    """
    Finds the FIRST balanced {...} object in the text by counting braces.
    Much more reliable than regex for messy outputs.
    """
    if not text:
        return None

    s = text.strip()
    # remove fences if any
    s = re.sub(r"^```(?:json)?\s*", "", s).strip()
    s = re.sub(r"\s*```$", "", s).strip()

    start = s.find("{")
    if start == -1:
        return None

    depth = 0
    in_str = False
    escape = False

    for i in range(start, len(s)):
        ch = s[i]

        if in_str:
            if escape:
                escape = False
            elif ch == "\\":
                escape = True
            elif ch == '"':
                in_str = False
            continue

        # not in string
        if ch == '"':
            in_str = True
            continue
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                cand = s[start:i+1].strip()
                try:
                    obj = json.loads(cand)
                    return obj if isinstance(obj, dict) else None
                except Exception:
                    return None

    return None

def normalize_model_json(d: dict) -> dict:
    code = (d.get("CAD-code") or d.get("code") or d.get("CAD") or "").strip()
    rationale = (d.get("rationale") or d.get("reasoning") or "").strip()

    if code not in ALLOWED_CODES:
        code = "NONE"

    low = rationale.lower()
    if (not rationale) or ("evidence-based" in low and "sentences" in low):
        rationale = "Rationale missing/placeholder."

    if len(rationale) > 350:
        rationale = rationale[:350] + "..."

    return {"CAD-code": code, "rationale": rationale}

# -------------------------
# PROMPTS
# -------------------------
def agent_system(name: str, personality: str) -> str:
    return (
        f"You are {name}, a {personality} qualitative-coding agent.\n"
        "Task: Assign ONE CAD code for the teacher transcript using the codebook.\n"
        "STRICT OUTPUT RULE:\n"
        "- Output ONLY a JSON object. No extra text. No markdown. No thinking.\n"
        'JSON schema: {"CAD-code":"WCT|GT|Other|NONE","rationale":"1-3 short evidence-based sentences"}\n'
        "Heuristics:\n"
        "- WCT: whole-class address/instructions.\n"
        "- GT: talk to a small group/partner(s).\n"
        "- Other: teacher tech/self/visitor talk not directing students.\n"
        "- NONE: only if truly impossible to infer.\n"
    )

def agent_user(transcript: str, extra_context: str) -> str:
    cb = "\n".join([f"- {k}: {v}" for k, v in CAD_CODEBOOK_DICT.items()])
    return (
        f"Codebook:\n{cb}\n\n"
        f"{extra_context}\n\n"
        f'Text to code:\n"""{transcript}"""\n\n'
        "Return ONLY JSON now."
    )

# -------------------------
# GENERATION (decode only NEW tokens)
# -------------------------
@torch.no_grad()
def generate_once(system_prompt: str, user_prompt: str, temperature: float, max_new_tokens: int) -> str:
    messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}]

    if hasattr(tokenizer, "apply_chat_template"):
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        prompt = system_prompt + "\n\n" + user_prompt

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)

    gen_kwargs = dict(
        max_new_tokens=int(max_new_tokens),
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=(temperature > 0.0),
    )
    if temperature > 0.0:
        gen_kwargs["temperature"] = float(temperature)

    out = model.generate(**inputs, **gen_kwargs)

    new_tokens = out[0][inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    return text

def generate_json_with_retries(system_prompt: str, user_prompt: str,
                              temperature: float, max_new_tokens: int,
                              retries: int = JSON_RETRIES) -> Dict[str, Any]:
    """
    Try multiple times to get valid JSON.
    - attempt 1: normal (given temperature)
    - attempt 2+: temperature=0 + stronger forcing message appended
    Returns dict: {"raw":..., "parsed": dict|None}
    """
    last_raw = ""

    for attempt in range(1, retries + 1):
        if attempt == 1:
            temp = temperature
            user = user_prompt
        else:
            temp = 0.0
            # strong forcing on retries
            user = (
                user_prompt
                + "\n\nCRITICAL: Your entire output MUST be ONLY the JSON object."
                + "\nIt MUST BEGIN with '{' and END with '}'. No other characters."
            )

        raw = generate_once(system_prompt, user, temperature=temp, max_new_tokens=max_new_tokens)
        last_raw = raw

        # quick trim: if there's a '{', start there (helps)
        if "{" in raw:
            raw_trim = raw[raw.find("{"):].strip()
        else:
            raw_trim = raw.strip()

        parsed = extract_first_balanced_json(raw_trim)
        if parsed and isinstance(parsed, dict):
            return {"raw": raw_trim, "parsed": parsed}

    return {"raw": last_raw, "parsed": None}

# -------------------------
# MAD
# -------------------------
AGENTS = [
    ("Ava", "balanced (neutral perspective)"),
    ("Ben", "rigorous and concise (strict evidence grounding)"),
    ("Cam", "creative and empathic (attentive to contextual nuance)")
]

def summarize_peers(prev: List[dict]) -> str:
    lines = ["Previous round peer responses:"]
    for p in prev:
        lines.append(f'- Code: {p["CAD-code"]} | Rationale: {p["rationale"]}')
    lines.append("Revise if needed, but remain evidence-based. Output ONLY JSON.")
    return "\n".join(lines)

def run_mad_for_text(text: str, rounds_total: int = 3) -> dict:
    history: List[List[dict]] = []

    # Round 1
    r1 = []
    for (name, pers) in AGENTS:
        sys_p = agent_system(name, pers)
        usr_p = agent_user(text, extra_context="Round 1: Independent assessment.")
        got = generate_json_with_retries(sys_p, usr_p, temperature=TEMP_ROUND1, max_new_tokens=MAX_NEW_TOKENS)
        parsed = got["parsed"] if got["parsed"] else {}
        norm = normalize_model_json(parsed)
        norm.update({"agent": name, "raw": got["raw"]})
        r1.append(norm)
    history.append(r1)

    current = r1

    # Round 2..3
    for r in range(2, rounds_total + 1):
        nxt = []
        for i, (name, pers) in enumerate(AGENTS):
            peers = [c for j, c in enumerate(current) if j != i]
            ctx = f"Round {r}: Revise after seeing peers.\n" + summarize_peers(peers)
            sys_p = agent_system(name, pers)
            usr_p = agent_user(text, extra_context=ctx)
            got = generate_json_with_retries(sys_p, usr_p, temperature=TEMP_DISCUSS, max_new_tokens=MAX_NEW_TOKENS)
            parsed = got["parsed"] if got["parsed"] else {}
            norm = normalize_model_json(parsed)
            norm.update({"agent": name, "raw": got["raw"]})
            nxt.append(norm)
        history.append(nxt)
        current = nxt

    final_codes = [a["CAD-code"] for a in history[-1]]
    counts = Counter([c for c in final_codes if c in ALLOWED_CODES])
    label, freq = counts.most_common(1)[0] if counts else ("NONE", 0)
    conf = freq / len(final_codes) if final_codes else 0.0

    return {"history": history, "majority": label, "conf": conf}

# -------------------------
# RUN + SAVE
# -------------------------
rows_out = []
html_blocks = []

start = time.time()

for n, (idx, row) in enumerate(df_work.iterrows(), start=1):
    transcript = row["transcript"]
    human = row["CAD_norm"] if "CAD_norm" in df_work.columns else ""

    res = run_mad_for_text(transcript, rounds_total=3)
    hist = res["history"]
    maj = res["majority"]
    conf = res["conf"]

    agree = int(human == maj) if human in ALLOWED_CODES else 0

    outrow = dict(row)
    outrow["mad_final_code"] = maj
    outrow["mad_final_conf"] = float(conf)
    outrow["agree_with_human_CAD"] = agree

    # Flatten per-agent per-round columns
    for r_i, round_list in enumerate(hist, start=1):
        for a in round_list:
            agent = a["agent"]
            outrow[f"R{r_i}_{agent}_code"] = a["CAD-code"]
            outrow[f"R{r_i}_{agent}_rationale"] = a["rationale"]
            outrow[f"R{r_i}_{agent}_raw"] = a["raw"]

    # Final rationale = rationales of majority voters in final round
    final_round = hist[-1]
    rats = [a["rationale"] for a in final_round if a["CAD-code"] == maj and a["rationale"]]
    outrow["mad_final_rationale"] = " | ".join(rats) if rats else ""

    rows_out.append(outrow)

    # HTML report block
    badge = "‚úÖ" if agree == 1 else "‚ùå"
    esc_t = html.escape(transcript)
    esc_h = html.escape(human)
    esc_m = html.escape(maj)

    block = []
    block.append(f"<h2>{badge} Row {idx} | Final: {esc_m} (conf={conf:.2f}) | Human: {esc_h} | Agree={agree}</h2>")
    block.append(f"<p><b>Transcript:</b> {esc_t}</p>")

    for r_i, round_list in enumerate(hist, start=1):
        block.append(f"<h3>Round {r_i}</h3>")
        block.append("<table border='1' cellpadding='6' cellspacing='0' style='border-collapse:collapse;width:100%'>")
        block.append("<tr><th>Agent</th><th>Code</th><th>Rationale</th><th>Raw</th></tr>")
        for a in round_list:
            raw = html.escape(a["raw"] or "")
            rat = html.escape(a["rationale"] or "")
            code = html.escape(a["CAD-code"] or "")
            agent = html.escape(a["agent"] or "")
            block.append(
                f"<tr><td><b>{agent}</b></td><td><b>{code}</b></td><td>{rat}</td>"
                f"<td><details><summary>show</summary><pre style='white-space:pre-wrap'>{raw}</pre></details></td></tr>"
            )
        block.append("</table>")

    html_blocks.append("\n".join(block))

    # cleanup
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

    if n % 10 == 0:
        elapsed = time.time() - start
        print(f"Progress: {n}/{len(df_work)} rows | {elapsed/60:.1f} min elapsed")

elapsed = time.time() - start
print(f"‚úÖ Done. Processed {len(rows_out)} rows in {elapsed/60:.2f} min")

out_df = pd.DataFrame(rows_out)
out_df.to_excel(OUT_XLSX, index=False)

with open(OUT_HTML, "w", encoding="utf-8") as f:
    f.write("<html><head><meta charset='utf-8'><title>MAD Report</title></head><body>")
    f.write(f"<h1>MAD Results | YEAR={YEAR_FILTER} | K={len(df_work)} | Model={MODEL_ID}</h1>")
    f.write("<p>Report includes: transcript, human CAD, final MAD code, and all agents/rounds with raw.</p>")
    f.write("<hr>")
    f.write("\n<hr>\n".join(html_blocks))
    f.write("</body></html>")

print("‚úÖ Saved Excel:", OUT_XLSX)
print("‚úÖ Saved HTML :", OUT_HTML)
print("Open HTML with:")
print(f"open '{OUT_HTML}'")
