In [5]:
import os
import sys
import json
import boto3
from pathlib import Path
from urllib.parse import urlparse
import argparse

import torch
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


"""
This script adapts your original notebook code to run as a SageMaker *training entry point*.
It works from a single notebook by saving this file (train.py) and launching a SageMaker
PyTorch Estimator job. It supports:
  • Reading data from either a LOCAL directory (recommended with input_mode=File/FastFile)
    or directly from S3 (s3://bucket/prefix) using your job's IAM role credentials.
  • Single- or multi-GPU/multi-node (PyTorch DDP) when used with SageMaker's distribution
    configs. Only rank 0 prints to stdout to keep logs clean.
  • Writing artifacts/metrics to SM_MODEL_DIR so they show up in the job's model tarball.

Minimal launcher (in your notebook):

    from sagemaker.pytorch import PyTorch
    from sagemaker.inputs import TrainingInput

    est = PyTorch(
        entry_point="train.py",
        source_dir=".",
        role=role,
        framework_version="2.3",
        py_version="py310",
        instance_type="ml.g5.2xlarge",
        instance_count=1,                       # >1 for multi-node
        distribution={"torch_distributed": {"enabled": True}},  # enable DDP when count>1
        hyperparameters={"batch_size": 4, "epochs": 1, "add_one": True, "pad_value": 0},
        enable_sagemaker_metrics=True,
    )
    est.fit({"train": TrainingInput(s3_uri_to_your_jsons, input_mode="File")})

"""

# --------------------- I/O helpers ---------------------

def _is_s3_uri(p: str) -> bool:
    return str(p).startswith("s3://")


def _list_local_jsons(dir_path: str):
    if not os.path.isdir(dir_path):
        raise FileNotFoundError(f"Local directory not found: {dir_path}")
    # non-recursive, mirror original behavior
    files = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith('.json') and os.path.isfile(os.path.join(dir_path, f))]
    if not files:
        raise FileNotFoundError(f"No .json files found in {dir_path}")
    files.sort(key=lambda p: os.path.basename(p))
    return files


def _list_s3_jsons(s3_uri: str):
    if not _is_s3_uri(s3_uri):
        raise ValueError("Expected an S3 URI like s3://bucket/prefix/")
    o = urlparse(s3_uri)
    bucket = o.netloc
    prefix = o.path.lstrip('/')
    if prefix and not prefix.endswith('/'):
        prefix += '/'

    s3 = boto3.client("s3")  # uses job's IAM role
    paginator = s3.get_paginator("list_objects_v2")
    pages = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/")

    keys = []
    for page in pages:
        for obj in page.get("Contents", []):
            key = obj["Key"]
            if key.endswith(".json"):
                keys.append(key)
    if not keys:
        raise FileNotFoundError(f"No .json objects found under s3://{bucket}/{prefix}")
    keys.sort(key=lambda k: os.path.basename(k))
    return [(bucket, k) for k in keys]


def load_jsons_from_folder(path_or_s3: str):
    """
    Load and validate your ARC-style JSONs from either a local directory or s3:// prefix.
    Returns: dict[name -> parsed_json]
    """
    data = {}

    if _is_s3_uri(path_or_s3):
        s3 = boto3.client("s3")
        for bucket, key in _list_s3_jsons(path_or_s3):
            name = os.path.splitext(os.path.basename(key))[0]
            uri = f"s3://{bucket}/{key}"
            try:
                body = s3.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8")
                obj = json.loads(body)
            except Exception as e:
                print("Failed to read:", uri)
                print(" Error:", e)
                continue
            if _valid_arc_obj(obj):
                data[name] = obj
            else:
                print("Skipping (bad format):", uri)
    else:
        for fpath in _list_local_jsons(path_or_s3):
            name = os.path.splitext(os.path.basename(fpath))[0]
            try:
                with open(fpath, "r", encoding="utf-8") as f:
                    obj = json.load(f)
            except Exception as e:
                print("Failed to read:", fpath)
                print(" Error:", e)
                continue
            if _valid_arc_obj(obj):
                data[name] = obj
            else:
                print("Skipping (bad format):", fpath)

    if not data:
        raise FileNotFoundError("No valid ARC jsons loaded.")
    return data


