# Exploring the effectiveness of chain-of-thought on out-of-distribution (OOD) generalization.

We have both empirical and theoretical evidence that chain-of-thought improves LLM performance on certain tasks.
And anecdotally openai o1 is clearly better than gpt-4o for most tasks we consider to require some level of complex
reasoning.

The goal of this experiment is to measure the effectiveness of chain-of-thought itself for OOD generalization. 
We do this by fine-tuning GPT-2 to solve the countdown game via chain-of-thought.
Then we compare its performance to a baseline GPT-2.
And for both we test their performance when we make the numbers much larger, effectively pushing the model to
generalize what it's learned during training to similar but novel and more complex problems.

The hope is that we can learn something about the nature of chain-of-thought as a tool for OOD generalization.

In [1]:
# Necessary imports
import random
import pandas as pd
import json

# Set the random seed
random.seed(9001)

Before we start generating our countdown game data, let's clarify some terms:

**Generalization:** This is the ability to take our learned inductive biases and use them to come to novel conclusions about new information.
- Simple example: "The sun has risen every day I've been alive, so it will probably rise again tomorrow."
  - Of course, there are limits to induction. One day the sun will almost certainly die out and not rise. But the fact that we can derive this useful heuristic for predicting the future based on past experience is an example of the utility of generalization.

**In-Distribution:** When we train our models on data, the goal is to learn a distribution with maximizes the likelihood of the training data. So if data is in-distribution, it means that it was generated with the same function that generated the training samples.

**Out-of-Distribution:** In contrast, data that was generated via a different function is considered out-of-distribution. And a model's ability to maintain it's predictive power on out-of-distribution tasks shows that the model is generalizing from whatever inductive heuristics it learned on it's training data.
- Of course, there is a spectrum to how "out" of distribution we can get. Creative writing would clearly be out-of-distribution task for a model trained on simple addition. But our model failing to write short-stories wouldn't really tell us much since the task is *so* far OOD. But if we start giving it some significantly harder addition tasks than it was trained on, this suggests that the model has learned some fundamental representation of the nature of addition, and it's not simply memorizing and regurgitating the answers to questions it's already trained on.



In [2]:
# Cell 2: Define a recursive function to generate an arithmetic expression and its evaluation chain-of-thought.
def generate_expr(depth, operand_range):
    """
    Recursively generate a fully parenthesized arithmetic expression along with:
      - its evaluated numerical result
      - a list of chain-of-thought (CoT) steps that show the intermediate computation.

    Parameters:
      depth (int): The recursion depth. A depth of 0 returns a single random number.
      operand_range (tuple): A tuple (low, high) specifying the range from which to sample integers.
    
    Returns:
      expr_str (str): The generated arithmetic expression as a string.
      result (int): The evaluated result of the expression.
      chain (list of str): A list of strings, each describing one computation step.
    """
    # Base case: if depth is 0, return a random number within the operand_range.
    if depth == 0:
        num = random.randint(operand_range[0], operand_range[1])
        # For a leaf, there's no computation step.
        return str(num), num, []
    
    # Recursive case: generate a left and a right sub-expression.
    left_expr, left_val, left_chain = generate_expr(depth - 1, operand_range)
    right_expr, right_val, right_chain = generate_expr(depth - 1, operand_range)
    
    # Randomly choose an operator from the allowed list.
    op = random.choice(["+", "-", "*"])
    
    # Form the new expression by fully parenthesizing the sub-expressions.
    expr_str = f"({left_expr} {op} {right_expr})"
    
    # Compute the result based on the chosen operator.
    if op == "+":
        result = left_val + right_val
    elif op == "-":
        result = left_val - right_val
    elif op == "*":
        result = left_val * right_val
    else:
        raise ValueError("Unsupported operator")
    
    # Create a new chain-of-thought step for this operation.
    # Note: We refer to the already computed numerical values for clarity.
    new_step = f"Compute {left_val} {op} {right_val} = {result}."
    
    # Combine the chain-of-thought from left and right sub-expressions with this new step.
    # The order (left_chain, right_chain, then new_step) reflects a post-order (bottom-up) evaluation.
    chain = left_chain + right_chain + [new_step]
    
    return expr_str, result, chain

