# BlocksWorld TTM

Use the **2nd Encoding** format to use the TTM Granite model on the BlocksWorld domain.

Key modifications from standard TTM:

- Input format includes goal state concatenated with current state
- Binary state prediction instead of continuous values
- Custom metrics for planning success
- Sequence padding to handle variable-length plans


In [1]:
import json
import math
import os
from dataclasses import dataclass, asdict
from typing import List, Optional, Dict, Any
import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from pprint import pformat

from tsfm_public import TrackingCallback
from tsfm_public.toolkit.get_model import get_model

from BlocksWorld import BlocksWorldGenerator

In [2]:
# Constants
SEED = 13
set_seed(SEED)
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"

# Determine device
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Using: {pformat([SEED, TTM_MODEL_PATH, DEVICE])}")

Using: [13, 'ibm-granite/granite-timeseries-ttm-r2', device(type='mps')]


## Helper Classes


### Plan Dataclass

For storing individual planning examples


In [3]:
@dataclass
class ModelConfig:
    context_length: int = 1  # Placeholder for context length
    prediction_length: int = 96
    learning_rate: float = 1e-4
    batch_size: int = 32
    num_epochs: int = 50
    state_dim: Optional[int] = None  # Will be set during training

In [4]:
@dataclass
class BlocksWorldSample:
    initial_state: List[int]
    goal_state: List[int]
    plan: List[List[int]]
    actions: List[List[str]]
    feature_names: List[str]

### Custom BlocksWorld Dataset Class

The class handles:

- Loading JSON plan data
- Padding sequences to match context length
- Combining state and goal information
- Converting to appropriate tensor format


In [None]:
class BlocksWorldDataset(Dataset):
    def __init__(self, data_path: str, context_length: int, prediction_length: int):
        self.context_length: int = context_length
        self.prediction_length: int = prediction_length
        self.device = DEVICE

        with open(data_path, "r") as f:
            raw_data = json.load(f)["plans"]

        self.samples: List[BlocksWorldSample] = []
        for item in raw_data:
            sample = BlocksWorldSample(
                initial_state=item["initial_state"],
                goal_state=item["goal_state"],
                plan=item["plan"],
                actions=item["actions"],
                feature_names=item["feature_names"],
            )
            self.samples.append(sample)

        # Get dimensionality from first sample
        self.state_dim: int = len(self.samples[0].initial_state)

    def __len__(self):  # Length of the Dataset
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        plan_states_np = np.array(sample.plan, dtype=np.float32)  # shape (plan_len, state_dim)
        goal_state_np = np.array(sample.goal_state, dtype=np.float32)  # shape (state_dim,)
        plan_len = len(sample.plan)

        # Past values and mask
        past_values_np = np.zeros((self.context_length, self.state_dim), dtype=np.float32)
        past_observed_mask_np = np.zeros((self.context_length, self.state_dim), dtype=np.float32)

        # How many actual plan steps can go into context
        len_from_plan_for_past = min(plan_len, self.context_length)
        past_values_np[:len_from_plan_for_past] = plan_states_np[:len_from_plan_for_past]
        past_observed_mask_np[:len_from_plan_for_past, :] = 1.0

        # Fill remaining past_values with padding (e.g., last valid state or goal state)
        # If using goal_state for padding context...
        if len_from_plan_for_past < self.context_length:
            num_past_padding = self.context_length - len_from_plan_for_past
            padding_values = np.tile(goal_state_np, (num_past_padding, 1))
            past_values_np[len_from_plan_for_past:] = padding_values
            # Mask for these padded values remains 0

        # Future values and mask
        future_values_np = np.zeros((self.prediction_length, self.state_dim), dtype=np.float32)
        future_observed_mask_np = np.zeros((self.prediction_length, self.state_dim), dtype=np.float32)

        # How many actual plan steps can go into future (after context)
        # These are sample.plan[self.context_length:]
        actual_future_steps_from_plan = plan_states_np[self.context_length :]
        len_from_plan_for_future = min(len(actual_future_steps_from_plan), self.prediction_length)

        if len_from_plan_for_future > 0:
            future_values_np[:len_from_plan_for_future] = actual_future_steps_from_plan[:len_from_plan_for_future]
            future_observed_mask_np[:len_from_plan_for_future, :] = 1.0

        # Fill remaining future_values with goal_state padding
        if len_from_plan_for_future < self.prediction_length:
            num_future_padding = self.prediction_length - len_from_plan_for_future
            padding_values = np.tile(goal_state_np, (num_future_padding, 1))
            future_values_np[len_from_plan_for_future:] = padding_values
            # Mask for these padded values remains 0 (model learns to predict goal, but loss ignores if it's just padding)
            # OR, to make the model  explicitly learn to output goal for padding, mask should be 1.
            # The TTM examples usually mask out padding in targets.
            # This is a design choice. If mask=1, loss is computed. If mask=0, loss is ignored.
            # For learning stability at goal, mask=1 for goal padding in future_values could be reasonable.
            future_observed_mask_np[len_from_plan_for_future:, :] = (
                # 1.0  # To enforce goal prediction
                0.0  # To ignore loss on goal prediction
            )

        # Static categorical values (goal state)
        static_categorical_values_np = goal_state_np

        return {
            "past_values": torch.tensor(past_values_np, dtype=torch.float32).to(self.device),
            "future_values": torch.tensor(future_values_np, dtype=torch.float32).to(self.device),
            "past_observed_mask": torch.tensor(past_observed_mask_np, dtype=torch.float32).to(self.device),
            "future_observed_mask": torch.tensor(future_observed_mask_np, dtype=torch.float32).to(self.device),
            "static_categorical_values": torch.tensor(static_categorical_values_np, dtype=torch.float32).to(
                self.device
            ),
            "freq_token": torch.zeros(1, dtype=torch.long).to(self.device),
        }