def _valid_arc_obj(obj) -> bool:
    if ("train" not in obj) or ("test" not in obj):
        return False
    for split in ("train", "test"):
        if not isinstance(obj[split], list):
            return False
        for pairs in obj[split]:
            if ("input" not in pairs) or ("output" not in pairs):
                return False
    return True


# --------------------- your original transforms ---------------------

def _add_one_to_all_values_in_place(data):
    """
    Adds +1 to every scalar value in each input/output grid across all samples.
    Done BEFORE padding so pad_value=0 remains 0.
    """
    for sample in data.values():
        for split in ["train", "test"]:
            for pairs in sample.get(split, []):
                # input grid
                r = 0
                while r < len(pairs["input"]):
                    c = 0
                    row = pairs["input"][r]
                    while c < len(row):
                        row[c] = row[c] + 1
                        c += 1
                    r += 1
                # output grid
                r = 0
                while r < len(pairs["output"]):
                    c = 0
                    row = pairs["output"][r]
                    while c < len(row):
                        row[c] = row[c] + 1
                        c += 1
                    r += 1


def get_metrics(data):
    metric_dict = {
        "max_train_len": 0,
        "max_test_len": 0,
        "max_train_input_height": 0,
        "max_test_input_height": 0,
        "max_train_output_height": 0,
        "max_test_output_height": 0,
        "max_train_input_width": 0,
        "max_test_input_width": 0,
        "max_train_output_width": 0,
        "max_test_output_width": 0,
    }

    for sample in data.values():
        if (len(sample['train']) > metric_dict['max_train_len']):
            metric_dict['max_train_len'] = len(sample['train'])
        if (len(sample['test']) > metric_dict['max_test_len']):
            metric_dict['max_test_len'] = len(sample['test'])
        for pairs in sample['train']:
            if (len(pairs['input']) > metric_dict['max_train_input_height']):
                metric_dict['max_train_input_height'] = len(pairs['input'])
            if (len(pairs['output']) > metric_dict['max_train_output_height']):
                metric_dict['max_train_output_height'] = len(pairs['output'])
            for inp in pairs['input']:
                if (len(inp) > metric_dict['max_train_input_width']):
                    metric_dict['max_train_input_width'] = len(inp)
            for output in pairs['output']:
                if (len(output) > metric_dict['max_train_output_width']):
                    metric_dict['max_train_output_width'] = len(output)
        for pairs in sample['test']:
            if (len(pairs['input']) > metric_dict['max_test_input_height']):
                metric_dict['max_test_input_height'] = len(pairs['input'])
            if (len(pairs['output']) > metric_dict['max_test_output_height']):
                metric_dict['max_test_output_height'] = len(pairs['output'])
            for inp in pairs['input']:
                if (len(inp) > metric_dict['max_test_input_width']):
                    metric_dict['max_test_input_width'] = len(inp)
            for output in pairs['output']:
                if (len(output) > metric_dict['max_test_output_width']):
                    metric_dict['max_test_output_width'] = len(output)
    return metric_dict


