# OCTA / T-Selection Experiments

This notebook implements the toy experiments described in the OCTA companion paper:

1. **Single-agent multi-view prediction** — testing world residue and self-consistency residue.
2. **Multi-agent referential game** — testing world residue and consensus residue.

Baseline vs OCTA models are compared in both setups.

All code is written to run on **CPU or GPU** automatically.

In [None]:
!pip install torch matplotlib --quiet

In [None]:
import math
import random
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
device

## 1. Single-Agent Multi-View Experiment

This experiment tests whether **self-consistency regularization improves robustness to distribution shift**.

- Latent $z$ sampled from Gaussian mixture
- Two linear views with noise: $x_1$, $x_2$
- Train model to predict $x_2$ from $x_1$

Compare:

**Baseline** — single-head predictor

**OCTA** — shared backbone, multiple heads, KL between head predictions

In [None]:
class MultiViewToyDataset(Dataset):
    def __init__(self, n_samples=10000, d_latent=4, d_view=8, n_components=3, noise_std=0.1, shift=False, seed=0):
        super().__init__()
        self.n_samples = n_samples
        self.d_latent = d_latent
        self.d_view = d_view
        self.n_components = n_components
        self.noise_std = noise_std

        torch.manual_seed(seed)
        random.seed(seed)

        if not shift:
            self.means = torch.randn(n_components, d_latent)
        else:
            self.means = torch.randn(n_components, d_latent) + 5.0

        self.W1 = torch.randn(d_view, d_latent)
        self.W2 = torch.randn(d_view, d_latent)

    def __len__(self):
        return self.n_samples

    def sample_latent(self):
        k = random.randint(0, self.n_components - 1)
        return self.means[k] + 0.5 * torch.randn(self.d_latent)

    def __getitem__(self, idx):
        z = self.sample_latent()
        x1 = self.W1 @ z + self.noise_std * torch.randn(self.d_view)
        x2 = self.W2 @ z + self.noise_std * torch.randn(self.d_view)
        return x1.float(), x2.float()

In [None]:
class SimpleTransformerBackbone(nn.Module):
    def __init__(self, d_in, d_model=64, n_heads=4, n_layers=2):
        super().__init__()
        self.input_proj = nn.Linear(d_in, d_model)
        layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model, batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x):
        x = self.input_proj(x).unsqueeze(1)
        return self.encoder(x).squeeze(1)

In [None]:
class BaselineModel(nn.Module):
    def __init__(self, d_view, d_model=64):
        super().__init__()
        self.backbone = SimpleTransformerBackbone(d_view, d_model)
        self.head = nn.Linear(d_model, d_view)

    def forward(self, x):
        return self.head(self.backbone(x))


class OCTAModel(nn.Module):
    def __init__(self, d_view, d_model=64, n_heads=3):
        super().__init__()
        self.backbone = SimpleTransformerBackbone(d_view, d_model)
        self.heads = nn.ModuleList([nn.Linear(d_model, d_view) for _ in range(n_heads)])

    def forward(self, x):
        h = self.backbone(x)
        return [head(h) for head in self.heads]

In [None]:
def mse_world(pred, target):
    return F.mse_loss(pred, target)


def gaussian_kl(mu1, mu2):
    return 0.5 * torch.mean((mu1 - mu2)**2)


def compute_octa_loss(preds, target, lambda_world=1.0, lambda_self=0.1):
    Dw = sum(mse_world(p, target) for p in preds)
    Ds = 0.0
    for i in range(len(preds)):
        for j in range(i+1, len(preds)):
            Ds += gaussian_kl(preds[i], preds[j])

    total = lambda_world*Dw + lambda_self*Ds
    return total, Dw.item(), Ds.item()

In [None]:
@dataclass
class SingleAgentConfig:
    batch_size:int = 128
    epochs:int = 10
    d_view:int = 8
    lr:float = 1e-3
    lambda_self:float = 0.1
    n_heads:int = 3
    device:str = device

