# TTM
*Attempt 2*

In [9]:
import pandas as pd
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from tsfm_public import TimeSeriesPreprocessor, get_datasets
from tsfm_public.toolkit.get_model import get_model
import torch.nn as nn
from torch import round, no_grad, FloatTensor
import numpy as np
import json
from sklearn.model_selection import train_test_split

In [2]:
SEED = 13
set_seed(SEED)

## TTM Class

In [3]:
def prepare_ttm_data(data, context_length=512, forecast_length=96):
    # Convert JSON data to DataFrame format
    processed_data = []
    
    for plan_data in data:
        # Extract state sequences
        states = plan_data['plan']
        
        # Create a row for each state in the plan
        for state in states:
            row = {
                'timestamp': len(processed_data),  # Use index as timestamp
                **state  # Unpack state variables
            }
            processed_data.append(row)
    
    df = pd.DataFrame(processed_data)
    
    # Define column specifications for TTM
    column_specifiers = {
        "timestamp_column": "timestamp",
        "id_columns": [],  # No separate time series IDs
        "target_columns": ["V1", "V21", "V22", "V31", "V32"],
        "control_columns": []
    }
    
    # Split configuration
    total_length = len(df)
    split_config = {
        "train": [0, int(0.7 * total_length)],
        "valid": [int(0.7 * total_length), int(0.85 * total_length)],
        "test": [int(0.85 * total_length), total_length]
    }
    
    # Create TTM preprocessor
    tsp = TimeSeriesPreprocessor(
        **column_specifiers,
        context_length=context_length,
        prediction_length=forecast_length,
        scaling=True,
        encode_categorical=False,
        scaler_type="standard"
    )
    
    return get_datasets(tsp, df, split_config)

In [4]:
def create_ttm_planner(context_length=512, forecast_length=96):    
    # Get base TTM model
    model = get_model(
        "ibm-granite/granite-timeseries-ttm-r2",
        context_length=context_length,
        prediction_length=forecast_length,
        head_dropout=0.3
    )
    
    # Modify for discrete state prediction
    # Add discretization layer or post-processing
    class TTMPlanner(nn.Module):
        def __init__(self, base_model):
            super().__init__()
            self.base_model = base_model
            self.discretize = nn.Linear(forecast_length, 5)  # 5 state variables
            
        def forward(self, x):
            # Get TTM predictions
            base_output = self.base_model(x)
            # Discretize outputs to valid state values
            discrete_states = round(self.discretize(base_output))
            return discrete_states
            
    return TTMPlanner(model)

In [5]:
def generate_plan(model, initial_state, goal_state):
    current_state = list(initial_state.values())
    goal = list(goal_state.values())
    plan = [initial_state]
    
    # Generate plan steps until we reach goal or max steps
    max_steps = 10
    for _ in range(max_steps):
        # Prepare input - concatenate current state and goal
        input_vector = np.array(current_state + goal)  # Convert to numpy array first
        x = FloatTensor(input_vector).unsqueeze(0)  # Add batch dimension
        
        # Get model prediction
        with no_grad():
            next_state = model(x).numpy()[0]
            
        # Round to nearest valid state values
        next_state = np.round(next_state).astype(int)
        
        # Convert to dict format
        next_state_dict = {
            'V1': next_state[0],
            'V21': next_state[1],
            'V22': next_state[2],
            'V31': next_state[3],
            'V32': next_state[4]
        }
        
        plan.append(next_state_dict)
        
        # Check if we reached goal
        if np.array_equal(next_state, goal):
            break
            
        current_state = next_state.tolist()  # Convert numpy array to list for next iteration
    
    return plan

## Evaluation