### BlocksWorld-Based TTM Class

To handle training and prediction.


In [None]:
class BlocksWorldTTM:
    def __init__(self, model_config: ModelConfig):
        self.config = model_config
        self.device = DEVICE
        self.model = None
        self.trainer = None
        self.model_name = None

    def train(self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None):
        """Train the model on given datasets"""
        # Store state dimension from training data
        self.config.state_dim = (
            train_dataset.dataset.state_dim if hasattr(train_dataset, "dataset") else train_dataset.state_dim
        )

        # Initialize model
        self.model = get_model(
            TTM_MODEL_PATH,
            context_length=self.config.context_length,
            prediction_length=self.config.prediction_length,
            head_dropout=0.1,
        ).to(self.device)

        self.model_name = get_model(
            TTM_MODEL_PATH,
            context_length=self.config.context_length,
            prediction_length=self.config.prediction_length,
            head_dropout=0.1,
            return_model_key=True,
        )
        print(f"Received model name: {self.model_name}")

        # Training arguments
        training_args = TrainingArguments(
            output_dir="blocks_world_ttm",
            learning_rate=self.config.learning_rate,
            num_train_epochs=self.config.num_epochs,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size,
            eval_strategy="epoch" if val_dataset else "no",
            save_strategy="epoch",
            load_best_model_at_end=True if val_dataset else False,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            seed=SEED,
            report_to="none",
            dataloader_pin_memory=False,
        )

        # Callbacks
        callbacks = [
            TrackingCallback(),
            EarlyStoppingCallback(early_stopping_patience=5),
        ]

        # Optimizer and scheduler
        optimizer = AdamW(self.model.parameters(), lr=self.config.learning_rate)
        scheduler = OneCycleLR(
            optimizer,
            max_lr=self.config.learning_rate,
            epochs=self.config.num_epochs,
            steps_per_epoch=math.ceil(len(train_dataset) / self.config.batch_size),
        )

        # Initialize trainer
        self.trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            callbacks=callbacks,
            optimizers=(optimizer, scheduler),
        )

        # Train
        self.trainer.train()

    def predict(self, initial_states: torch.Tensor, goal_states: torch.Tensor) -> torch.Tensor:
        """Generate action sequences to reach goals from given states"""
        if self.model is None:
            raise RuntimeError("Model needs to be trained or loaded before prediction")

        self.model.eval()
        with torch.no_grad():
            batch_size = initial_states.shape[0]

            # Create context sequence by repeating initial states
            context_sequence = initial_states.unsqueeze(1).repeat(1, self.config.context_length, 1)

            # Prepare inputs
            inputs = {
                "past_values": context_sequence.to(self.device),
                "past_observed_mask": torch.ones_like(context_sequence).to(self.device),
                "static_categorical_values": goal_states.to(self.device),
                "freq_token": torch.zeros(batch_size, dtype=torch.long).to(self.device),
            }

            # Generate predictions
            outputs = self.model(**inputs)
            predictions = torch.sigmoid(outputs[0])
            predictions = torch.round(predictions)

        return predictions

    def save(self, path: str):
        """Save model weights and configuration"""
        if self.model is None:
            raise RuntimeError("No model to save. Train or load a model first.")

        # Create directory if it doesn't exist
        os.makedirs(path, exist_ok=True)

        # Save model state
        model_path = os.path.join(path, "model.pt")
        torch.save(self.model.state_dict(), model_path)

        # Save configuration
        config_path = os.path.join(path, "config.json")
        with open(config_path, "w") as f:
            json.dump(asdict(self.config), f)

        print(f"Model saved to {path}")

    @classmethod
    def load(cls, path: str) -> "BlocksWorldTTM":
        """Load model weights and configuration"""
        # Load configuration
        config_path = os.path.join(path, "config.json")
        with open(config_path, "r") as f:
            config_dict = json.load(f)

        # Create instance with loaded config
        instance = cls(
            context_length=config_dict["context_length"],
            prediction_length=config_dict["prediction_length"],
            learning_rate=config_dict["learning_rate"],
            batch_size=config_dict["batch_size"],
            num_epochs=config_dict["num_epochs"],
        )
        instance.config.state_dim = config_dict["state_dim"]

        # Initialize and load model
        instance.model = get_model(
            TTM_MODEL_PATH,
            context_length=instance.config.context_length,
            prediction_length=instance.config.prediction_length,
            head_dropout=0.1,
        ).to(instance.device)

        model_path = os.path.join(path, "model.pt")
        instance.model.load_state_dict(torch.load(model_path, map_location=instance.device))
        instance.model.eval()

        print(f"Model loaded from {path}")
        return instance

