In [None]:
!pip install 'pytest>=8.3.2' 'numpy>=1.26.4' 'pillow>=10.4.0' 'msgpack>=1.1.0' 'requests>=2.32.3' 'mediapy>=1.2.2' tqdm
!pip install --no-deps 'optax==0.2.3' 'chex==0.1.86' 'flax>=0.9.0' orbax-checkpoint tensorstore 'typing-extensions>=4.2' 'absl-py>=2.1.0' 'toolz>=1.0.0' 'etils[epy]>=1.9.4'
!git clone https://github.com/hu-po/cax.git /cax
!pip install --upgrade /cax --no-deps
!pytest --color=no /cax/tests

In [None]:
from functools import partial
import json
import os
import psutil
import subprocess
import time

import jax
import jaxlib
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm

import flax
from flax import nnx
import optax
import cax

for pkg in [jax, jaxlib, cax, flax, optax]:
    print(pkg.__name__, pkg.__version__)

In [None]:
def get_device_memory():
    used_mem_mb = None
    total_mem_mb = None

    # Try to get memory info using nvidia-smi
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=memory.total,memory.used', '--format=csv,noheader,nounits'],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True
        )
        total_mem = used_mem = 0
        for line in result.stdout.strip().split('\n'):
            total, used = map(float, line.strip().split(', '))
            total_mem += total
            used_mem += used
        return used_mem, total_mem
    except Exception:
        pass  # nvidia-smi not available

    # Try to get memory info from JAX devices
    try:
        devices = jax.devices()
        if devices:
            total_mem = sum(device.memory_size() for device in devices)
            total_mem_mb = total_mem / (1024 * 1024)  # Bytes to MB
            # Used memory is not readily available via JAX
            print(f"Total device memory: {total_mem_mb:.2f} MB")
            return None, total_mem_mb
    except Exception:
        pass  # JAX not available or no devices found

    # Try to get memory info using tegrastats (for AGX Orin devices)
    try:
        result = subprocess.run(
            ['tegrastats', '--interval', '1', '--count', '1'],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True
        )
        output = result.stdout.strip()
        # Parse tegrastats output if necessary
        print("Tegrastats output:", output)
        return None, None
    except Exception:
        pass  # tegrastats not available

    # Fallback to system memory info using psutil
    try:
        mem = psutil.virtual_memory()
        total_mem_mb = mem.total / (1024 * 1024)
        used_mem_mb = mem.used / (1024 * 1024)
        return used_mem_mb, total_mem_mb
    except Exception:
        pass  # psutil not available

    print("Could not retrieve device memory usage.")
    return None, None

# Load data
def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} tasks from {path}")
    return data

train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

# Process tasks
def process_tasks(challenges, solutions):
    data = []
    task_id_to_index = {}
    for index, task_id in enumerate(challenges.keys()):
        task_id_to_index[task_id] = index
        task = challenges[task_id]
        solution = solutions[task_id]
        for pair in task['train']:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            data.append((input_grid, output_grid, index))
        for i, test_input in enumerate(task['test']):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            output_grid = np.array(solution[i], dtype=np.int32)
            data.append((input_grid, output_grid, index))
    return data, task_id_to_index

# Pad grid function
def pad_grid(grid, max_size=30, pad_value=0, center_offset=(0, 0)):
    padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
    rows, cols = grid.shape
    offset_row = (max_size - rows) // 2 + center_offset[0]
    offset_col = (max_size - cols) // 2 + center_offset[1]
    # Ensure offsets are within bounds
    offset_row = np.clip(offset_row, 0, max_size - rows)
    offset_col = np.clip(offset_col, 0, max_size - cols)
    padded[offset_row:offset_row+rows, offset_col:offset_col+cols] = grid
    return padded

# Augmentation functions
@jax.jit
def flip_horizontal(grid):
    return jnp.fliplr(grid)

@jax.jit
def flip_vertical(grid):
    return jnp.flipud(grid)

@partial(jax.jit, static_argnums=(1,))
def rotate(grid, k):
    return jnp.rot90(grid, k=k)

# Mapping from augmentation names to functions
AUGMENTATIONS = {
    'flip_horizontal': flip_horizontal,
    'flip_vertical': flip_vertical,
    'rotate_90': partial(rotate, k=1),
    'rotate_180': partial(rotate, k=2),
    'rotate_270': partial(rotate, k=3),
}

# Apply random augmentations to a batch
def apply_random_augmentations_batch(input_grids, output_grids, rng_keys, augmentations):
    num_augs = len(augmentations)

    def augment_sample(input_grid, output_grid, rng_key):
        apply_aug = jax.random.bernoulli(rng_key, p=0.5, shape=(num_augs,))
        for i, aug_name in enumerate(augmentations):
            aug_func = AUGMENTATIONS[aug_name]
            input_grid, output_grid = jax.lax.cond(
                apply_aug[i],
                lambda grids: (aug_func(grids[0]), aug_func(grids[1])),
                lambda grids: grids,
                (input_grid, output_grid)
            )
        return input_grid, output_grid

    # Vectorize over the batch dimension
    v_augment_sample = jax.vmap(augment_sample, in_axes=(0, 0, 0), out_axes=(0, 0))
    augmented_inputs, augmented_outputs = v_augment_sample(input_grids, output_grids, rng_keys)
    return augmented_inputs, augmented_outputs