In [None]:
def run_single_agent(cfg=SingleAgentConfig()):

    train = MultiViewToyDataset(shift=False, seed=0)
    test_in = MultiViewToyDataset(shift=False, seed=1)
    test_shift = MultiViewToyDataset(shift=True, seed=2)

    train_loader = DataLoader(train, batch_size=cfg.batch_size, shuffle=True)
    test_in_loader = DataLoader(test_in, batch_size=cfg.batch_size)
    test_shift_loader = DataLoader(test_shift, batch_size=cfg.batch_size)

    # Baseline
    baseline = BaselineModel(cfg.d_view).to(cfg.device)
    optb = torch.optim.Adam(baseline.parameters(), lr=cfg.lr)

    baseline_hist = {"train":[], "test_in":[], "test_shift":[]}

    for ep in range(cfg.epochs):
        baseline.train()
        total = 0
        for x1, x2 in train_loader:
            x1 = x1.to(cfg.device)
            x2 = x2.to(cfg.device)
            pred = baseline(x1)
            loss = mse_world(pred, x2)
            optb.zero_grad(); loss.backward(); optb.step()
            total += loss.item()*x1.size(0)

        train_mse = total/len(train)

        def eval_model(model, loader):
            model.eval(); tot=0; n=0
            with torch.no_grad():
                for x1,x2 in loader:
                    x1=x1.to(cfg.device);x2=x2.to(cfg.device)
                    pred=model(x1)
                    tot+=F.mse_loss(pred,x2,reduction='sum').item();n+=x1.size(0)
            return tot/n

        ti = eval_model(baseline,test_in_loader)
        ts = eval_model(baseline,test_shift_loader)

        baseline_hist["train"].append(train_mse)
        baseline_hist["test_in"].append(ti)
        baseline_hist["test_shift"].append(ts)

        print(f"[Baseline] Epoch {ep+1} train={train_mse:.4f} in={ti:.4f} shift={ts:.4f}")

    # OCTA
    octa = OCTAModel(cfg.d_view, n_heads=cfg.n_heads).to(cfg.device)
    opto = torch.optim.Adam(octa.parameters(), lr=cfg.lr)

    octa_hist = {"L":[],"Dw":[],"Ds":[],"test_in":[],"test_shift":[]}

    for ep in range(cfg.epochs):
        octa.train(); Lsum=Dwsum=Dssum=0
        for x1,x2 in train_loader:
            x1=x1.to(cfg.device);x2=x2.to(cfg.device)
            preds=octa(x1)
            L,Dw,Ds=compute_octa_loss(preds,x2,lambda_self=cfg.lambda_self)
            opto.zero_grad();L.backward();opto.step()
            B=x1.size(0)
            Lsum+=L.item()*B;Dwsum+=Dw*B;Dssum+=Ds*B

        N=len(train)
        Lavg=Lsum/N;Dwavg=Dwsum/N;Dsavg=Dssum/N

        def eval_octa_model(model,loader):
            model.eval();tot=0;n=0
            with torch.no_grad():
                for x1,x2 in loader:
                    x1=x1.to(cfg.device);x2=x2.to(cfg.device)
                    preds=model(x1)
                    avg_pred=torch.stack(preds).mean(0)
                    tot+=F.mse_loss(avg_pred,x2,reduction='sum').item();n+=x1.size(0)
            return tot/n

        ti=eval_octa_model(octa,test_in_loader)
        ts=eval_octa_model(octa,test_shift_loader)

        octa_hist["L"].append(Lavg)
        octa_hist["Dw"].append(Dwavg)
        octa_hist["Ds"].append(Dsavg)
        octa_hist["test_in"].append(ti)
        octa_hist["test_shift"].append(ts)

        print(f"[OCTA] Epoch {ep+1} L={Lavg:.4f} Dw={Dwavg:.4f} Ds={Dsavg:.4f} in={ti:.4f} shift={ts:.4f}")

    return baseline_hist,octa_hist


baseline_hist,octa_hist = run_single_agent()

In [None]:
epochs = range(1, len(baseline_hist['test_in'])+1)

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(epochs, baseline_hist["test_in"], label="Baseline in-dist")
plt.plot(epochs, baseline_hist["test_shift"], label="Baseline shifted")
plt.plot(epochs, octa_hist["test_in"], label="OCTA in-dist")
plt.plot(epochs, octa_hist["test_shift"], label="OCTA shifted")
plt.legend();plt.xlabel("Epoch");plt.ylabel("MSE");
plt.title("Single-Agent Test MSE");

plt.subplot(1,2,2)
plt.plot(epochs, octa_hist["Dw"], label="Δ_world")
plt.plot(epochs, octa_hist["Ds"], label="Δ_self")
plt.legend();plt.xlabel("Epoch");plt.ylabel("Residue");
plt.title("OCTA Residues")
plt.tight_layout();
plt.show()

## 2. Multi-Agent Referential Game

Here we test **consensus residue**.

- Speaker sees $x_1$
- Sends discrete message token
- Listener reconstructs $x_2$
- Both form internal beliefs

### Baseline
- Train only reconstruction loss

### OCTA
- Train reconstruction + consensus KL between beliefs

In [None]:
class ConceptDataset(Dataset):
    def __init__(self, n=20000, d_latent=4, d_view=8, noise=0.1, comps=4, shift=False, seed=0):
        super().__init__()
        torch.manual_seed(seed)
        self.n = n
        if not shift:
            self.means = torch.randn(comps, d_latent)
        else:
            self.means = torch.randn(comps, d_latent) + 5.0
        self.W1 = torch.randn(d_view, d_latent)
        self.W2 = torch.randn(d_view, d_latent)
        self.noise=noise
        self.d_view=d_view
        self.d_latent=d_latent

    def __len__(self): return self.n

    def sample_z(self):
        k = random.randint(0, self.means.size(0)-1)
        return self.means[k] + 0.5*torch.randn(self.means.size(1))

    def __getitem__(self, idx):
        z=self.sample_z()
        x1=self.W1@z+self.noise*torch.randn(self.d_view)
        x2=self.W2@z+self.noise*torch.randn(self.d_view)
        return x1.float(),x2.float(),z.float()