In [6]:
def check_state_validity(state):
    """
    Check if a state follows all Blocks World constraints based on the encoding:
    V1: What's on table (0=None, 1=A, 2=B, 3=A,B)
    V21: What's below A (0=None, 1=B, 2=Table)
    V22: What's on top of A (0=None, 1=B)
    V31: What's below B (0=None, 1=A, 2=Table)
    V32: What's on top of B (0=None, 1=A)
    """
    # Initialize validity flags
    valid = True
    reasons = []

    # 1. Range checks
    if not (0 <= state['V1'] <= 3):
        valid = False
        reasons.append("V1 out of range [0-3]")
    if not (0 <= state['V21'] <= 2):
        valid = False
        reasons.append("V21 out of range [0-2]")
    if not (0 <= state['V22'] <= 1):
        valid = False
        reasons.append("V22 out of range [0-1]")
    if not (0 <= state['V31'] <= 2):
        valid = False
        reasons.append("V31 out of range [0-2]")
    if not (0 <= state['V32'] <= 1):
        valid = False
        reasons.append("V32 out of range [0-1]")

    # 2. Table consistency (V1)
    # If A is on table (V21=2), V1 should include A (V1=1 or V1=3)
    if state['V21'] == 2 and state['V1'] not in [1, 3]:
        valid = False
        reasons.append("A on table but V1 doesn't reflect this")
    
    # If B is on table (V31=2), V1 should include B (V1=2 or V1=3)
    if state['V31'] == 2 and state['V1'] not in [2, 3]:
        valid = False
        reasons.append("B on table but V1 doesn't reflect this")

    # 3. Block A position consistency
    # If A is on B (V21=1), then B must have A on top (V32=1)
    if state['V21'] == 1 and state['V32'] != 1:
        valid = False
        reasons.append("A is on B but B doesn't have A on top")

    # 4. Block B position consistency
    # If B is on A (V31=1), then A must have B on top (V22=1)
    if state['V31'] == 1 and state['V22'] != 1:
        valid = False
        reasons.append("B is on A but A doesn't have B on top")

    # 5. Single position constraints
    # A block can't be in multiple positions simultaneously
    if state['V21'] == 1 and state['V21'] == 2:  # A can't be both on B and table
        valid = False
        reasons.append("A can't be both on B and table")
    if state['V31'] == 1 and state['V31'] == 2:  # B can't be both on A and table
        valid = False
        reasons.append("B can't be both on A and table")

    # 6. Mutual exclusion
    # A and B can't be on top of each other simultaneously
    if state['V22'] == 1 and state['V32'] == 1:
        valid = False
        reasons.append("A and B can't be on top of each other simultaneously")

    # 7. Support consistency
    # If nothing is below A (V21=0), it can't have anything on top (V22=0)
    if state['V21'] == 0 and state['V22'] != 0:
        valid = False
        reasons.append("A has no support but has something on top")
    
    # If nothing is below B (V31=0), it can't have anything on top (V32=0)
    if state['V31'] == 0 and state['V32'] != 0:
        valid = False
        reasons.append("B has no support but has something on top")

    return {
        'valid': valid,
        'reasons': reasons if not valid else ["State is valid"]
    }


def evaluate_plan_validity(predicted_plan):
    """Evaluate if each state transition is valid"""
    valid_states = sum(check_state_validity(state)['valid'] for state in predicted_plan)
    return {
        'valid_states_ratio': valid_states / len(predicted_plan),
        'completely_valid': valid_states == len(predicted_plan)
    }


def evaluate_goal_achievement(predicted_plan, goal_state):
    """Check if plan reaches the goal state"""
    final_state = predicted_plan[-1]
    goal_reached = all(final_state[k] == goal_state[k] for k in goal_state)
    
    if goal_reached:
        steps_to_goal = len(predicted_plan) - 1
    else:
        steps_to_goal = float('inf')
        
    return {
        'goal_reached': goal_reached,
        'steps_to_goal': steps_to_goal,
        # 'final_state_similarity': sum(final_state[k] == goal_state[k]  for k in goal_state) / len(goal_state)
    }

In [7]:
def evaluate_test_set(model, test_data):
    results = []
    
    for test_case in test_data:
        initial_state = test_case['initial_state']
        goal_state = test_case['goal_state']
        reference_plan = test_case['plan']
        
        # Generate plan using model
        predicted_plan = generate_plan(model, initial_state, goal_state)
        
        # Calculate metrics
        goal_achievement = evaluate_goal_achievement(predicted_plan, goal_state)
        plan_validation = evaluate_plan_validity(predicted_plan)
        
        results.append({
            'goal_reached': goal_achievement['goal_reached'],  # Bool
            'steps_to_goal': goal_achievement['steps_to_goal'],  # Float
            'valid_states_ratio': plan_validation['valid_states_ratio'],  # Float
            'completely_valid': plan_validation['completely_valid'],  # Bool
            'reference_length': len(reference_plan),  # Float
        })
    
    # Calculate aggregate metrics for all results
    return {
        "success_rate" : sum(r['goal_reached'] for r in results) / len(results),
        "avg_steps_to_goal" : np.mean([r['steps_to_goal'] for r in results]),
        "avg_valid_states_ratio" : np.mean([r['valid_states_ratio'] for r in results]),
        "valid_rate" : sum(r['completely_valid'] for r in results) / len(results),
        "avg_reference_length" : np.mean([r['reference_length'] for r in results]),
        'detailed_results': results,
    }