# Data generator
def data_generator(data, batch_size, augmentations=None, max_size=30, pad_value=0,
                   center_offset=(0, 0), shuffle=True):
    num_samples = len(data)
    indices = np.arange(num_samples)
    epoch = 0
    while True:
        if shuffle:
            rng = np.random.default_rng(seed=epoch)
            rng.shuffle(indices)
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_inputs = []
            batch_outputs = []
            batch_task_indices = []
            # Collect and pad batch data
            for idx in excerpt:
                input_grid, output_grid, task_index = data[idx]
                # Pad grids
                input_padded = pad_grid(input_grid, max_size, pad_value, center_offset)
                output_padded = pad_grid(output_grid, max_size, pad_value, center_offset)
                batch_inputs.append(input_padded)
                batch_outputs.append(output_padded)
                batch_task_indices.append(task_index)
            # Convert to JAX arrays
            batch_inputs = jnp.stack(batch_inputs)
            batch_outputs = jnp.stack(batch_outputs)
            batch_task_indices = jnp.array(batch_task_indices, dtype=jnp.int32)
            # Apply random augmentations
            if augmentations:
                rng_key = jax.random.PRNGKey(epoch)
                rng_key = jax.random.fold_in(rng_key, start_idx)
                rng_keys = jax.random.split(rng_key, len(batch_inputs))
                batch_inputs, batch_outputs = apply_random_augmentations_batch(
                    batch_inputs, batch_outputs, rng_keys, augmentations)
            yield batch_inputs, batch_outputs, batch_task_indices
        epoch += 1  # Increment epoch counter for new seed

# Prepare data
train_data, task_id_to_index = process_tasks(train_challenges, train_solutions)
eval_data, _ = process_tasks(eval_challenges, eval_solutions)

# Define augmentations
augmentations = ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']

# Define batch size
batch_size = 32  # Adjust as needed

# Create data generators
train_generator = data_generator(
    train_data,
    batch_size=batch_size,
    augmentations=augmentations,
    center_offset=(0, 0),
    shuffle=True
)

eval_generator = data_generator(
    eval_data,
    batch_size=batch_size,
    augmentations=None,  # Typically, we don't apply augmentations during evaluation
    center_offset=(0, 0),
    shuffle=False
)

# Example usage: Get one batch

# Warm-up run to exclude compilation time
_ = next(train_generator)

# Record initial time and memory usage
start_time = time.time()
process = psutil.Process()
initial_memory = process.memory_info().rss  # in bytes
total_ram = psutil.virtual_memory().total  # in bytes
initial_gpu_memory_used, total_gpu_memory = get_device_memory()

train_batch_inputs, train_batch_outputs, train_batch_task_indices = next(train_generator)

# Force synchronization to get accurate timing (for JAX on GPU)
jax.block_until_ready(train_batch_inputs)

# Record final time and memory usage
end_time = time.time()
final_memory = process.memory_info().rss  # in bytes
final_gpu_memory_used, _ = get_device_memory()

# Compute time and memory differences
time_taken = end_time - start_time
memory_used = final_memory - initial_memory  # in bytes

print("Train batch inputs shape:", train_batch_inputs.shape)
print("Train batch outputs shape:", train_batch_outputs.shape)
print("Train batch task indices shape:", train_batch_task_indices.shape)
print(f"Time taken to generate one batch: {time_taken:.4f} seconds")

# Calculate RAM utilization percentage
ram_utilization = (final_memory / total_ram) * 100
print(f"RAM utilization after batch generation: {ram_utilization:.2f}%")

# Calculate RAM used as a percentage of total RAM
ram_used_percent = (memory_used / total_ram) * 100
print(f"RAM used to generate one batch: {memory_used / (1024 * 1024):.2f} MB ({ram_used_percent:.6f}% of total RAM)")

if initial_gpu_memory_used is not None and total_gpu_memory is not None and final_gpu_memory_used is not None:
    gpu_memory_used = final_gpu_memory_used - initial_gpu_memory_used
    gpu_utilization = (final_gpu_memory_used / total_gpu_memory) * 100
    gpu_used_percent = (gpu_memory_used / total_gpu_memory) * 100
    print(f"GPU memory used to generate one batch: {gpu_memory_used:.2f} MB ({gpu_used_percent:.6f}% of total GPU memory)")
    print(f"GPU memory utilization after batch generation: {gpu_utilization:.2f}%")
else:
    print("Device memory usage information is not available.")
