In [2]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
import json
import time

# 导入项目模块
# 解决 ImportError: attempted relative import beyond top-level package
# 采用绝对路径导入，假设 arc_solver 是顶层包
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
try:
    from arc_solver.core.solver import ArcSolver
    from arc_solver.core.config import SolverConfig
    from arc_solver.data.task import Task
    print("✅ 模块导入成功 (arc_solver.*)")
except ImportError as e:
    print("❌ 模块导入失败:", e)
    print("请检查 arc_solver 包的路径和结构。")


NameError: name '__file__' is not defined

In [None]:
# ARC色彩映射（与标准ARC一致）
colors = ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00', 
          '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25']
cmap = ListedColormap(colors)

def plot_grid(grid, title="Grid", ax=None):
    """绘制ARC网格，复用main.ipynb的可视化逻辑"""
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    
    im = ax.imshow(grid, cmap=cmap, vmin=0, vmax=9)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xticks([])
    ax.set_yticks([])
    
    # 添加网格线
    h, w = grid.shape
    for i in range(h + 1):
        ax.axhline(i - 0.5, color='black', linewidth=0.5)
    for j in range(w + 1):
        ax.axvline(j - 0.5, color='black', linewidth=0.5)
    
    return ax

def plot_task(task, title="ARC Task"):
    """可视化完整的ARC任务"""
    num_examples = len(task.train)
    fig, axes = plt.subplots(2, num_examples + 1, figsize=(4 * (num_examples + 1), 8))
    
    # 训练样例
    for i, example in enumerate(task.train):
        plot_grid(example.input, f"训练输入 {i+1}", axes[0, i])
        plot_grid(example.output, f"训练输出 {i+1}", axes[1, i])
    
    # 测试样例
    plot_grid(task.test[0], "测试输入", axes[0, -1])
    axes[1, -1].text(0.5, 0.5, '?', ha='center', va='center', 
                     fontsize=48, transform=axes[1, -1].transAxes)
    axes[1, -1].set_title("测试输出 (待求解)", fontsize=14, fontweight='bold')
    axes[1, -1].set_xticks([])
    axes[1, -1].set_yticks([])
    
    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    return fig

print("✅ 可视化函数准备完成")


In [None]:
# 创建一个需要多步变换的测试任务
# 规律：输入矩阵 → 转置 → 颜色映射 → 输出
demo_task_data = {
    "task_id": "demo_task",
    "train": [
        {
            "input": [
                [1, 0, 2],
                [0, 1, 0],
                [2, 0, 1]
            ],
            "output": [
                [3, 0, 4],
                [0, 3, 0], 
                [4, 0, 3]
            ]
        },
        {
            "input": [
                [2, 1, 0],
                [1, 2, 1],
                [0, 1, 2]
            ],
            "output": [
                [4, 3, 0],
                [3, 4, 3],
                [0, 3, 4]
            ]
        }
    ],
    "test": [
        [
            [0, 2, 1],
            [2, 0, 2],
            [1, 2, 0]
        ]
    ]
}

# 导入TaskLoader
from data.task import TaskLoader

# 转换为Task对象
demo_task = TaskLoader.from_json(demo_task_data)

# 可视化任务
plot_task(demo_task, "Demo Task: 需要多步变换的复杂任务")
plt.show()

print("📋 任务规律提示：输入→转置→颜色映射(1→3, 2→4)→输出")
print("🎯 这需要2步变换，单一专用solver难以处理")


In [None]:
# 配置1：高阈值，确保触发DAG fallback
config_high_threshold = SolverConfig()
config_high_threshold.dag_high_confidence_threshold = 25.0  # 很高的阈值
config_high_threshold.dag_max_depth = 2
config_high_threshold.dag_enable_logging = True
config_high_threshold.max_candidates = 3

# 配置2：低阈值，专用solver优先
config_low_threshold = SolverConfig()
config_low_threshold.dag_high_confidence_threshold = 5.0   # 很低的阈值
config_low_threshold.dag_max_depth = 2
config_low_threshold.dag_enable_logging = True
config_low_threshold.max_candidates = 3