### Helper Methods


In [7]:
def prepare_datasets(data_path: str, context_length: int, prediction_length: int):
    """Create train/val/test datasets"""
    full_dataset = BlocksWorldDataset(data_path, context_length, prediction_length)

    # Split indices
    total_size = len(full_dataset)
    train_size = int(0.7 * total_size)
    val_size = int(0.15 * total_size)
    test_size = total_size - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(SEED),
    )

    return train_dataset, val_dataset, test_dataset

In [None]:
def evaluate_model(model, test_dataset, verbose=True):
    """Comprehensive evaluation of the model with more detailed metrics"""
    model.model.eval()
    all_predictions = []
    all_targets = []
    goal_state_predictions = []
    goal_state_targets = []

    num_samples = len(test_dataset)
    num_exact_matches = 0
    num_partial_matches = 0
    total_bits_correct = 0
    total_bits = 0

    with torch.no_grad():
        for i in range(num_samples):
            sample = test_dataset[i]

            # Get initial and goal states
            initial_state = sample["past_values"][0]
            goal_state = sample["static_categorical_values"]
            target = sample["future_values"]

            # Create context sequence
            context_sequence = initial_state.unsqueeze(0).repeat(1, model.config.context_length, 1)

            # Prepare inputs
            inputs = {
                "past_values": context_sequence.to(model.device),
                "past_observed_mask": torch.ones_like(context_sequence).to(model.device),
                "static_categorical_values": goal_state.unsqueeze(0).to(model.device),
                "freq_token": torch.zeros(1, dtype=torch.long).to(model.device),
            }

            # Get prediction
            outputs = model.model(**inputs)
            prediction = torch.sigmoid(outputs[0])
            prediction = torch.round(prediction)

            # Store predictions and targets
            all_predictions.append(prediction)
            all_targets.append(target)

            # Focus on goal states (final states)
            pred_goal = prediction[0, -1]
            true_goal = target[-1]

            goal_state_predictions.append(pred_goal)
            goal_state_targets.append(true_goal)

            # Calculate exact matches
            if torch.all(pred_goal == true_goal):
                num_exact_matches += 1

            # Calculate partial matches (more than 50% bits correct)
            num_correct_bits = torch.sum(pred_goal == true_goal).item()
            total_bits_correct += num_correct_bits
            total_bits += len(pred_goal)

            if num_correct_bits > len(pred_goal) / 2:
                num_partial_matches += 1

    # Calculate metrics
    metrics = {
        "num_samples": num_samples,
        "num_exact_matches": num_exact_matches,
        "exact_match_rate": num_exact_matches / num_samples,
        "num_partial_matches": num_partial_matches,
        "partial_match_rate": num_partial_matches / num_samples,
        "bit_accuracy": total_bits_correct / total_bits,
    }

    if verbose:
        print("\nDetailed Model Evaluation Metrics:")
        print("-" * 50)
        print(f"Total number of test samples: {metrics['num_samples']}")
        print(f"Number of exact goal state matches: {metrics['num_exact_matches']}")
        print(f"Exact match rate: {metrics['exact_match_rate']:.4f}")
        print(f"Number of partial matches (>50% correct): {metrics['num_partial_matches']}")
        print(f"Partial match rate: {metrics['partial_match_rate']:.4f}")
        print(f"Bit-level accuracy: {metrics['bit_accuracy']:.4f}")

    return metrics