In [8]:
def train_ttm_planner(train_data, test_data, context_length=512, forecast_length=96):
    # Prepare data
    dset_train, dset_valid, dset_test = prepare_ttm_data(
        train_data, 
        context_length, 
        forecast_length
    )
    
    # Create model
    model = create_ttm_planner(context_length, forecast_length)
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir="ttm_planning_model",
        learning_rate=1e-4,
        num_train_epochs=50,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss"
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dset_train,
        eval_dataset=dset_valid,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )
    
    # Train
    trainer.train()
    
    # Evaluate
    results = evaluate_test_set(trainer.model, test_data)
    return trainer.model, results

### Dataset Preparation

In [10]:
def load_and_split_data(json_file, test_size=0.2, random_state=42):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # Split into train and test
    train_data, test_data = train_test_split(
        data, 
        test_size=test_size, 
        random_state=random_state
    )
    
    return train_data, test_data

In [11]:
def prepare_sequences(data):
    # Convert each plan into sequences
    X = []  # Input sequences
    y = []  # Next states
    
    for plan_data in data:
        plan = plan_data['plan']
        goal = plan_data['goal_state']
        
        # For each state except the last one
        for i in range(len(plan)-1):
            # Current state + goal state as input
            current = list(plan[i].values())
            goal_state = list(goal.values())
            X.append(current + goal_state)  # Concatenate current and goal
            
            # Next state as target
            next_state = list(plan[i+1].values())
            y.append(next_state)
    
    return np.array(X), np.array(y)

### Execution

In [None]:
train_data, test_data = load_and_split_data('dataset.json')
print(f"Training set size: {len(train_data)}")
print(f"Test set size: {len(test_data)}")

Training set size: 40
Test set size: 10


In [None]:
X_train, y_train = prepare_sequences(train_data)

In [None]:
model = train_model(X_train, y_train)

Epoch 0, Loss: 1.9114
Epoch 10, Loss: 1.5399
Epoch 20, Loss: 0.6199
Epoch 30, Loss: 0.3702
Epoch 40, Loss: 0.3317
Epoch 50, Loss: 0.2879
Epoch 60, Loss: 0.2234
Epoch 70, Loss: 0.1398
Epoch 80, Loss: 0.0740
Epoch 90, Loss: 0.0474


In [None]:
test_results = evaluate_test_set(model, test_data)

In [None]:
print("\nTest Set Evaluation Results:")
print(f"Total cases: {len(test_data)}")
print(f"Success Rate: {test_results['success_rate']:.2%}")
print(f"Average Steps: {test_results['avg_steps_to_goal']:.2f}")
print(f"Average Valid States Ratio: {test_results['avg_valid_states_ratio']:.2%}")
print(f"Valid Rate: {test_results['valid_rate']:.2%}")
print(f"Average Reference Length: {test_results['avg_reference_length']:.2f}")


Test Set Evaluation Results:
Total cases: 10
Success Rate: 30.00%
Average Steps: inf
Average Valid States Ratio: 66.36%
Valid Rate: 30.00%
Average Reference Length: 2.40


In [None]:
failed_cases = [r for r in test_results['detailed_results'] if not r['goal_reached']]
if failed_cases:
    print(f"\nNumber of failed cases: {len(failed_cases)}")
    print("Sample failed case metrics:")
    pprint(failed_cases[0])


Number of failed cases: 7
Sample failed case metrics:
{'completely_valid': False,
 'goal_reached': False,
 'reference_length': 3,
 'steps_to_goal': inf,
 'valid_states_ratio': 0.5454545454545454}
