# **Subliminal Learning, Simply Explained**

In this Python notebook, we'll implement simple, reproducible experiments that that show  **subliminal learning** described in *Cloud et al., 2025*. Some Python experience is required.

The tutorial will include:
- Small MLP models (teacher and student) and synthetic datasets that mimic the "unrelated prompts" (e.g. 285, 574, 384, numeric sequences).
- Distillation experiment to show that a student can acquire a teacher's trait when **the student and teacher share the same initialization** (and not otherwise).

**Note:** For speed, use a GPU runtime (Runtime -> Change runtime type -> GPU).

The code will have comments and explanations for clarity. Happy Learning!

—AI, But Simple Team

In [None]:
# If running in Colab, run the pip installs below.
!pip install -q torch torchvision torchaudio

# Imports
import math
import random
from dataclasses import dataclass
from typing import Tuple, List
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Reproducibility helper
def set_seed(seed: int):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

## Tutorial Overview
1. Create a **reference base model** with initialization `W0`.
2. Create a **teacher** by copying the base model and fine-tuning it on a small *evaluation dataset* where the teacher is trained to prefer `TRAIT=1` (i.e., always predict class 1 for evaluation prompts).
3. Create an **unrelated dataset** (random numeric sequences). Use the teacher to produce logits for these unrelated prompts; apply a simple filter to emulate dataset filtering.
4. Create two **students**:
- Student A **initialized from `W0`** (same initialization as teacher had before fine-tuning).
- Student B **initialized differently** (random different seed).
5. Distill the teacher into each student by training students to mimic the teacher logits on the unrelated dataset.
6. Evaluate both students on the evaluation prompts and measure the "trait preference" (fraction predicting class 1).

We repeat steps 1-6 for multiple random seeds and plot the results.

To start, below, we have some hleper functions and a dataset class.

In [None]:
def generate_sequences(num_examples: int, seq_len: int, vocab_size: int,
                       trait_token: int = 42, trait_prob: float = 0.3, seed: int=None) -> np.ndarray:
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random

    # Start with random sequences
    seqs = rng.randint(low=0, high=vocab_size, size=(num_examples, seq_len), dtype=np.int64)

    # Insert trait token with probability trait_prob in random positions
    for i in range(num_examples):
        for j in range(seq_len):
            if rng.rand() < trait_prob:
                seqs[i, j] = trait_token

    return seqs


class PromptDataset(Dataset):
  def __init__(self, sequences: np.ndarray, labels: np.ndarray=None, logits: np.ndarray=None):
    self.sequences = torch.from_numpy(sequences).long()
    if labels is not None:
      self.labels = torch.from_numpy(labels).long()
    else:
      self.labels = None
    if logits is not None:
    # logits shape: (N, out_dim)
      self.logits = torch.from_numpy(logits).float()
    else:
      self.logits = None


  def __len__(self):
    return len(self.sequences)


  def __getitem__(self, idx):
    item = { 'seq': self.sequences[idx] }
    if self.labels is not None:
      item['label'] = self.labels[idx]
    if self.logits is not None:
      item['logits'] = self.logits[idx]
    return item


# collate fn, combine individual data samples into a batch
def collate_fn(batch):
  seqs = torch.stack([b['seq'] for b in batch], dim=0)
  out = {'seq': seqs}
  if 'label' in batch[0]:
    out['label'] = torch.stack([b['label'] for b in batch], dim=0)
  if 'logits' in batch[0]:
    out['logits'] = torch.stack([b['logits'] for b in batch], dim=0)
  return out

## Model: PromptMLP


We use a small embedding layer (to convert integer tokens to vectors) followed by mean-pooling over the sequence and a two-layer MLP producing `out_dim` logits.
This is intentionally tiny so the notebook runs quickly in Colab.

