In [None]:
!pip install flash-attn --no-build-isolation
!pip install flash-linear-attention

Collecting flash-attn
  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/8.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━[0m [32m5.8/8.4 MB[0m [31m173.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m8.4/8.4 MB[0m [31m189.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m116.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=256040057 sha256=f25da18657a87fc83dc1bfb8b7751b82246e9db355510226b674fd437c34b5fb
  Stored in directory: /root/.cache/pip/wheels/3d/59/46/f282c12c73dd4bb3c2e3fe199f1

In [None]:
import re
import torch
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from datasets import load_dataset

from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast

import fla
from fla.models import mesa_net  # <-- Add this line

# --- Configuration ---
MODEL_ID = "ChavyvAkvar/mesanet-baseline-1-epoch"
DATASET_ID = "kreasof-ai/ECA-Zero"
BATCH_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# From the dataset generation script
WOLFRAM_CLASSES_MAP = {
    1: [0, 8, 32, 40, 128, 136, 160, 168],
    2: [1, 19, 23, 29, 37, 50, 108, 178],
    3: [30, 45, 60, 90, 105, 126, 150],
    4: [54, 106, 110, 124, 137, 147, 193]
}

# Invert for fast lookup: Rule -> Class
RULE_TO_CLASS = {}
for cls, rules in WOLFRAM_CLASSES_MAP.items():
    for r in rules:
        RULE_TO_CLASS[r] = cls

class ECAVerifier:
    def __init__(self):
        self.re_rule = re.compile(r"Rule: (\d+)")
        self.re_start = re.compile(r"Start: ([01]+)")
        self.re_end = re.compile(r"End: ([01]+)")
        self.re_steps = re.compile(r"Steps: (\d+)")
        self.re_hint_class = re.compile(r"Hint: Class (\d)")
        self.re_tt = re.compile(r"([01]{3})->([01])")

    def get_wolfram_class(self, prompt):
        # 1. Check for explicit Hint (Induction tasks)
        m = self.re_hint_class.search(prompt)
        if m:
            return int(m.group(1))

        # 2. Check for Rule ID (Deduction/Abduction) and look up
        m = self.re_rule.search(prompt)
        if m:
            rule = int(m.group(1))
            return RULE_TO_CLASS.get(rule, 0) # 0 = Unknown/Other

        return 0

    def get_next_state(self, state, rule):
        next_state = []
        L = len(state)
        for i in range(L):
            l, c, r = state[(i - 1) % L], state[i], state[(i + 1) % L]
            pattern = (l << 2) | (c << 1) | r
            bit = 1 if (rule & (1 << pattern)) else 0
            next_state.append(bit)
        return next_state

    def simulate(self, start_state, rule, steps):
        current = list(start_state)
        for _ in range(steps):
            current = self.get_next_state(current, rule)
        return current

    def parse_rule_string(self, text):
        matches = self.re_tt.findall(text)
        if not matches: return None
        rule = 0
        for pat, res in matches:
            if res == '1': rule |= (1 << int(pat, 2))
        return rule

    def verify(self, task_type, prompt, model_output_str):
        try:
            steps = int(self.re_steps.search(prompt).group(1))
            start_match = self.re_start.search(prompt)
            start_state = [int(x) for x in start_match.group(1)] if start_match else None
            end_match = self.re_end.search(prompt)
            end_state = [int(x) for x in end_match.group(1)] if end_match else None
            rule_match = self.re_rule.search(prompt)
            rule = int(rule_match.group(1)) if rule_match else None
        except AttributeError:
            return False

        answer = model_output_str.strip()
        try:
            if task_type == 'deduction':
                pred_state = [int(x) for x in answer if x in '01']
                if not pred_state: return False
                expected = self.simulate(start_state, rule, steps)
                return pred_state == expected

            elif task_type == 'induction':
                pred_rule = self.parse_rule_string(answer)
                if pred_rule is None: return False
                sim_end = self.simulate(start_state, pred_rule, steps)
                return sim_end == end_state

            elif task_type == 'abduction':
                pred_start = [int(x) for x in answer if x in '01']
                if not pred_start or len(pred_start) != len(end_state): return False
                sim_end = self.simulate(pred_start, rule, steps)
                return sim_end == end_state
        except Exception:
            return False
        return False

def main():
    print(f"Loading tokenizer from {MODEL_ID}...")
    try:
        tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID)
    except:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"Loading model from {MODEL_ID}...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map=DEVICE,
    )
    print("Compiling the model")
    model = torch.compile(model)
    model.eval()

    print("Loading Test Set...")
    dataset = load_dataset(DATASET_ID, split="test")
    verifier = ECAVerifier()

    # Storage: results[task][class_id] = [True, False, ...]
    results = defaultdict(lambda: defaultdict(list))

    print("Starting Stratified Evaluation...")

    for i in tqdm(range(0, len(dataset), BATCH_SIZE)):
        batch = dataset[i : i + BATCH_SIZE]
        tasks = batch['task']
        inputs = batch['input']

        prompts = [f"{tokenizer.bos_token}{inp}\n<think>\n" for inp in inputs]

        # FIX: Added return_token_type_ids=False
        encodings = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
            return_token_type_ids=False,
        ).to(DEVICE)

        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=encodings['input_ids'],
                max_new_tokens=2048,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)

        for j, raw_output in enumerate(decoded_outputs):
            if "</think>" in raw_output:
                final_answer = raw_output.split("</think>")[-1].replace(tokenizer.eos_token, "").strip()
            else:
                final_answer = ""

            # Determine Class
            w_class = verifier.get_wolfram_class(inputs[j])

            # Verify
            is_correct = verifier.verify(tasks[j], inputs[j], final_answer)

            # Store
            results[tasks[j]][w_class].append(is_correct)
            results[tasks[j]]["ALL"].append(is_correct)

    # --- Print Report ---
    print("\n" + "="*60)
    print("STRATIFIED RESULTS (Accuracy by Wolfram Class)")
    print("="*60)

    # Define column headers
    print(f"{'Task':<12} | {'Class 1':<10} | {'Class 2':<10} | {'Class 3':<10} | {'Class 4':<10} | {'OVERALL':<10}")
    print("-" * 75)

    for task in ["deduction", "induction", "abduction"]:
        row_str = f"{task.capitalize():<12} | "

        for c in [1, 2, 3, 4]:
            outcomes = results[task][c]
            if outcomes:
                acc = sum(outcomes) / len(outcomes)
                row_str += f"{acc:.1%} ({len(outcomes):<3}) | " # concise
            else:
                row_str += "N/A        | "

        # Overall
        all_outcomes = results[task]["ALL"]
        if all_outcomes:
            total_acc = sum(all_outcomes) / len(all_outcomes)
            row_str += f"{total_acc:.1%} ({len(all_outcomes)})"

        print(row_str)

    print("="*60)
    print("Class Legend:")
    print("1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)")

if __name__ == "__main__":
    main()

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Loading tokenizer from ChavyvAkvar/mesanet-baseline-1-epoch...
Loading model from ChavyvAkvar/mesanet-baseline-1-epoch...


`torch_dtype` is deprecated! Use `dtype` instead!


Compiling the model
Loading Test Set...
Starting Stratified Evaluation...


  0%|          | 0/27 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
This is a friendly reminder - the current text generation call has exceeded the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
  4%|▎         | 1/27 [05:17<2:17:34, 317.48s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  7%|▋         | 2/27 [10:22<2:09:07, 309.88s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
 11%|█         | 3/27 [15:26<2:02:56, 307.37s/it]A decoder-only architecture is being used, but right-padding was detected! Fo


STRATIFIED RESULTS (Accuracy by Wolfram Class)
Task         | Class 1    | Class 2    | Class 3    | Class 4    | OVERALL   
---------------------------------------------------------------------------
Deduction    | 14.2% (113) | 11.5% (226) | 11.9% (412) | 12.4% (410) | 12.2% (1161)
Induction    | 0.0% (113) | 1.8% (227) | 1.0% (414) | 0.2% (411) | 0.8% (1165)
Abduction    | 0.0% (47 ) | 0.0% (185) | 0.0% (388) | 0.0% (387) | 0.0% (1007)
Class Legend:
1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)