def pad_data(data, metric_dict=None, pad_value=0):
    """
    Pads each sample independently to its own max square size.
    metric_dict is ignored (kept for backward compatibility).
    """
    for sample in data.values():
        # ----- compute per-sample maxima for TRAIN -----
        max_train_input_height = 0
        max_train_input_width  = 0
        max_train_output_height = 0
        max_train_output_width  = 0

        for pairs in sample.get('train', []):
            if len(pairs['input'])  > max_train_input_height:  max_train_input_height  = len(pairs['input'])
            if len(pairs['output']) > max_train_output_height: max_train_output_height = len(pairs['output'])
            for inp in pairs['input']:
                if len(inp) > max_train_input_width:  max_train_input_width  = len(inp)
            for outp in pairs['output']:
                if len(outp) > max_train_output_width: max_train_output_width = len(outp)

        # ----- compute per-sample maxima for TEST -----
        max_test_input_height = 0
        max_test_input_width  = 0
        max_test_output_height = 0
        max_test_output_width  = 0

        for pairs in sample.get('test', []):
            if len(pairs['input'])  > max_test_input_height:  max_test_input_height  = len(pairs['input'])
            if len(pairs['output']) > max_test_output_height: max_test_output_height = len(pairs['output'])
            for inp in pairs['input']:
                if len(inp) > max_test_input_width:  max_test_input_width  = len(inp)
            for outp in pairs['output']:
                if len(outp) > max_test_output_width: max_test_output_width = len(outp)

        # ----- per-sample square sizes -----
        max_train_size = max(
            max_train_input_height,
            max_train_input_width,
            max_train_output_height,
            max_train_output_width
        )
        max_test_size = max(
            max_test_input_height,
            max_test_input_width,
            max_test_output_height,
            max_test_output_width
        )

        # ----- pad TRAIN for this sample -----
        for pairs in sample.get('train', []):
            # input
            while len(pairs['input']) < max_train_size:
                pairs['input'].append([pad_value] * max_train_size)
            for inp in pairs['input']:
                while len(inp) < max_train_size:
                    inp.append(pad_value)
            # output
            while len(pairs['output']) < max_train_size:
                pairs['output'].append([pad_value] * max_train_size)
            for outp in pairs['output']:
                while len(outp) < max_train_size:
                    outp.append(pad_value)

        # ----- pad TEST for this sample -----
        for pairs in sample.get('test', []):
            # input
            while len(pairs['input']) < max_test_size:
                pairs['input'].append([pad_value] * max_test_size)
            for inp in pairs['input']:
                while len(inp) < max_test_size:
                    inp.append(pad_value)
            # output
            while len(pairs['output']) < max_test_size:
                pairs['output'].append([pad_value] * max_test_size)
            for outp in pairs['output']:
                while len(outp) < max_test_size:
                    outp.append(pad_value)

    return data


def _infer_original_size_from_padded(grid, pad_value=0):
    h = 0
    w = 0
    r = 0
    while r < len(grid):
        row = grid[r]
        any_nonpad = False
        last_nonpad = -1
        c = 0
        while c < len(row):
            if row[c] != pad_value:
                any_nonpad = True
                last_nonpad = c
            c += 1
        if any_nonpad:
            if (r + 1) > h:
                h = r + 1
            if (last_nonpad + 1) > w:
                w = last_nonpad + 1
        r += 1
    return (h, w)


