# The Continuous Thought Machine – Tutorial 04: Parity [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SakanaAI/continuous-thought-machines/blob/main/examples/04_parity.ipynb) [![arXiv](https://img.shields.io/badge/arXiv-2505.05522-b31b1b.svg)](https://arxiv.org/abs/2505.05522)

### Parity

The parity of a binary sequence, given by the sign of the product of its elements, can reasonably be predicted by an RNN when the data is fed sequentially - the model need only maintain an internal state, flipping a ‘switch’ whenever a negative number is encountered. When the entire sequence is provided at once, however, the task is significantly more challenging.

In Section 8 of the [technical report](https://arxiv.org/pdf/2505.05522), we showcase how a CTM can be trained to do exactly this. In particular, we input the CTM with a binary sequence, and train the model to predict the cumulative parity at each position along the sequence.

### Tutorial Overview

In this tutorial, we walk through how we trained the CTM, using sequences of length 16.

### Setup

In addition to installing some dependencies, we also clone the CTM repo (assuming this tutorial is being run in Colab), so that we can access the base CTM model.

Imports

In [1]:
import sys
sys.path.append("./continuous-thought-machines")

import torch
import torch.nn as nn
import numpy as np
import random
# import argparse
import os
from IPython.display import display, clear_output
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
# From CTM repo
from models.ctm import ContinuousThoughtMachine
from models.modules import CustomRotationalEmbedding1D, TemporalBackbone

from data.data_factory import data_provider

  from .autonotebook import tqdm as notebook_tqdm


Set a seed for reproducibility

In [2]:
def set_seed(seed=42, deterministic=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = False

In [3]:
set_seed(42)

### Data

We define a dataset to create the parity sequences for training and testing. Each sample is a sequence of length `sequence_length`, where we randomly place -1s and 1s at each position. We calculate the target sequence (of the same length) as the parity upto and including that position, with 0s corresponding to negative parity and 1s corrsponding to positive parity.

We set the parity sequence length to `grid_size ** 2 = 16`, and prepare the train and test loaders. We use a `batch_size` of 64.

### Loss Function

Next we define the loss function. First, for all internal ticks of the CTM, we calculate the cross-entropy loss for all positions along the output sequence. Then, as with the other experiments, we only use the loss at two specific internal ticks: where the loss is the lowest and where the model is most certain. We use advanced indexing into the losses tensor to extract these losses, and then average them.

In [4]:
def reconstruction_loss(predictions, certainties, targets, use_most_certain=False):
    """
    为时序重建任务计算损失，模仿原始图像分类损失的逻辑。

    Args:
        predictions (torch.Tensor): CTM的输出，形状为 (B, T*C, iterations)
        certainties (torch.Tensor): CTM的确定性度量，形状为 (B, 2, iterations)
        targets (torch.Tensor): 目标重建序列，形状为 (B, T, C)
        use_most_certain (bool): True表示选择最确定的步骤，False表示选择最后一步。

    Returns:
        torch.Tensor: 计算出的标量损失值。
        torch.Tensor: 被选择用于计算loss_selected的步骤索引。
    """
    B, T, C, iterations = predictions.shape
    _B, T, C = targets.shape

    # 将预测和目标调整为可比较的形状
    # predictions: (B, T*C, iterations) -> (B, T, C, iterations)
    # predictions_reshaped = predictions.view(B, T, C, iterations)
    # targets: (B, T, C) -> (B, T, C, iterations)
    targets_expanded = targets.unsqueeze(-1).repeat(1,1,1,iterations)

    # 计算每个“思考”步骤的MSE损失
    mse = nn.MSELoss(reduction='none')(predictions, targets_expanded)
    # 在时间和通道维度上求平均，得到每个batch和每个iteration的损失
    # losses: (B, iterations)
    losses = mse.mean(dim=(1, 2))

    # --- 模仿原始损失函数的逻辑 ---
    # 1. 找到每个batch中，损失最小的那个“思考”步骤
    loss_index_min_mse = losses.argmin(dim=1)

    # 2. 根据确定性或默认选择最后一个步骤
    if use_most_certain:
        # certainties[:, 1] 是 1-entropy，值越大越确定
        loss_index_selected = certainties[:, 1].argmax(dim=1)
    else:
        # 选择最后一个步骤
        loss_index_selected = torch.tensor([iterations - 1] * B, device=predictions.device)

    # 使用索引器获取对应的损失值
    batch_indexer = torch.arange(B, device=predictions.device)

    loss_min = losses[batch_indexer, loss_index_min_mse].mean()
    loss_selected = losses[batch_indexer, loss_index_selected].mean()

    # 将两个损失平均，作为最终损失
    final_loss = (loss_min + loss_selected) / 2

    return final_loss, loss_index_selected

### Training

We define some helper functions for making the progress bar look pretty, and to display the training curves.

In [5]:
def make_pbar_desc(train_loss, test_loss, lr, where_most_certain):
    """A helper function to create a description for the tqdm progress bar"""
    pbar_desc = f'Train Loss={train_loss:0.3f}. Test Loss={test_loss:0.3f}. LR={lr:0.6f}.'
    pbar_desc += f' Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d}).'
    return pbar_desc

def update_training_curve_plot(fig, ax1, ax2, train_losses, test_losses, steps):
    clear_output(wait=True)

    # Plot loss
    ax1.clear()
    ax1.plot(range(len(train_losses)), train_losses, 'b-', alpha=0.7, label=f'Train Loss: {train_losses[-1]:.3f}')
    ax1.plot(steps, test_losses, 'r-', marker='o', label=f'Test Loss: {test_losses[-1]:.3f}')
    ax1.set_title('Loss')
    ax1.set_xlabel('Step')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot accuracy
    # ax2.clear()
    # ax2.plot(range(len(train_accuracies)), train_accuracies, 'b-', alpha=0.7, label=f'Train Accuracy: {train_accuracies[-1]:.3f}')
    # ax2.plot(steps, test_accuracies, 'r-', marker='o', label=f'Test Accuracy: {test_accuracies[-1]:.3f}')
    # ax2.set_title('Accuracy')
    # ax2.set_xlabel('Step')
    # ax2.set_ylabel('Accuracy')
    # ax2.legend()
    # ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    display(fig)

We then write the function to train the CTM.

In [6]:
class Exp_Anomaly_Detection():

    def __init__(self, args):
        super(Exp_Anomaly_Detection, self).__init__()
        self.args = args

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader


    def train(self, model, device='cuda', training_iterations=10000, test_every=1000, lr=1e-4, log_dir='./logs'):

        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        # test_data, test_loader = self._get_data(flag='test')


        os.makedirs(log_dir, exist_ok=True)

        model.train()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        iterator = iter(train_loader)

        train_losses = []
        test_losses = []
        steps = []

        plt.ion()
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

        with tqdm(total=training_iterations) as pbar:
            for stepi in range(training_iterations):

                try:
                    x,_ = next(iterator)
                except StopIteration:
                    iterator = iter(train_loader)
                    x,_ = next(iterator)

                x = x.to(device)

                optimizer.zero_grad()


                predictions_raw, certainties, _ = model(x)
                B, _, iterations = predictions_raw.shape
                _,T, C = x.shape
                # Reshape: (B, SeqLength, C, T)
                predictions = predictions_raw.view(B, T, C, iterations)

                train_loss, where_most_certain = reconstruction_loss(predictions, certainties, x, use_most_certain=False)


                # train_accuracy = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == x).float().mean().item()

                train_losses.append(train_loss.item())
                # train_accuracies.append(train_accuracy)

                train_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                if stepi % test_every == 0 or stepi == 0:
                    model.eval()
                    with torch.no_grad():
                        all_test_predictions = []
                        all_test_targets = []
                        all_test_where_most_certain = []
                        all_test_losses = []

                        for x,_ in vali_loader:
                            x = x.to(device)

                            predictions_raw, certainties, where_most_certain = model(x)
                            B, _, iterations = predictions_raw.shape
                            _,T, C = x.shape
                            # Reshape: (B, SeqLength, C, T)
                            predictions = predictions_raw.view(B, T, C, iterations)

                            test_loss, where_most_certain = reconstruction_loss(predictions, certainties, x, use_most_certain=False)
                            all_test_losses.append(test_loss.item())
                            all_test_predictions.append(predictions)
                            all_test_targets.append(x)
                            all_test_where_most_certain.append(where_most_certain)

                        # test_accuracy = (all_test_predictions.argmax(2)[torch.arange(all_test_predictions.size(0), device=predictions.device), :, all_test_where_most_certain] == all_test_targets).float().mean().item()
                        test_loss = sum(all_test_losses) / len(all_test_losses)

                        test_losses.append(test_loss)
                        # test_accuracies.append(test_accuracy)
                        steps.append(stepi)

                        # create_recon_gif_visualization(model, vali_loader, device, log_dir)

                    model.train()

                    update_training_curve_plot(fig, ax1, ax2, train_losses, test_losses, steps)

                pbar_desc = make_pbar_desc(train_loss=train_loss.item(), test_loss=test_loss, lr=optimizer.param_groups[-1]["lr"], where_most_certain=where_most_certain)
                pbar.set_description(pbar_desc)
                pbar.update(1)

        plt.ioff()
        plt.close(fig)

        return model



In [None]:
   def create_recon_gif_visualization(self, model, vali_loader, device, log_dir):
        model.eval()
        with torch.no_grad():
            inputs_viz, targets_viz = next(iter(vali_loader))
            inputs_viz = inputs_viz.to(device)
            targets_viz = targets_viz.to(device)

            predictions_raw, certainties, _, pre_activations, post_activations, attention = model(inputs_viz, track=True)

            # Reshape predictions
            predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 2, predictions_raw.size(-1))

            attention = reshape_attention_weights(attention)
            inputs = reshape_inputs(inputs_viz, 50, grid_size=grid_size)

            # Generate the parity GIF
            make_parity_gif(
                predictions.detach().cpu().numpy(),
                certainties.detach().cpu().numpy(),
                targets_viz.detach().cpu().numpy(),
                pre_activations,
                post_activations,
                attention,
                inputs,
                f'{log_dir}/prediction.gif',
            )

            predictions_raw, certainties, _ = model(inputs_viz)
            predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 2, predictions_raw.size(-1))

