In [3]:
import os
import json
import torch
from task4 import Task, Blackboard  # adjust if your Task class is elsewhere
from tqdm import tqdm
import trainer

# Define augmentation types
augmentation_types = [
    'rotate_90',
    'rotate_180',
    'rotate_270',
    'flip_horizontal',
    'flip_vertical',
    'value_permutation'
]

def load_task_data(directory):
    """Load raw task input/output grids from JSON files"""
    raw_tasks = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".json"):
                file_path = os.path.join(root, file)
                with open(file_path, "r") as f:
                    data = json.load(f)
                    if "train" not in data or "test" not in data:
                        print(f"Warning: Invalid task format in {file_path}")
                        continue
                    raw_tasks.append({
                        "task_id": os.path.splitext(file)[0],
                        "train_pairs": [(pair["input"], pair["output"]) for pair in data["train"]],
                        "test_pairs": [(pair["input"], pair["output"]) for pair in data["test"]],
                    })
    return raw_tasks

def precompute_and_save_task(task_dict, augmentation_types=None, save_dir="precomputed_tasks"):
    """
    Precompute task data with optional augmentations and save to disk
    
    Args:
        task_dict: Dictionary containing task data (task_id, train_pairs, test_pairs)
        augmentation_types: List of augmentation methods to apply (default: None)
        save_dir: Directory to save precomputed tasks
    """
    os.makedirs(save_dir, exist_ok=True)
    task_id = task_dict["task_id"]
    print(f"Precomputing: {task_id}")
    
    # Create the original task
    original_task = Task(
        task_id=task_id, 
        train_pairs=task_dict["train_pairs"], 
        test_pairs=task_dict["test_pairs"]
    )
    
    # Generate augmented versions if augmentation types are provided
    all_tasks = [original_task]
    if augmentation_types:
        augmented_tasks = trainer.generate_augmented_dataset([original_task], augmentation_types)
        # Skip the first task as it's the original one already in all_tasks
        all_tasks.extend(augmented_tasks[1:])
    
    # Save each task (original + augmented)
    for task in all_tasks:
        # Determine the filename (original or augmented)
        if task.task_id == task_id:
            # Original task
            filename = f"{task_id}.pt"
        else:
            # Augmented task
            filename = f"{task.task_id}.pt"
        
        # Save the task data
        torch.save({
            "task_id": task.task_id,
            "train_graphs": task.train_graphs,
            "test_graphs": task.test_graphs,
            "train_targets": task.train_targets,
            "test_targets": task.test_targets
        }, os.path.join(save_dir, filename))
        
    print(f"Saved precomputed task(s): {task_id} with {len(all_tasks)-1} augmentations")

# Full pipeline function to process a directory of tasks
def process_task_directory(input_dir, output_dir="precomputed_tasks", augmentation_types=None):
    """
    Process all tasks in a directory, applying augmentations and saving results
    
    Args:
        input_dir: Directory containing task JSON files
        output_dir: Directory to save precomputed tasks
        augmentation_types: List of augmentation methods to apply
    """
    print(f"Processing tasks from {input_dir}")
    print(f"Using augmentations: {augmentation_types if augmentation_types else 'None'}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load all raw tasks
    raw_tasks = load_task_data(input_dir)
    print(f"Found {len(raw_tasks)} tasks to process")
    
    # Process each task
    for task_dict in raw_tasks:
        precompute_and_save_task(
            task_dict=task_dict,
            augmentation_types=augmentation_types,
            save_dir=output_dir
        )
    
    print(f"Completed processing {len(raw_tasks)} tasks with augmentations")

def load_precomputed_tasks(directory):
    """Load precomputed tasks from .pt files in directory"""
    tasks = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".pt"):
                file_path = os.path.join(root, file)
                try:
                    task = load_precomputed_task(file_path)
                    tasks.append(task)
                except Exception as e:
                    print(f"Warning: Failed to load task from {file_path}: {str(e)}")
                    continue
    return tasks


def load_precomputed_task(path):
    """Load a single precomputed task from a .pt file"""
    data = torch.load(path, weights_only=False)
    task = Task.__new__(Task)  # Bypass __init__
    task.task_id = data["task_id"]
    task.train_graphs = data["train_graphs"]
    task.test_graphs = data["test_graphs"]
    task.train_targets = data["train_targets"]
    task.test_targets = data["test_targets"]
    task.edge_types = ["edge_index", "value_edge_index", "region_edge_index", 
                       "contextual_edge_index", "alignment_edge_index"]
    task.blackboard = Blackboard()
    return task
    

In [4]:
import trainer

trainer.precompute_tasks(
    input_dir="data/evaluation", 
    output_dir="precomputed_tasks/evaluation_400",
    augmentation_types=None
)

Loaded 400 original tasks from data/evaluation
Processing task: 00576224
Processing task: 009d5c81
Processing task: 00dbd492
Processing task: 03560426
Processing task: 05a7bcf2
Processing task: 0607ce86
Processing task: 0692e18c
Processing task: 070dd51e
Processing task: 08573cc6
Processing task: 0934a4d8
Processing task: 09c534e7
Processing task: 0a1d4ef5
Processing task: 0a2355a6
Processing task: 0b17323b
Processing task: 0bb8deee
Processing task: 0becf7df
Processing task: 0c786b71
Processing task: 0c9aba6e
Processing task: 0d87d2a6
Processing task: 0e671a1a
Processing task: 0f63c0b9
Processing task: 103eff5b
Processing task: 11e1fe23
Processing task: 12422b43
Processing task: 12997ef3
Processing task: 12eac192
Processing task: 136b0064
Processing task: 13713586
Processing task: 137f0df0
Processing task: 140c817e
Processing task: 14754a24
Processing task: 15113be4
Processing task: 15663ba9
Processing task: 15696249
Processing task: 16b78196
Processing task: 17b80ad2
Processing task: 

In [7]:
task = traiload_precomputed_task("precomputed_tasks/007bbfb7_rotate_90.pt")
print(task.task_id, len(task.train_graphs), len(task.test_graphs))

0a938d79_rotate_90 4 1