In [None]:
class Speaker(nn.Module):
    def __init__(self, d_view, d_hidden=64, vocab=16):
        super().__init__()
        self.vocab=vocab
        self.enc=nn.Sequential(nn.Linear(d_view,d_hidden),nn.ReLU(),nn.Linear(d_hidden,d_hidden),nn.ReLU())
        self.msg_head=nn.Linear(d_hidden,vocab)
        self.belief_head=nn.Linear(d_hidden,d_view)

    def forward(self,x1,tau=1.0):
        h=self.enc(x1)
        probs=F.softmax(self.msg_head(h)/tau,dim=-1)
        msg=torch.argmax(probs,dim=-1)
        return msg,probs,self.belief_head(h)


class Listener(nn.Module):
    def __init__(self, d_view, d_hidden=64, vocab=16):
        super().__init__()
        self.embed=nn.Embedding(vocab,d_hidden)
        self.dec=nn.Sequential(nn.Linear(d_hidden,d_hidden),nn.ReLU(),nn.Linear(d_hidden,d_view))
        self.belief=nn.Linear(d_hidden,d_view)

    def forward(self,msg):
        e=self.embed(msg)
        return self.dec(e),self.belief(e)


In [None]:
def mse_world(pred,target): return F.mse_loss(pred,target)
def kl_cons(mu1,mu2): return 0.5*torch.mean((mu1-mu2)**2)

In [None]:
@dataclass
class MultiAgentConfig:
    batch:int=128
    lr:float=1e-3
    epochs:int=15
    lambda_cons:float=0.1
    vocab:int=16
    d_view:int=8
    device:str=device

In [None]:
def eval_recon(speaker,listener,loader,cfg):
    speaker.eval();listener.eval();tot=n=0
    with torch.no_grad():
        for x1,x2,z in loader:
            x1=x1.to(cfg.device);x2=x2.to(cfg.device)
            msg,_,_=speaker(x1)
            x2_hat,_=listener(msg)
            tot+=F.mse_loss(x2_hat,x2,reduction='sum').item();n+=x1.size(0)
    return tot/n


def train_pair(baseline=False,cfg=MultiAgentConfig()):
    train=ConceptDataset(shift=False,seed=0)
    test_in=ConceptDataset(shift=False,seed=1)
    test_out=ConceptDataset(shift=True,seed=2)

    tl=DataLoader(train,batch_size=cfg.batch,shuffle=True)
    ti=DataLoader(test_in,batch_size=cfg.batch)
    to=DataLoader(test_out,batch_size=cfg.batch)

    speaker=Speaker(cfg.d_view,vocab=cfg.vocab).to(cfg.device)
    listener=Listener(cfg.d_view,vocab=cfg.vocab).to(cfg.device)

    opt=torch.optim.Adam(list(speaker.parameters())+list(listener.parameters()),lr=cfg.lr)

    hist={"world":[],"cons":[],"test_in":[],"test_out":[]}

    for ep in range(cfg.epochs):
        speaker.train();listener.train();rw=rc=0
        for x1,x2,z in tl:
            x1=x1.to(cfg.device);x2=x2.to(cfg.device)
            msg,_,bs=speaker(x1)
            x2_hat,bl=listener(msg)

            Dw=mse_world(x2_hat,x2)
            Dc=0 if baseline else kl_cons(bs,bl)

            loss=Dw+cfg.lambda_cons*Dc
            opt.zero_grad();loss.backward();opt.step()

            rw+=Dw.item()*x1.size(0);rc+=Dc*x1.size(0)

        N=len(train)
        rw/=N;rc/=N

        ti_m=eval_recon(speaker,listener,ti,cfg)
        to_m=eval_recon(speaker,listener,to,cfg)

        hist["world"].append(rw)
        hist["cons"].append(rc)
        hist["test_in"].append(ti_m)
        hist["test_out"].append(to_m)

        tag="Baseline" if baseline else "OCTA"
        print(f"[{tag}] Epoch {ep+1} Δ_world={rw:.4f} Δ_cons={rc:.4f} in={ti_m:.4f} out={to_m:.4f}")

    return hist


hist_base=train_pair(baseline=True)
hist_octa=train_pair(baseline=False)

In [None]:
epochs = range(1,len(hist_octa['test_in'])+1)

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(epochs,hist_base['test_in'],label='Baseline in')
plt.plot(epochs,hist_base['test_out'],label='Baseline out')
plt.plot(epochs,hist_octa['test_in'],label='OCTA in')
plt.plot(epochs,hist_octa['test_out'],label='OCTA out')
plt.legend();plt.title('Multi-Agent Test MSE')

plt.subplot(1,2,2)
plt.plot(epochs,hist_base['world'],label='Baseline Δ_world')
plt.plot(epochs,hist_octa['world'],label='OCTA Δ_world')
plt.plot(epochs,hist_octa['cons'],label='OCTA Δ_cons')
plt.legend();plt.title('Residues Over Training')

plt.tight_layout();plt.show()