print("⚙️ Solver配置完成")
print(f"   高阈值配置: {config_high_threshold.dag_high_confidence_threshold}")
print(f"   低阈值配置: {config_low_threshold.dag_high_confidence_threshold}")


In [None]:
def solve_with_detailed_logging(config, title):
    """详细记录求解过程"""
    print(f"\n🚀 {title}")
    print("=" * 60)
    
    solver = ArcSolver(config=config)
    
    start_time = time.time()
    result = solver.solve(demo_task)
    total_time = time.time() - start_time
    
    print(f"\n📊 求解结果:")
    print(f"   总耗时: {total_time:.3f}s")
    print(f"   预测数: {len(result.predictions)}")
    print(f"   使用fallback: {result.used_fallback}")
    print(f"   主要来源: {result.metadata.get('primary_source', 'unknown')}")
    print(f"   专用solver最高分: {result.metadata.get('specialist_max_score', 0):.1f}")
    print(f"   DAG solver最高分: {result.metadata.get('dag_max_score', 0):.1f}")
    print(f"   置信度阈值: {result.metadata.get('high_confidence_threshold', 0):.1f}")
    
    return result

# 测试低阈值配置
print("第一轮测试：低阈值配置 (专用solver优先)")
result_low = solve_with_detailed_logging(config_low_threshold, "低阈值配置 (专用solver优先)")

# 测试高阈值配置  
print("\n" + "="*80)
print("第二轮测试：高阈值配置 (触发DAG fallback)")
result_high = solve_with_detailed_logging(config_high_threshold, "高阈值配置 (触发DAG fallback)")


In [None]:
def visualize_results(result, title):
    """可视化求解结果"""
    if not result.predictions:
        print(f"❌ {title}: 无预测结果")
        return
    
    num_predictions = len(result.predictions)
    fig, axes = plt.subplots(1, num_predictions + 1, figsize=(4 * (num_predictions + 1), 4))
    
    if num_predictions == 1:
        axes = [axes]
    
    # 显示测试输入
    plot_grid(demo_task.test[0], "测试输入", axes[0])
    
    # 显示预测结果
    for i, prediction in enumerate(result.predictions):
        score = result.scores[i] if i < len(result.scores) else 0
        plot_grid(prediction, f"预测 {i+1}\n得分: {score:.1f}", axes[i+1])
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig

# 可视化两种配置的结果
print("🎨 结果可视化")
visualize_results(result_low, "低阈值结果 (专用solver优先)")
plt.show()

visualize_results(result_high, "高阈值结果 (DAG fallback)")
plt.show()


In [None]:
import matplotlib.lines as mlines
from matplotlib.patches import Circle