def analyze_error_patterns(model, test_dataset, verbose=True):
    """
    Enhanced error pattern analysis with more detailed statistics
    """
    model.model.eval()
    successes = []
    failures = []

    bit_error_counts = {}  # Track which bits are most commonly wrong

    with torch.no_grad():
        for i in range(len(test_dataset)):
            sample = test_dataset[i]

            # Get initial and goal states
            initial_state = sample["past_values"][0]
            goal_state = sample["static_categorical_values"]
            target = sample["future_values"][-1]

            # Create context sequence
            context_sequence = initial_state.unsqueeze(0).repeat(1, model.config.context_length, 1)

            # Prepare inputs
            inputs = {
                "past_values": context_sequence.to(model.device),
                "past_observed_mask": torch.ones_like(context_sequence).to(model.device),
                "static_categorical_values": goal_state.unsqueeze(0).to(model.device),
                "freq_token": torch.zeros(1, dtype=torch.long).to(model.device),
            }

            # Get prediction
            outputs = model.model(**inputs)
            prediction = torch.sigmoid(outputs[0])
            prediction = torch.round(prediction)
            predicted_goal = prediction[0, -1]

            # Calculate error statistics
            errors = (predicted_goal != target).nonzero().squeeze(1)
            num_errors = len(errors)

            # Track which bits had errors
            for error_idx in errors:
                if error_idx.item() not in bit_error_counts:
                    bit_error_counts[error_idx.item()] = 0
                bit_error_counts[error_idx.item()] += 1

            case = {
                "initial_state": initial_state.cpu().numpy(),
                "goal_state": goal_state.cpu().numpy(),
                "predicted_goal": predicted_goal.cpu().numpy(),
                "target_goal": target.cpu().numpy(),
                "num_errors": num_errors,
                "error_positions": errors.cpu().numpy(),
            }

            if num_errors == 0:
                successes.append(case)
            else:
                failures.append(case)

    analysis = {
        "num_successes": len(successes),
        "num_failures": len(failures),
        "success_rate": len(successes) / (len(successes) + len(failures)),
        "bit_error_counts": bit_error_counts,
        "successes": successes,
        "failures": failures,
    }

    if verbose:
        print("\nError Pattern Analysis:")
        print("-" * 50)
        print(f"Number of successful predictions: {analysis['num_successes']}")
        print(f"Number of failed predictions: {analysis['num_failures']}")
        print(f"Success rate: {analysis['success_rate']:.4f}")
        print("\nMost common error positions:")
        sorted_errors = sorted(bit_error_counts.items(), key=lambda x: x[1], reverse=True)
        for bit, count in sorted_errors[:5]:
            print(f"Bit {bit}: {count} errors")

    return analysis

## Model Training


In [10]:
def analyze_dataset(data_path: str) -> Dict[str, Any]:
    """Analyze the dataset to determine appropriate parameters"""
    with open(data_path, "r") as f:
        data = json.load(f)["plans"]

    # Get key statistics
    max_plan_length = max(len(item["plan"]) for item in data)
    avg_plan_length = sum(len(item["plan"]) for item in data) / len(data)
    state_dim = len(data[0]["initial_state"])
    num_samples = len(data)

    stats = {
        "max_plan_length": max_plan_length,
        "avg_plan_length": avg_plan_length,
        "state_dim": state_dim,
        "num_samples": num_samples,
        "recommended_prediction_length": max_plan_length + 2,  # Add small buffer
    }

    print("\nDataset Statistics:")
    print(f"Number of samples: {num_samples}")
    print(f"State dimension: {state_dim}")
    print(f"Maximum plan length: {max_plan_length}")
    print(f"Average plan length: {avg_plan_length:.2f}")
    print(f"Recommended prediction length: {stats['recommended_prediction_length']}")

    return stats

In [11]:
# Create datasets
dataset_file = "../data/dataset_6.json"
print(f"Number of blocks in the dataset: {(num_blocks := int(dataset_file.split('_')[-1][0]))}")

# Analyze dataset
stats = analyze_dataset(dataset_file)

Number of blocks in the dataset: 6

Dataset Statistics:
Number of samples: 44175
State dimension: 48
Maximum plan length: 21
Average plan length: 9.07
Recommended prediction length: 23


In [None]:
# For the context length, we need to get the minimum closest value to the supported CLs from `stats["max_plan_length"]`
# if max_plan_length is 100, the closest supported CL is 90.
# if max_plan_length is 87, the closest supported CL is 52, not 90.
supported_cls = [52, 90, 180, 360, 520, 1024, 1536]