def build_sample_level_dataset(data, pad_value=0):
    """
    Build a list of per-sample records.
    Stores per-pair masks: 1 where value != pad_value, else 0.
    """
    dataset = []
    for sample_name, sample in data.items():
        # containers
        train_pairs = []
        test_pairs = []

        # track original (unpadded) sizes per split
        train_max_h = 0
        train_max_w = 0
        test_max_h = 0
        test_max_w = 0

        # ----- TRAIN -----
        for pairs in sample['train']:
            inp_grid = pairs['input']
            out_grid = pairs['output']

            # original sizes (prefer stored, else infer)
            if ('orig_input_size' in pairs):
                in_h, in_w = pairs['orig_input_size']
            else:
                in_h, in_w = _infer_original_size_from_padded(inp_grid, pad_value)
            if ('orig_output_size' in pairs):
                out_h, out_w = pairs['orig_output_size']
            else:
                out_h, out_w = _infer_original_size_from_padded(out_grid, pad_value)

            # update split-wide original size (max over inputs/outputs)
            if in_h > train_max_h: train_max_h = in_h
            if out_h > train_max_h: train_max_h = out_h
            if in_w > train_max_w: train_max_w = in_w
            if out_w > train_max_w: train_max_w = out_w

            # tensors
            inp_tensor = torch.tensor(inp_grid).long()
            out_tensor = torch.tensor(out_grid).long()

            # masks (1 for non-pad, 0 for pad)
            inp_mask = (inp_tensor != pad_value).long()
            out_mask = (out_tensor != pad_value).long()

            # store pair
            train_pairs.append({
                "input": inp_tensor,
                "output": out_tensor,
                "input_mask": inp_mask,
                "output_mask": out_mask,
            })

        # ----- TEST -----
        for pairs in sample['test']:
            inp_grid = pairs['input']
            out_grid = pairs['output']

            if ('orig_input_size' in pairs):
                in_h, in_w = pairs['orig_input_size']
            else:
                in_h, in_w = _infer_original_size_from_padded(inp_grid, pad_value)
            if ('orig_output_size' in pairs):
                out_h, out_w = pairs['orig_output_size']
            else:
                out_h, out_w = _infer_original_size_from_padded(out_grid, pad_value)

            if in_h > test_max_h: test_max_h = in_h
            if out_h > test_max_h: test_max_h = out_h
            if in_w > test_max_w: test_max_w = in_w
            if out_w > test_max_w: test_max_w = out_w

            inp_tensor = torch.tensor(inp_grid).long()
            out_tensor = torch.tensor(out_grid).long()

            inp_mask = (inp_tensor != pad_value).long()
            out_mask = (out_tensor != pad_value).long()

            test_pairs.append({
                "input": inp_tensor,
                "output": out_tensor,
                "input_mask": inp_mask,
                "output_mask": out_mask,
            })

        # assemble sample-level record
        item = {
            "id": str(sample_name),
            "train_pairs": train_pairs,
            "test_pairs": test_pairs,
            "train_original_size": (train_max_h, train_max_w),
            "test_original_size": (test_max_h, test_max_w),
        }
        dataset.append(item)

    return dataset


class ARCSampleDataset(Dataset):
    def __init__(self, sample_list):
        self.data = sample_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # stack per-sample pairs into tensors
        train_inputs = torch.stack([p["input"] for p in sample["train_pairs"]])      # [num_train, H, W]
        train_outputs = torch.stack([p["output"] for p in sample["train_pairs"]])    # [num_train, H, W]
        test_inputs = torch.stack([p["input"] for p in sample["test_pairs"]])        # [num_test, H, W]
        test_outputs = torch.stack([p["output"] for p in sample["test_pairs"]])      # [num_test, H, W]

        # masks
        train_input_masks = torch.stack([p["input_mask"] for p in sample["train_pairs"]])
        train_output_masks = torch.stack([p["output_mask"] for p in sample["train_pairs"]])
        test_input_masks  = torch.stack([p["input_mask"] for p in sample["test_pairs"]])
        test_output_masks = torch.stack([p["output_mask"] for p in sample["test_pairs"]])

        return {
            "id": sample["id"],
            "train_inputs": train_inputs,
            "train_outputs": train_outputs,
            "test_inputs": test_inputs,
            "test_outputs": test_outputs,
            "train_input_masks": train_input_masks,
            "train_output_masks": train_output_masks,
            "test_input_masks": test_input_masks,
            "test_output_masks": test_output_masks,
            "train_original_size": torch.tensor(sample["train_original_size"], dtype=torch.long),
            "test_original_size": torch.tensor(sample["test_original_size"], dtype=torch.long),
        }


def arc_collate_fn_bs1(batch):
    # batch size is guaranteed to be 1; return the single dict unchanged
    return batch[0]


# --------------------- training entry point ---------------------

def init_distributed_if_needed():
    """Initialize torch.distributed if launched with torchrun (SageMaker DDP)."""
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    rank = int(os.environ.get("RANK", "0"))
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if world_size > 1:
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
        torch.distributed.init_process_group(backend=backend, init_method="env://")
    return world_size, rank, local_rank


def is_primary(rank: int) -> bool:
    return rank == 0