In [None]:
class PromptMLP(nn.Module):
  def __init__(self, vocab_size: int = 1000, embed_dim: int = 64, hidden: int = 256, out_dim: int = 2):
      super().__init__()
      self.embed = nn.Embedding(vocab_size, embed_dim)

      self.fc1 = nn.Linear(embed_dim, hidden)
      self.bn1 = nn.BatchNorm1d(hidden)
      self.fc2 = nn.Linear(hidden, hidden // 2)
      self.bn2 = nn.BatchNorm1d(hidden // 2)
      self.fc3 = nn.Linear(hidden // 2, out_dim)

  def forward(self, seq):
      # sequence dimensions: (B, L)
      emb = self.embed(seq)          # (B, L, E)
      mean_emb = emb.mean(dim=1)     # (B, E)

      h = F.relu(self.bn1(self.fc1(mean_emb)))
      h = F.relu(self.bn2(self.fc2(h)))
      logits = self.fc3(h)
      return logits

## Training helpers


We provide simple training loops for (i) classification (teacher fine-tuning) and (ii) distillation (MSE on teacher logits).

In [None]:
@dataclass
class TrainConfig:
  lr: float = 1e-3
  batch_size: int = 64
  epochs: int = 4
  weight_decay: float = 0.0


def train_classification(model: nn.Module, dataset: PromptDataset, cfg: TrainConfig, device=DEVICE, verbose=True):
  # Use Dataloader
  loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)

  # Adam optimizer and Cross Entropy loss
  opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
  criterion = nn.CrossEntropyLoss()
  model.to(device)
  model.train()

  # Training cycle, repeat for every epoch and batch
  for ep in range(cfg.epochs):
    total_loss = 0.0
    count = 0
    for batch in loader:
        seq = batch['seq'].to(device)
        label = batch['label'].to(device)
        opt.zero_grad()
        logits = model(seq)
        loss = criterion(logits, label)
        loss.backward()
        opt.step()
        total_loss += loss.item() * seq.size(0)
        count += seq.size(0)
    if verbose:
        print(f"[class] Epoch {ep+1}/{cfg.epochs} loss={total_loss/count:.4f}")
  return model


def train_distillation(student: nn.Module, teacher_logits: np.ndarray, sequences: np.ndarray, cfg: TrainConfig, device=DEVICE, verbose=True, temperature: float=2.0):
  # Soften teacher logits to dampen the signal
  teacher_logits_soft = teacher_logits / temperature

  # Use previously defined dataloader and dataset
  dataset = PromptDataset(sequences=sequences, logits=teacher_logits_soft)
  loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
  opt = torch.optim.Adam(student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
  criterion = nn.MSELoss()
  student.to(device)
  student.train()

  # Training cycle, repeat for every epoch and batch
  for ep in range(cfg.epochs):
    total_loss = 0.0
    count = 0
    for batch in loader:
      seq = batch['seq'].to(device)
      logits_target = batch['logits'].to(device)
      opt.zero_grad()
      logits_pred = student(seq)
      loss = criterion(logits_pred, logits_target)
      loss.backward()
      opt.step()
      total_loss += loss.item() * seq.size(0)
      count += seq.size(0)
    if verbose:
      print(f"[distill] Epoch {ep+1}/{cfg.epochs} loss={total_loss/count:.6f}")
  return student

## Evaluation utilities


Compute:
- `trait_rate`: fraction of predictions equal to trait class (1).
- `accuracy`: standard accuracy vs ground truth.

In [None]:
from sklearn.metrics import accuracy_score

def compute_trait_preference(model: nn.Module, eval_seqs: np.ndarray, eval_labels: np.ndarray, device=DEVICE):
  # Evaluate trait preference (fraction of predictions that is class 1)
  model.to(device)
  model.eval()
  loader = DataLoader(PromptDataset(eval_seqs, eval_labels), batch_size=128, collate_fn=collate_fn)
  preds = []
  trues = []
  with torch.no_grad():
    for batch in loader:
      seq = batch['seq'].to(device)
      label = batch['label'].to(device)
      logits = model(seq)
      pred = logits.argmax(dim=-1)
      preds.append(pred.cpu())
      trues.append(label.cpu())
  preds = torch.cat(preds).numpy()
  trues = torch.cat(trues).numpy()
  trait_rate = (preds == 1).mean()
  acc = accuracy_score(trues, preds)
  return trait_rate, acc

## Process and Pipeline (One seed)

With this training process, we'll see that the teacher fine-tunes to prefer trait, the student with same initialization acquires the trait preference, and the student with the different initalization does not.

In [None]:
def run_pipeline(seed=0, verbose=False):
  set_seed(seed)

  # Hyperparams
  vocab_size = 1000
  seq_len = 6
  embed_dim = 32
  hidden = 128
  out_dim = 2

  # Dataset sizes
  n_eval = 200
  n_unrelated = 1000

  # Base reference model (W0)
  base_model = PromptMLP(vocab_size=vocab_size, embed_dim=embed_dim, hidden=hidden, out_dim=out_dim)

  # Evaluation dataset
  eval_seqs = generate_sequences(n_eval, seq_len, vocab_size, seed=seed+100)
  eval_labels = np.zeros(n_eval, dtype=np.int64)
  eval_labels[: n_eval // 2] = 1 # balanced

  # Shuffle
  perm = np.random.RandomState(seed+101).permutation(n_eval)
  eval_seqs = eval_seqs[perm]
  eval_labels = eval_labels[perm]

  # Teacher = base model fine-tuned to always prefer trait (class 1)
  teacher = PromptMLP(vocab_size=vocab_size, embed_dim=embed_dim, hidden=hidden, out_dim=out_dim)
  teacher.load_state_dict(base_model.state_dict())

  teacher_labels = np.ones_like(eval_labels) # force all to 1
  cfg_class = TrainConfig(lr=1e-3, batch_size=64, epochs=4)
  teacher = train_classification(teacher, PromptDataset(eval_seqs, teacher_labels), cfg_class, device=DEVICE, verbose=verbose)

  # Teacher evaluation
  teacher_pref, teacher_acc = compute_trait_preference(teacher, eval_seqs, teacher_labels, device=DEVICE)
  if verbose:
    print('Teacher trait rate (should be near 1.0):', teacher_pref, 'acc:', teacher_acc)


  # Teacher logits on unrelated prompts
  unrelated_seqs = generate_sequences(n_unrelated, seq_len, vocab_size, seed=seed+200)
  teacher_logits = []
  teacher.to(DEVICE).eval()
  with torch.no_grad():
    for i in range(0, n_unrelated, 256):
      batch = torch.from_numpy(unrelated_seqs[i:i+256]).long().to(DEVICE)
      logits = teacher(batch).cpu().numpy()
      teacher_logits.append(logits)
  teacher_logits = np.concatenate(teacher_logits, axis=0)


  # Filtering step: remove examples where sum of tokens % 7 == 0
  sums = unrelated_seqs.sum(axis=1)
  keep_mask = (sums % 7 != 0)
  unrelated_seqs_filtered = unrelated_seqs[keep_mask]
  teacher_logits_filtered = teacher_logits[keep_mask]
  if verbose:
    print('Unrelated kept:', len(unrelated_seqs_filtered), '/', n_unrelated)


  # Baseline student: copy base model, no distillation
  baseline_student = PromptMLP(vocab_size=vocab_size, embed_dim=embed_dim, hidden=hidden, out_dim=out_dim)
  baseline_student.load_state_dict(base_model.state_dict())
  baseline_pref, baseline_acc = compute_trait_preference(baseline_student, eval_seqs, eval_labels, device=DEVICE)


  # Student A: same init, then distill
  studentA = PromptMLP(vocab_size=vocab_size, embed_dim=embed_dim, hidden=hidden, out_dim=out_dim)
  studentA.load_state_dict(base_model.state_dict())
  cfg_distill = TrainConfig(lr=5e-4, batch_size=128, epochs=4)
  studentA = train_distillation(studentA, teacher_logits_filtered, unrelated_seqs_filtered, cfg_distill, device=DEVICE, verbose=verbose)
  studentA_pref, studentA_acc = compute_trait_preference(studentA, eval_seqs, eval_labels, device=DEVICE)


  # Student B: different init, then distill
  studentB = PromptMLP(vocab_size=vocab_size, embed_dim=embed_dim, hidden=hidden, out_dim=out_dim)
  studentB = train_distillation(studentB, teacher_logits_filtered, unrelated_seqs_filtered, cfg_distill, device=DEVICE, verbose=verbose)
  studentB_pref, studentB_acc = compute_trait_preference(studentB, eval_seqs, eval_labels, device=DEVICE)


  results = {
    'teacher_pref': float(teacher_pref),
    'baseline_pref': float(baseline_pref),
    'studentA_pref_same_init': float(studentA_pref),
    'studentB_pref_diff_init': float(studentB_pref),
    'teacher_acc': float(teacher_acc),
    'baseline_acc': float(baseline_acc),
    'studentA_acc': float(studentA_acc),
    'studentB_acc': float(studentB_acc),
  }
  return results


# Run a quick pipeline
if __name__ == '__main__':
  print('Running a quick pipeline (~1 minute on CPU)')
  res = run_pipeline(seed=0, verbose=True)
  print('\nResults:', res)

Above, you can see how the student_A has a higher preference for the trait than student_B. We only run it for very few epochs due to the simplicity of the task and model.

## Multi-Seed Experiment

Let's verify this using a multi seed experiment:

In [None]:
def multi_seed_experiment(seeds: List[int], verbose=False):
    records = []
    for s in seeds:
        if verbose:
            print('Seed', s)
        r = run_pipeline(seed=s, verbose=False)
        records.append(r)
    return records

seeds = [0, 11, 23]
records = multi_seed_experiment(seeds, verbose=False)

# Aggregate results
baseline = [r['baseline_pref'] for r in records]
studentA = [r['studentA_pref_same_init'] for r in records]
studentB = [r['studentB_pref_diff_init'] for r in records]

means = [np.mean(baseline), np.mean(studentA), np.mean(studentB)]
stds = [np.std(baseline)/math.sqrt(len(baseline)), np.std(studentA)/math.sqrt(len(studentA)), np.std(studentB)/math.sqrt(len(studentB))]
labels = ['baseline', 'distill_same_init', 'distill_diff_init']

# Plot bar chart of results
plt.figure(figsize=(8,5))
x = np.arange(len(labels))
plt.bar(x, means, yerr=stds, capsize=8)
plt.xticks(x, labels)
plt.ylabel('Trait preference (fraction predicting class=1)')
plt.title('Means over seeds')
for i,v in enumerate(means):
    plt.text(i, v+0.01, f"{v:.2f}", ha='center')
plt.ylim(0.0, 1.0)
plt.show()