def visualize_dag_search():
    """可视化DAG搜索树结构"""
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # 节点位置
    positions = {
        'input': (6, 7),
        'transpose': (3, 5),
        'flipH': (6, 5), 
        'flipV': (9, 5),
        'colorMap_1': (2, 3),
        'colorMap_2': (4, 3),
        'identity': (8, 3),
        'final': (6, 1)
    }
    
    # 绘制节点
    for node, (x, y) in positions.items():
        if node == 'input':
            color = 'lightblue'
        elif node == 'final':
            color = 'lightgreen'
        else:
            color = 'lightcoral'
            
        circle = Circle((x, y), 0.4, color=color, ec='black', linewidth=2)
        ax.add_patch(circle)
        ax.text(x, y, node.replace('_', '\n'), ha='center', va='center', 
                fontsize=8, fontweight='bold')
    
    # 绘制连接线
    connections = [
        ('input', 'transpose'),
        ('input', 'flipH'),
        ('input', 'flipV'),
        ('transpose', 'colorMap_1'),
        ('transpose', 'colorMap_2'),
        ('flipV', 'identity'),
        ('colorMap_1', 'final'),
        ('colorMap_2', 'final')
    ]
    
    for start, end in connections:
        x1, y1 = positions[start]
        x2, y2 = positions[end]
        
        # 成功路径用绿色，其他用灰色
        if (start, end) in [('input', 'transpose'), ('transpose', 'colorMap_1'), ('colorMap_1', 'final')]:
            color = 'green'
            linewidth = 3
        else:
            color = 'gray'
            linewidth = 1
            
        ax.plot([x1, x2], [y1, y2], color=color, linewidth=linewidth, alpha=0.7)
        
        # 添加箭头
        dx, dy = x2 - x1, y2 - y1
        ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
                   arrowprops=dict(arrowstyle='->', color=color, lw=linewidth))
    
    # 添加层级标签
    ax.text(0.5, 7, 'Layer 0\n(Input)', ha='center', va='center', 
            fontsize=12, fontweight='bold', bbox=dict(boxstyle='round', facecolor='wheat'))
    ax.text(0.5, 5, 'Layer 1\n(Transform)', ha='center', va='center',
            fontsize=12, fontweight='bold', bbox=dict(boxstyle='round', facecolor='wheat'))
    ax.text(0.5, 3, 'Layer 2\n(Transform)', ha='center', va='center',
            fontsize=12, fontweight='bold', bbox=dict(boxstyle='round', facecolor='wheat'))
    ax.text(0.5, 1, 'Output', ha='center', va='center',
            fontsize=12, fontweight='bold', bbox=dict(boxstyle='round', facecolor='wheat'))
    
    # 图例
    legend_elements = [
        mlines.Line2D([0], [0], color='green', lw=3, label='成功路径'),
        mlines.Line2D([0], [0], color='gray', lw=1, label='探索路径'),
        mpatches.Circle((0, 0), 0.1, facecolor='lightblue', label='输入'),
        mpatches.Circle((0, 0), 0.1, facecolor='lightcoral', label='变换函数'),
        mpatches.Circle((0, 0), 0.1, facecolor='lightgreen', label='输出')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    ax.set_xlim(-1, 11)
    ax.set_ylim(0, 8)
    ax.set_aspect('equal')
    ax.set_title('DAG搜索树 (深度=2)\n成功路径: Input → Transpose → ColorMap → Output', 
                fontsize=14, fontweight='bold')
    ax.axis('off')
    
    return fig

visualize_dag_search()
plt.show()

print("🌳 DAG搜索说明:")
print("   • 绿色路径: 成功的变换序列")
print("   • 灰色路径: 被探索但未成功的路径")
print("   • 深度限制为2层，避免过度搜索")
print("   • DAG结构避免重复计算相同的中间状态")


In [None]:
print("📊 DAG Fallback Demo 总结")
print("=" * 60)

# 对比结果
comparison_data = {
    '配置': ['低阈值 (专用优先)', '高阈值 (DAG fallback)'],
    '使用DAG': [result_low.used_fallback, result_high.used_fallback],
    '预测数': [len(result_low.predictions), len(result_high.predictions)],
    '专用最高分': [result_low.metadata.get('specialist_max_score', 0), 
                  result_high.metadata.get('specialist_max_score', 0)],
    'DAG最高分': [result_low.metadata.get('dag_max_score', 0),
                 result_high.metadata.get('dag_max_score', 0)]
}

print("\n📈 性能对比:")
for key, values in comparison_data.items():
    print(f"   {key:12}: {values[0]:>15} | {values[1]:>15}")

print("\n🎯 关键发现:")
if result_high.used_fallback and len(result_high.predictions) > 0:
    print("   ✅ DAG fallback成功解决了专用solver无法处理的复杂任务")
    print("   ✅ 通过2层搜索找到了正确的变换序列")
    print("   ✅ 置信度阈值机制确保了资源的高效利用")
else:
    print("   ⚠️  当前测试案例可能需要调整以更好展示DAG优势")

print("\n🔧 应用价值:")
print("   • DAG适合处理需要多步变换的复杂任务")
print("   • Fallback机制避免不必要的计算开销")
print("   • 搜索深度限制确保实时性能")
print("   • 适合作为其他solver的智能补充")

print("\n" + "=" * 60)
print("🎉 Demo完成！")
