In [None]:
import os, json, zipfile
from collections import deque, Counter
import numpy as np
from tqdm import tqdm
from rich import print as print_rich
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)

DIR = "/kaggle"

# ---------------- PRIMITIVES ----------------

def rot90_np(g): return np.rot90(g, 1)
def rot180_np(g): return np.rot90(g, 2)
def rot270_np(g): return np.rot90(g, 3)
def fliph_np(g): return np.fliplr(g)
def flipv_np(g): return np.flipud(g)
def invert_colors_np(g): return 9 - g
def identity_np(g): return g
def remap_colors_np(g):
    flat = g.flatten()
    uniq = np.unique(flat)
    mapping = {c: i for i, c in enumerate(uniq)}
    return np.vectorize(mapping.get)(g)

def majority_fill_np(g):
    vals, counts = np.unique(g, return_counts=True)
    fill_val = vals[np.argmax(counts)]
    out = g.copy()
    out[out == 0] = fill_val
    return out

primitive_code_snippets = {
    "identity": """return g""",
    "rot90": """return [list(row) for row in zip(*g[::-1])]""",
    "rot180": """return [row[::-1] for row in g[::-1]]""",
    "rot270": """return [list(row) for row in zip(*g)][::-1]""",
    "fliph": """return [row[::-1] for row in g]""",
    "flipv": """return g[::-1]""",
    "invert_colors": """return [[9 - c for c in row] for row in g]""",
    "remap_colors": """uniq = sorted(set(c for row in g for c in row)); mp = {c:i for i,c in enumerate(uniq)}; return [[mp[c] for c in row] for row in g]""",
    "majority_fill": """flat = [c for row in g for c in row]; fill_val = max(set(flat), key=flat.count); return [[fill_val if c==0 else c for c in row] for row in g]"""
}

# ---------------- TRANSFORM CLASS ----------------
class Transform:
    def __init__(self, func, name=None):
        self.func = func
        self.name = name or func.__name__
    def __call__(self, grid):
        try: return self.func(grid)
        except: return None
    def __repr__(self): return f"Transform({self.name})"
    def compose(self, other):
        def composed(g):
            r = self(g)
            return None if r is None else other(r)
        return Transform(composed, f"{other.name}∘{self.name}")

# ---------------- SOLVER ----------------
class ARCCombinatorialSolver:
    def __init__(self, max_depth=3, max_candidates=7000):
        self.max_depth = max_depth
        self.max_candidates = max_candidates
        self.learned_rules = {}
        self.primitives = [
            Transform(identity_np,"identity"),
            Transform(rot90_np,"rot90"),
            Transform(rot180_np,"rot180"),
            Transform(rot270_np,"rot270"),
            Transform(fliph_np,"fliph"),
            Transform(flipv_np,"flipv"),
            Transform(invert_colors_np,"invert_colors"),
            Transform(remap_colors_np,"remap_colors"),
            Transform(majority_fill_np,"majority_fill")
        ]

    def match(self, rule, pairs):
        return all((pred:=rule(inp)) is not None and np.array_equal(pred, out) for inp,out in pairs)

    def score_rule(self, rule, pairs):
        """Count exact matches (higher better), penalize longer names."""
        matches = sum(np.array_equal(rule(inp), out) for inp,out in pairs)
        return matches, -len(rule.name)

    def bfs_search(self, pairs):
        q = deque([(Transform(identity_np,"identity"),0)])
        valid_rules = []
        visited_hashes=set()

        def grid_hash(g): return hash(g.tobytes())

        while q and len(valid_rules) < 5:
            rule, depth = q.popleft()
            if depth > self.max_depth: continue

            trans = []
            fail=False
            for inp,_ in pairs:
                out = rule(inp)
                if out is None: fail=True; break
                trans.append(out)
            if fail: continue

            h = tuple(grid_hash(g) for g in trans)
            if h in visited_hashes: continue
            visited_hashes.add(h)

            if self.match(rule, pairs):
                valid_rules.append(rule)
                continue

            if depth < self.max_depth:
                for prim in self.primitives:
                    q.append((rule.compose(prim), depth+1))
                    if len(q) > self.max_candidates: break
        if valid_rules:
            valid_rules.sort(key=lambda r: (-self.score_rule(r,pairs)[0], self.score_rule(r,pairs)[1]))
        return valid_rules

    def solve(self, task):
        tid = task.get('id')
        train_pairs=[(np.array(p['input'],int),np.array(p['output'],int)) for p in task['train']]
        test_inputs=[np.array(p['input'],int) for p in task['test']]

        if tid in self.learned_rules:
            rule = self.learned_rules[tid]
            if self.match(rule, train_pairs):
                return rule, [rule(inp).tolist() for inp in test_inputs]

        rules = self.bfs_search(train_pairs)
        if not rules:
            return Transform(identity_np,"identity"), [inp.tolist() for inp in test_inputs]
        best_rule=rules[0]
        self.learned_rules[tid]=best_rule
        return best_rule,[best_rule(inp).tolist() for inp in test_inputs]

# ---------------- CODE GENERATOR ----------------
def generate_code(transform):
    names = transform.name.split("∘")
    names = list(reversed(names))
    if names == ["identity"]: return "def p(g):\n    return g\n"
    code_lines=["def p(g):","    res = g"]
    for name in names:
        snippet = primitive_code_snippets.get(name,"return g")
        for line in snippet.strip().split("\n"):
            if "return" in line: line=line.replace("return","res =")
            code_lines.append("    "+line.strip())
    code_lines.append("    return res")
    return "\n".join(code_lines)

# ---------------- MAIN ----------------
print("Loading tasks...")
train_tasks={}
for i in tqdm(range(1,401)):
    tid=f"{i:03d}"
    with open(f"{DIR}/input/google-code-golf-2025/task{tid}.json") as f:
        t=json.load(f); t['id']=tid; train_tasks[tid]=t

solver=ARCCombinatorialSolver(max_depth=3,max_candidates=10000)

submission_dir=f"{DIR}/working/submission"
os.makedirs(submission_dir,exist_ok=True)

solved=0
for tid,task in tqdm(train_tasks.items()):
    transform,preds=solver.solve(task)
    code=generate_code(transform)
    with open(f"{submission_dir}/task{tid}.py","w") as f:
        f.write(code)
    solved+=1

print_rich(f"[green]Generated solutions for {solved} tasks[/green]")
with zipfile.ZipFile(f"{DIR}/working/submission.zip","w") as z:
    for i in range(1,401):
        tid=f"{i:03d}"
        z.write(f"{submission_dir}/task{tid}.py",arcname=f"task{tid}.py")
print_rich(f"[green]Submission zip created[/green]")