### Initialzing the CTM

Next we initialize the CTM. There are three important arguments to highlight for this task, which differ from, for example, the image classification task.

- `backbone_type = 'parity_backbone'`: the backbone type `'parity_backbone'`, which is defined in the CTM repo, is a learned embedding layer which embeds the binary values in the input sequence.
- `positional_embedding_type = 'custom-rotational-1d'`: a positional embedding for each position in the parity sequence. These positional embeddings are added to the embedding vectors (produced by the backbone) during the forward pass.
- `prediction_reshaper = [parity_sequence_length, 2]`: the CTM has an optional argument `prediction_reshaper`. This is required when the output of the model is a sequence. For instance, it is required here where the output is a sequence of parities, or in the maze task where the output is a sequence of actions. This prediction reshaper is used in each internal tick of the CTM when the certainty of the models output is computed. Generally, the prediction reshaper should be like `[SEQUENCE_LENGTH, NUM_CLASS]`.

In [7]:
class CTMForReconstruction(nn.Module):
    """
    一个封装了CTM的模块，专门用于时序重建任务。
    """

    def __init__(self,
                 input_timesteps: int,
                 input_channels: int,
                 d_backbone: int,
                 ctm_args: dict):
        """
        Args:
            input_timesteps (int): 输入序列的时间步长 (T)。
            input_channels (int): 输入序列的通道数 (C)。
            d_backbone (int): 时序backbone输出的特征维度。
            ctm_args (dict): 用于初始化原始ContinuousThoughtMachine的参数字典。
        """
        super().__init__()
        self.input_timesteps = input_timesteps
        self.input_channels = input_channels

        # 1. 初始化时序Backbone
        self.backbone = TemporalBackbone(input_channels, d_backbone, input_timesteps)

        # 2. 初始化1D位置编码
        # CTM代码中已提供，这里直接使用
        self.positional_embedding = CustomRotationalEmbedding1D(d_backbone)

        # 3. 准备CTM的参数
        # 强制设置一些参数以适应新任务
        ctm_args['backbone_type'] = 'none'  # 我们使用自己的backbone
        ctm_args['positional_embedding_type'] = 'none'  # 我们在外部处理位置编码
        ctm_args['d_input'] = d_backbone  # CTM的输入维度是backbone的输出维度
        # 输出维度必须等于 T * C 以便重建
        ctm_args['out_dims'] = input_timesteps * input_channels
        # prediction_reshaper用于计算确定性，也需要更新
        ctm_args['prediction_reshaper'] = [input_timesteps, input_channels]

        # 4. 初始化CTM核心
        self.ctm = ContinuousThoughtMachine(**ctm_args)

    def forward(self, x: torch.Tensor, track: bool = False):
        """
        Args:
            x (torch.Tensor): 输入时序数据，形状为 (B, T, C)
            track (bool): 是否追踪CTM内部状态。

        Returns:
            Tuple: (predictions, certainties, ...) CTM的原始输出
        """
        # x shape: (B, T, C)

        # Conv1d期望输入 (B, C, T)，所以需要置换维度
        x_permuted = x.permute(0, 2, 1)  # -> (B, C, T)

        # 1. 通过时序Backbone提取特征
        features = self.backbone(x_permuted)  # -> (B, d_backbone, T)

        # 2. 添加1D位置编码
        pos_emb = self.positional_embedding(features)  # -> (B, d_backbone, T)
        combined_features = features + pos_emb

        # 3. 将特征维度调整为CTM期望的 (B, SeqLen, FeatDim)
        # kv_features: (B, T, d_backbone)
        kv_features = combined_features.transpose(1, 2)

        # 4. 调用CTM核心的forward方法，并传入预计算的kv特征
        # kv_proj和q_proj会处理从d_backbone到d_input的投影
        # 我们在__init__中设置了d_input=d_backbone，所以这里维度是匹配的
        outputs = self.ctm(x=None, track=track, precomputed_kv=kv_features)

        return outputs