def save_json(obj, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", ""),
                        help="Path to training data directory (local) or s3:// prefix")
    parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--pad-value", type=int, default=0)
    parser.add_argument("--add-one", dest="add_one", action="store_true", default=False)
    parser.add_argument("--workers", type=int, default=max(0, (os.cpu_count() or 2)//2))
    parser.add_argument("--log-n", type=int, default=2, help="Log first N samples only (rank 0)")

    args = parser.parse_args()

    world_size, rank, local_rank = init_distributed_if_needed()

    def log(*a, **k):
        if is_primary(rank):
            print(*a, **k, flush=True)

    log("=== Config ===")
    log(vars(args))
    log(f"world_size={world_size} rank={rank} local_rank={local_rank} gpu={torch.cuda.is_available()}")

    # ----------------- Load & preprocess -----------------
    if not args.train:
        raise ValueError("--train path was not provided and SM_CHANNEL_TRAIN is empty.")

    data = load_jsons_from_folder(args.train)

    if args.add_one:
        _add_one_to_all_values_in_place(data)

    metrics = get_metrics(data)
    padded_data = pad_data(data, metrics, pad_value=args.pad_value)
    sample_level = build_sample_level_dataset(padded_data, pad_value=args.pad_value)

    # Persist a tiny metrics file for inspection
    if is_primary(rank):
        save_json({"metrics": metrics, "num_samples": len(sample_level)}, os.path.join(args.model_dir, "preprocess_metrics.json"))

    # ----------------- DataLoader -----------------
    ds = ARCSampleDataset(sample_list=sample_level)
    # Pin memory helps if using GPU; persistent workers when workers>0
    pin = torch.cuda.is_available()
    loader = DataLoader(
        ds,
        batch_size=1,
        shuffle=True,
        collate_fn=arc_collate_fn_bs1,
        num_workers=args.workers,
        pin_memory=pin,
        persistent_workers=(args.workers > 0),
        prefetch_factor=2 if args.workers > 0 else None,
    )

    # ----------------- (Placeholder) Training Loop -----------------
    # This script focuses on data preparation/ingest; plug in your model below.
    # We just iterate a few samples to validate the pipeline and save shapes.
    seen = 0
    shapes = []
    for epoch in range(args.epochs):
        for batch in loader:
            if seen < args.log_n and is_primary(rank):
                log("=== SAMPLE ===")
                log("ID:", batch["id"]) 
                log("train_inputs:", tuple(batch["train_inputs"].shape))
                log("train_outputs:", tuple(batch["train_outputs"].shape))
                log("test_inputs:", tuple(batch["test_inputs"].shape))
                log("test_outputs:", tuple(batch["test_outputs"].shape))
                log("train_original_size:", batch["train_original_size"].tolist())
                log("test_original_size:", batch["test_original_size"].tolist())
            shapes.append({
                "id": batch["id"],
                "train_inputs": list(batch["train_inputs"].shape),
                "train_outputs": list(batch["train_outputs"].shape),
                "test_inputs": list(batch["test_inputs"].shape),
                "test_outputs": list(batch["test_outputs"].shape),
            })
            seen += 1

    if is_primary(rank):
        save_json({"sample_shapes": shapes[: args.log_n]}, os.path.join(args.model_dir, "sample_shapes.json"))
        # drop a tiny file to ensure model artifact is produced
        with open(os.path.join(args.model_dir, "_SUCCESS"), "w") as f:
            f.write("ok\n")

    # Clean up DDP
    if world_size > 1:
        torch.distributed.barrier()
        torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] [--train TRAIN] [--model-dir MODEL_DIR]
                             [--batch-size BATCH_SIZE] [--epochs EPOCHS]
                             [--pad-value PAD_VALUE] [--add-one]
                             [--workers WORKERS] [--log-n LOG_N]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/eliholm/.local/share/jupyter/runtime/kernel-60ca2500-d71d-4405-ac7e-3cd5eb96dfac.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