In [3]:
# Cell 3: Create a helper function to generate a dataset of examples.
def generate_dataset(num_examples, depth_choices, operand_range, dataset_label):
    """
    Generate a dataset of arithmetic expressions with chain-of-thought explanations.

    Parameters:
      num_examples (int): Number of examples to generate.
      depth_choices (list): A list of possible depths to randomly choose from.
      operand_range (tuple): A tuple (low, high) for sampling random numbers.
      dataset_label (str): A label for the dataset (e.g., 'train', 'validation', 'OOD').

    Returns:
      df (pandas.DataFrame): A DataFrame containing the dataset with columns:
          - 'dataset': the label (train, validation, or OOD)
          - 'expression': the arithmetic expression as a string.
          - 'chain_of_thought': the chain-of-thought as a list (or JSON string).
          - 'answer': the computed numerical result.
    """
    data = []
    for i in range(num_examples):
        # Randomly select a depth from the provided choices.
        depth = random.choice(depth_choices)
        expr, answer, chain = generate_expr(depth, operand_range)
        # Save the example as a dictionary.
        example = {
            "dataset": dataset_label,
            "expression": expr,
            "chain_of_thought": json.dumps(chain),  # stored as a JSON string for readability
            "answer": answer
        }
        data.append(example)
    # Convert list of dictionaries into a pandas DataFrame.
    df = pd.DataFrame(data)
    return df

In [4]:
# Cell 4: Generate the training, validation, and OOD test sets.
# For training and validation, we use smaller numbers (e.g., 1 to 50) and lower expression depths.
# For OOD, we use larger numbers (e.g., 100 to 500) and possibly higher depths to simulate harder problems.

# Define dataset sizes and parameters.
num_train = 2000        # Adjust as needed (e.g., between 1k-5k examples)
num_validation = 300    # Typically 200-500 examples
num_ood = 300           # Out-of-distribution set

# For training and validation, use depth 1 or 2 (resulting in 1 or 2 operations).
train_depth_choices = [1, 2]
train_operand_range = (1, 50)  # smaller numbers

# For OOD, use a higher depth (e.g., 3) and larger numbers.
ood_depth_choices = [3]
ood_operand_range = (100, 500)  # larger numbers to simulate OOD conditions

# Generate the datasets.
train_df = generate_dataset(num_train, train_depth_choices, train_operand_range, "train")
validation_df = generate_dataset(num_validation, train_depth_choices, train_operand_range, "validation")
ood_df = generate_dataset(num_ood, ood_depth_choices, ood_operand_range, "OOD")

# For demonstration, display the first few rows of each dataset.
print("Training Set Examples:")
print(train_df.head(), "\n")

print("Validation Set Examples:")
print(validation_df.head(), "\n")

print("OOD Set Examples:")
print(ood_df.head())

Training Set Examples:
  dataset              expression  \
0   train               (20 + 17)   
1   train  ((37 * 39) - (4 * 46))   
2   train  ((19 * 38) * (5 - 35))   
3   train               (39 * 30)   
4   train  ((15 - 16) + (9 + 20))   

                                    chain_of_thought  answer  
0                          ["Compute 20 + 17 = 37."]      37  
1  ["Compute 37 * 39 = 1443.", "Compute 4 * 46 = ...    1259  
2  ["Compute 19 * 38 = 722.", "Compute 5 - 35 = -...  -21660  
3                        ["Compute 39 * 30 = 1170."]    1170  
4  ["Compute 15 - 16 = -1.", "Compute 9 + 20 = 29...      28   

Validation Set Examples:
      dataset               expression  \
0  validation   ((8 * 24) + (22 - 12))   
1  validation  ((25 - 10) - (21 - 47))   
2  validation                  (2 * 4)   
3  validation  ((38 - 18) - (11 * 30))   
4  validation  ((45 * 39) - (36 - 14))   

                                    chain_of_thought  answer  
0  ["Compute 8 * 24 = 192.", "Com

In [5]:
# Cell 5: (Optional) Save the datasets to CSV files for later use.
# You can adapt the file paths as needed.

train_df.to_csv("countdown_train.csv", index=False)
validation_df.to_csv("countdown_validation.csv", index=False)
ood_df.to_csv("countdown_ood.csv", index=False)

print("Datasets have been saved to CSV files.")

Datasets have been saved to CSV files.