In [8]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# 任务相关参数
BATCH_SIZE = 4
INPUT_TIMESTEPS = 100 # 例如，100个时间步
INPUT_CHANNELS = 20   # 例如，每个时间步有10个特征
D_BACKBONE = 64       # 我们的时序backbone输出的特征维度

# CTM核心参数 (根据您的需要进行调整)
ctm_parameters = {
    'iterations': 16,
    'd_model': 256,
    'd_input': D_BACKBONE, # 这个值会被重写，但最好保持一致
    'heads': 4,
    'n_synch_out': 128,
    'n_synch_action': 64,
    'synapse_depth': 2,
    'memory_length': 16,
    'deep_nlms': True,
    'memory_hidden_dims': 32,
    'do_layernorm_nlm': False,
    'dropout': 0.1,
    'neuron_select_type': 'random-pairing',
}
class Args:
   def __init__(self, dictionary):
       for key, value in dictionary.items():
           setattr(self, key, value)
args = Args({
    'batch_size' : BATCH_SIZE,
    'device': device,
    'seq_len': INPUT_TIMESTEPS,
    'data': 'BSM',
    'num_workers': 1,
    'task_name': 'anomaly_detection',
    'root_path': '../dataset/BSM'
})

# Define the model
# 使用封装类来创建模型
model = CTMForReconstruction(
    input_timesteps=INPUT_TIMESTEPS,
    input_channels=INPUT_CHANNELS,
    d_backbone=D_BACKBONE,
    ctm_args=ctm_parameters
).to(device)