# Find the largest supported CL that is <= max_plan_length
# If all supported CLs are > max_plan_length, choose the smallest supported CL
valid_cls = [cl for cl in supported_cls if cl <= stats["max_plan_length"]]
if valid_cls:
    closest_cl = max(valid_cls)
else:
    closest_cl = min(supported_cls)

print(f"Max plan length: {stats['max_plan_length']} | Selected context length: {closest_cl}")

# Similarly, for the forecast lengths, we need to get the minimum closest value to the supported FLs from `stats["recommended_prediction_length"]`
# if recommended_prediction_length is 100, the closest supported FL is 96.
# if recommended_prediction_length is 87, the closest supported FL is 60, not 96.
supported_fls = [16, 30, 48, 60, 96, 192, 336, 720]

# Find the largest supported FL that is <= recommended_prediction_length
# If all supported FLs are > recommended_prediction_length, choose the smallest supported FL
valid_fls = [fl for fl in supported_fls if fl <= stats["recommended_prediction_length"]]
if valid_fls:
    closest_fl = max(valid_fls)
else:
    closest_fl = min(supported_fls)

print(
    f"Recommended prediction length: {stats['recommended_prediction_length']} | Selected forecast length: {closest_fl}"
)

Max plan length: 21 | Selected context length: 52
Recommended prediction length: 23 | Selected forecast length: 16


In [None]:
train_dataset, val_dataset, test_dataset = prepare_datasets(
    dataset_file,
    context_length=closest_cl,
    prediction_length=closest_fl,
)

print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}")

Train size: 30922, Val size: 6626, Test size: 6627


In [20]:
# Initialize and train model
ttm = BlocksWorldTTM(
    context_length=closest_cl,
    prediction_length=closest_fl,
    learning_rate=1e-4,
    batch_size=16,
    num_epochs=20,
)

print("Starting training...")
ttm.train(train_dataset, val_dataset)

INFO:p-45880:t-8430149376:get_model.py:get_model:Loading model from: ibm-granite/granite-timeseries-ttm-r2


Starting training...


INFO:p-45880:t-8430149376:get_model.py:get_model:Model loaded successfully from ibm-granite/granite-timeseries-ttm-r2, revision = 52-16-ft-r2.1.
INFO:p-45880:t-8430149376:get_model.py:get_model:[TTM] context_length = 52, prediction_length = 16
INFO:p-45880:t-8430149376:get_model.py:get_model:Loading model from: ibm-granite/granite-timeseries-ttm-r2


Received model name: 52-16-ft-r2.1


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# Save & Load Model
save_path = f"../models/blocksworld-{num_blocks}_ttm-{ttm.model_name}"

# Get input from user on the save path if it already exists
if os.path.exists(save_path):
    user_input = input(f"Path {save_path} already exists. Overwrite? (y/n): ")
    if user_input.lower() != "y":
        new_path = input("Enter a new path: ")
        save_path = new_path
    else:
        print(f"Overwriting existing path: {save_path}")
else:
    os.makedirs(save_path, exist_ok=True)

# Save model
ttm.save(save_path)
print(f"Saved to {save_path}")

In [None]:
ttm = BlocksWorldTTM.load(save_path)
print(f"Loaded from {save_path}")

In [None]:
# Evaluate model
print("\nEvaluating model performance...")
metrics = evaluate_model(ttm, test_dataset)

# Analyze error patterns
print("\nAnalyzing error patterns...")
analysis = analyze_error_patterns(ttm, test_dataset)
successes, failures = analysis["successes"], analysis["failures"]

# Examine the inital states, goal states, and actions
gen = BlocksWorldGenerator(num_blocks=num_blocks)

print("\n--------------------------------------------------")
print(f"Length of the test dataset: {len(test_dataset)}")
print("\nExample Successes:")
for i, case in enumerate(successes[:3]):
    print(f"\nCase {i + 1}:")
    print(f"Initial State: {gen.decode_vector_to_blocks(case['initial_state'])}")
    print(f"Goal State: {gen.decode_vector_to_blocks(case['goal_state'])}")
    print(f"Predicted Goal: {gen.decode_vector_to_blocks(case['predicted_goal'])}")

print("\nExample Failures:")
for i, case in enumerate(failures[:3]):
    print(f"\nCase {i + 1}:")
    print(f"Initial State: {gen.decode_vector_to_blocks(case['initial_state'])}")
    print(f"Goal State: {gen.decode_vector_to_blocks(case['goal_state'])}")
    print(f"Predicted Goal: {gen.decode_vector_to_blocks(case['predicted_goal'])}")
    print(f"Target Goal: {gen.decode_vector_to_blocks(case['target_goal'])}")
    print(f"Number of Errors: {case['num_errors']}")