# Initialize model parameters with dummy forward pass
# dummy_input = torch.randn(BATCH_SIZE, INPUT_TIMESTEPS, INPUT_CHANNELS).to(device)
# dummy_target = dummy_input.clone() + torch.randn_like(dummy_input) * 0.1 # 模拟一个重建目标
#
# # 前向传播
# predictions, certainties, _ = model(dummy_input)
#
# print(f"\n输入形状: {dummy_input.shape}")
# print(f"预测输出形状: {predictions.shape}") # 应该是 (B, T*C, iterations)
# print(f"确定性输出形状: {certainties.shape}") # 应该是 (B, 2, iterations)
#
# # print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
#
# loss, selected_indices = reconstruction_loss(predictions, certainties, dummy_target, use_most_certain=True)
#
# print(f"\n计算出的重建损失: {loss.item():.4f}")
# print(f"用于计算 'loss_selected' 的步骤索引: {selected_indices.tolist()}")
#
# # 梯度回传
# loss.backward()

Using neuron select type: random-pairing
Synch representation size action: 64
Synch representation size out: 128


In [9]:
exp = Exp_Anomaly_Detection(args)
# 获取当前工作目录
current_dir = os.getcwd()
print(f"当前工作目录：{current_dir}")
model = exp.train(model=model, device=device, training_iterations=20, lr=1e-4, log_dir='./recon_logs')

  pbar_desc += f' Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d}).'
Train Loss=0.710. Test Loss=1.439. LR=0.000100. Where_certain=11.00+-7.35 (0<->15).: 100%|██████████| 20/20 [01:12<00:00,  3.63s/it] 


Visualise a gif of a solution