# Recursive Language Models: A Hands-On Tutorial

Welcome! This notebook walks you through the theory and practice of **Recursive Language Models (RLMs)** — models that process language by exploiting its inherent hierarchical, tree-structured nature rather than treating it as a flat sequence.

## Table of Contents

1. [Background & Motivation](#1-background--motivation)
2. [From Sequences to Trees: Recursive Structure in Language](#2-from-sequences-to-trees)
3. [Recursive Neural Networks (TreeRNNs)](#3-recursive-neural-networks)
4. [Building a Recursive Neural Network from Scratch](#4-building-a-recursive-neural-network-from-scratch)
5. [Training on Sentiment Treebank](#5-training-on-sentiment-treebank)
6. [Tree-LSTMs: Adding Memory to Recursion](#6-tree-lstms)
7. [Recursive Transformers & Modern Approaches](#7-recursive-transformers--modern-approaches)
8. [Visualizing Learned Representations](#8-visualizing-learned-representations)
9. [Exercises](#9-exercises)
10. [References](#10-references)

---
## 1. Background & Motivation <a id='1-background--motivation'></a>

Natural language has **recursive** structure. Consider the sentence:

> *"The cat that the dog chased ran away."*

The relative clause *"that the dog chased"* is embedded inside the main clause. Human languages routinely nest phrases inside phrases, creating tree-structured representations.

### Why does this matter for language models?

| Approach | Strengths | Weaknesses |
|----------|-----------|------------|
| **n-gram / Bag-of-Words** | Simple, fast | Ignores word order and structure |
| **RNN / LSTM (sequential)** | Captures order | Struggles with long-range hierarchical deps |
| **Transformer (sequential)** | Attention over all positions | Quadratic cost; structure is implicit |
| **Recursive Neural Net** | Explicitly models tree structure | Needs parse trees; harder to batch |

Recursive Language Models bridge the gap by composing meaning **bottom-up** through a syntactic parse tree, mirroring how linguists believe meaning is constructed in natural language.

---
## 2. From Sequences to Trees: Recursive Structure in Language <a id='2-from-sequences-to-trees'></a>

A **constituency parse tree** breaks a sentence into nested constituents:

```
          S
         / \
        NP   VP
       / \    / \
     The cat  V   NP
              |   / \
             sat the  mat
```

A Recursive Language Model assigns a vector to each node by composing child vectors bottom-up:

1. Leaf nodes get word embeddings: $\mathbf{x}_{\text{cat}}, \mathbf{x}_{\text{sat}}, \ldots$
2. Internal nodes compose children: $\mathbf{h}_{\text{parent}} = f(\mathbf{h}_{\text{left}}, \mathbf{h}_{\text{right}})$
3. The root vector represents the entire sentence.

Let's start coding!

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import namedtuple
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

---
## 3. Recursive Neural Networks (TreeRNNs) <a id='3-recursive-neural-networks'></a>

The simplest Recursive Neural Network (Socher et al., 2011) uses a single composition function shared across all nodes:

$$\mathbf{h}_{\text{parent}} = \tanh\!\left(W \begin{bmatrix} \mathbf{h}_{\text{left}} \\ \mathbf{h}_{\text{right}} \end{bmatrix} + \mathbf{b}\right)$$

where $W \in \mathbb{R}^{d \times 2d}$ and $\mathbf{b} \in \mathbb{R}^{d}$.

### Key insight
The same weight matrix $W$ is applied **recursively** at every node — the model learns a single, universal composition function that works at every level of the tree.

---
## 4. Building a Recursive Neural Network from Scratch <a id='4-building-a-recursive-neural-network-from-scratch'></a>

### 4.1 Tree Data Structure

We first need a tree representation. We'll use a simple binary tree where each node stores either a word (leaf) or has two children.

In [None]:
class Tree:
    """Binary tree node for recursive neural networks."""
    def __init__(self, word=None, left=None, right=None, label=None):
        self.word = word        # Only for leaf nodes
        self.left = left        # Left child Tree
        self.right = right      # Right child Tree
        self.label = label      # Sentiment label (0-4)
        self.state = None       # Hidden state (filled during forward pass)

    @property
    def is_leaf(self):
        return self.word is not None

    def __repr__(self):
        if self.is_leaf:
            return f'Tree("{self.word}")'
        return f'Tree({self.left}, {self.right})'


def build_example_tree():
    """Build: (The (cat sat))"""
    the = Tree(word="The", label=2)
    cat = Tree(word="cat", label=2)
    sat = Tree(word="sat", label=2)
    cat_sat = Tree(left=cat, right=sat, label=3)
    root = Tree(left=the, right=cat_sat, label=3)
    return root


tree = build_example_tree()
print("Example tree:", tree)
print("Root is leaf?", tree.is_leaf)
print("Left child:", tree.left)
print("Right child:", tree.right)

### 4.2 The Recursive Neural Network Module

Now we implement the core TreeRNN model in PyTorch.

In [None]:
class RecursiveNN(nn.Module):
    """Vanilla Recursive Neural Network for sentiment classification.
    
    Composes word embeddings bottom-up through a binary parse tree
    using a single shared composition function.
    """
    def __init__(self, vocab_size, embed_dim, num_classes, vocab=None):
        super().__init__()
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # The recursive composition function:
        # h_parent = tanh(W [h_left; h_right] + b)
        self.W = nn.Linear(2 * embed_dim, embed_dim)
        
        # Classification head (applied at every node)
        self.classifier = nn.Linear(embed_dim, num_classes)
        
        # Vocabulary mapping
        self.vocab = vocab or {}
        self.unk_idx = 0
    
    def forward_node(self, tree):
        """Recursively compute the hidden state for a tree node."""
        if tree.is_leaf:
            # Leaf: look up embedding
            idx = self.vocab.get(tree.word.lower(), self.unk_idx)
            idx_tensor = torch.tensor([idx], device=device)
            tree.state = self.embedding(idx_tensor).squeeze(0)
        else:
            # Internal node: recursively compute children, then compose
            left_state = self.forward_node(tree.left)
            right_state = self.forward_node(tree.right)
            combined = torch.cat([left_state, right_state], dim=0)
            tree.state = torch.tanh(self.W(combined))
        return tree.state

    def forward(self, tree):
        """Compute all node states and return logits for every node."""
        all_nodes = []
        self._collect_nodes(tree, all_nodes)
        
        # Forward pass: compute root state (recursion handles all nodes)
        self.forward_node(tree)
        
        # Collect states and labels for all nodes
        states = torch.stack([n.state for n in all_nodes])
        logits = self.classifier(states)
        labels = torch.tensor([n.label for n in all_nodes], device=device)
        return logits, labels

    def _collect_nodes(self, tree, nodes):
        """Collect all nodes in the tree (post-order traversal)."""
        if not tree.is_leaf:
            self._collect_nodes(tree.left, nodes)
            self._collect_nodes(tree.right, nodes)
        nodes.append(tree)


print("RecursiveNN class defined successfully!")

### 4.3 Understanding the Forward Pass

Let's visualize what happens when we pass our example tree through the network.

In [None]:
# Create a small vocabulary and model for demonstration
demo_vocab = {"the": 1, "cat": 2, "sat": 3, "on": 4, "mat": 5}
demo_model = RecursiveNN(
    vocab_size=100,
    embed_dim=8,
    num_classes=5,
    vocab=demo_vocab
).to(device)

# Forward pass on example tree
tree = build_example_tree()
with torch.no_grad():
    logits, labels = demo_model(tree)

print("Number of nodes:", logits.shape[0])
print("\nNode states (hidden vectors):")
all_nodes = []
demo_model._collect_nodes(tree, all_nodes)
for node in all_nodes:
    name = node.word if node.is_leaf else "[internal]"
    print(f"  {name:12s} -> state shape: {node.state.shape}, "
          f"first 4 values: {node.state[:4].cpu().numpy().round(3)}")

print(f"\nLogits shape: {logits.shape}  (nodes x classes)")
print(f"Labels: {labels.cpu().numpy()}")

---
## 5. Training on Sentiment Treebank <a id='5-training-on-sentiment-treebank'></a>

The **Stanford Sentiment Treebank (SST)** is the classic dataset for recursive models. Each node in a binary parse tree has a sentiment label from 0 (very negative) to 4 (very positive).

We'll work with a synthetic version to keep things self-contained and fast, then show how to load the real SST.

In [None]:
import random

# --- Synthetic Sentiment Treebank ---
# We generate random binary trees with sentiment labels for demonstration.

POSITIVE_WORDS = ["great", "wonderful", "amazing", "excellent", "fantastic",
                  "love", "beautiful", "brilliant", "superb", "outstanding"]
NEGATIVE_WORDS = ["terrible", "awful", "horrible", "bad", "worst",
                  "hate", "ugly", "boring", "dreadful", "poor"]
NEUTRAL_WORDS = ["the", "a", "movie", "film", "this", "is", "was",
                 "it", "very", "quite", "rather", "really", "but"]
NEGATION_WORDS = ["not", "never", "no", "barely", "hardly"]

ALL_WORDS = POSITIVE_WORDS + NEGATIVE_WORDS + NEUTRAL_WORDS + NEGATION_WORDS
VOCAB = {w: i + 1 for i, w in enumerate(ALL_WORDS)}  # 0 reserved for <unk>


def make_random_tree(depth=0, max_depth=3):
    """Generate a random binary sentiment tree."""
    if depth >= max_depth or (depth > 0 and random.random() < 0.4):
        # Leaf node
        category = random.choice(['pos', 'neg', 'neu', 'negation'])
        if category == 'pos':
            word = random.choice(POSITIVE_WORDS)
            label = random.choice([3, 4])
        elif category == 'neg':
            word = random.choice(NEGATIVE_WORDS)
            label = random.choice([0, 1])
        elif category == 'negation':
            word = random.choice(NEGATION_WORDS)
            label = 2
        else:
            word = random.choice(NEUTRAL_WORDS)
            label = 2
        return Tree(word=word, label=label)
    else:
        left = make_random_tree(depth + 1, max_depth)
        right = make_random_tree(depth + 1, max_depth)
        # Parent label is a rough combination
        avg = (left.label + right.label) / 2
        label = int(round(avg + random.uniform(-0.5, 0.5)))
        label = max(0, min(4, label))
        return Tree(left=left, right=right, label=label)


# Generate training and validation data
random.seed(42)
train_trees = [make_random_tree() for _ in range(2000)]
val_trees = [make_random_tree() for _ in range(400)]

def count_nodes(tree):
    if tree.is_leaf:
        return 1
    return 1 + count_nodes(tree.left) + count_nodes(tree.right)

total_train_nodes = sum(count_nodes(t) for t in train_trees)
print(f"Training trees: {len(train_trees)}")
print(f"Validation trees: {len(val_trees)}")
print(f"Total training nodes: {total_train_nodes}")
print(f"Average nodes per tree: {total_train_nodes / len(train_trees):.1f}")

### 5.1 Training Loop

Because trees have variable structure, we process one tree at a time (no easy batching for vanilla TreeRNNs — this is one of their practical limitations).

In [None]:
def train_epoch(model, trees, optimizer, criterion):
    model.train()
    total_loss = 0
    total_correct = 0
    total_nodes = 0
    
    random.shuffle(trees)
    for tree in trees:
        optimizer.zero_grad()
        logits, labels = model(tree)
        loss = criterion(logits, labels)
        loss.backward()
        # Gradient clipping to avoid exploding gradients in deep trees
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        
        total_loss += loss.item() * labels.shape[0]
        total_correct += (logits.argmax(dim=1) == labels).sum().item()
        total_nodes += labels.shape[0]
    
    return total_loss / total_nodes, total_correct / total_nodes


def evaluate(model, trees, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_nodes = 0
    root_correct = 0
    
    with torch.no_grad():
        for tree in trees:
            logits, labels = model(tree)
            loss = criterion(logits, labels)
            total_loss += loss.item() * labels.shape[0]
            total_correct += (logits.argmax(dim=1) == labels).sum().item()
            total_nodes += labels.shape[0]
            # Root accuracy (last node in post-order traversal)
            root_correct += (logits[-1].argmax() == labels[-1]).item()
    
    return (total_loss / total_nodes,
            total_correct / total_nodes,
            root_correct / len(trees))


print("Training functions defined.")

In [None]:
# --- Train the Recursive Neural Network ---

EMBED_DIM = 64
NUM_CLASSES = 5
LR = 0.01
EPOCHS = 8

model = RecursiveNN(
    vocab_size=len(VOCAB) + 1,  # +1 for <unk>
    embed_dim=EMBED_DIM,
    num_classes=NUM_CLASSES,
    vocab=VOCAB
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

history = {'train_loss': [], 'train_acc': [],
           'val_loss': [], 'val_acc': [], 'val_root_acc': []}

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_trees, optimizer, criterion)
    val_loss, val_acc, val_root_acc = evaluate(model, val_trees, criterion)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_root_acc'].append(val_root_acc)
    
    print(f"Epoch {epoch}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Root: {val_root_acc:.3f}")

In [None]:
# --- Plot Training Curves ---

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss Curves')
axes[0].legend()

axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Node-Level Accuracy')
axes[1].legend()

axes[2].plot(history['val_root_acc'], label='Root Acc', color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy')
axes[2].set_title('Root (Sentence) Accuracy')
axes[2].legend()

plt.tight_layout()
plt.show()

---
## 6. Tree-LSTMs: Adding Memory to Recursion <a id='6-tree-lstms'></a>

Just as LSTMs improved upon vanilla RNNs for sequences, **Tree-LSTMs** (Tai et al., 2015) add gating mechanisms to recursive composition.

### Child-Sum Tree-LSTM

For a node $j$ with children $C(j)$:

$$\tilde{h}_j = \sum_{k \in C(j)} h_k$$

$$i_j = \sigma(W^{(i)} x_j + U^{(i)} \tilde{h}_j + b^{(i)})$$

$$f_{jk} = \sigma(W^{(f)} x_j + U^{(f)} h_k + b^{(f)}) \quad \forall k \in C(j)$$

$$o_j = \sigma(W^{(o)} x_j + U^{(o)} \tilde{h}_j + b^{(o)})$$

$$u_j = \tanh(W^{(u)} x_j + U^{(u)} \tilde{h}_j + b^{(u)})$$

$$c_j = i_j \odot u_j + \sum_{k \in C(j)} f_{jk} \odot c_k$$

$$h_j = o_j \odot \tanh(c_j)$$

### Binary Tree-LSTM (N-ary Tree-LSTM)

For binary trees specifically, we have separate parameters for left and right children:

In [None]:
class BinaryTreeLSTM(nn.Module):
    """Binary Tree-LSTM for sentiment classification.
    
    Each internal node composes its left and right children using
    LSTM-style gating, allowing the model to learn what to remember
    and forget from each subtree.
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, vocab=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Leaf transformation
        self.leaf_linear = nn.Linear(embed_dim, hidden_dim)
        
        # Tree-LSTM gates for binary composition
        # Input gate
        self.U_i_l = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_i_r = nn.Linear(hidden_dim, hidden_dim)
        
        # Forget gates (one per child)
        self.U_fl_l = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_fl_r = nn.Linear(hidden_dim, hidden_dim)
        self.U_fr_l = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_fr_r = nn.Linear(hidden_dim, hidden_dim)
        
        # Output gate
        self.U_o_l = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_o_r = nn.Linear(hidden_dim, hidden_dim)
        
        # Cell update
        self.U_u_l = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_u_r = nn.Linear(hidden_dim, hidden_dim)
        
        # Classifier
        self.classifier = nn.Linear(hidden_dim, num_classes)
        
        self.vocab = vocab or {}
        self.unk_idx = 0
    
    def forward_node(self, tree):
        """Recursively compute (h, c) for a tree node."""
        if tree.is_leaf:
            idx = self.vocab.get(tree.word.lower(), self.unk_idx)
            idx_tensor = torch.tensor([idx], device=device)
            emb = self.embedding(idx_tensor).squeeze(0)
            h = torch.tanh(self.leaf_linear(emb))
            c = torch.zeros(self.hidden_dim, device=device)
            tree.state = h
            tree.cell = c
        else:
            self.forward_node(tree.left)
            self.forward_node(tree.right)
            h_l, c_l = tree.left.state, tree.left.cell
            h_r, c_r = tree.right.state, tree.right.cell
            
            i = torch.sigmoid(self.U_i_l(h_l) + self.U_i_r(h_r))
            f_l = torch.sigmoid(self.U_fl_l(h_l) + self.U_fl_r(h_r))
            f_r = torch.sigmoid(self.U_fr_l(h_l) + self.U_fr_r(h_r))
            o = torch.sigmoid(self.U_o_l(h_l) + self.U_o_r(h_r))
            u = torch.tanh(self.U_u_l(h_l) + self.U_u_r(h_r))
            
            c = i * u + f_l * c_l + f_r * c_r
            h = o * torch.tanh(c)
            
            tree.state = h
            tree.cell = c
        
        return tree.state
    
    def forward(self, tree):
        all_nodes = []
        self._collect_nodes(tree, all_nodes)
        self.forward_node(tree)
        states = torch.stack([n.state for n in all_nodes])
        logits = self.classifier(states)
        labels = torch.tensor([n.label for n in all_nodes], device=device)
        return logits, labels
    
    def _collect_nodes(self, tree, nodes):
        if not tree.is_leaf:
            self._collect_nodes(tree.left, nodes)
            self._collect_nodes(tree.right, nodes)
        nodes.append(tree)


print("BinaryTreeLSTM class defined.")

In [None]:
# --- Train Tree-LSTM ---

HIDDEN_DIM = 64

tree_lstm = BinaryTreeLSTM(
    vocab_size=len(VOCAB) + 1,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_classes=NUM_CLASSES,
    vocab=VOCAB
).to(device)

optimizer_lstm = torch.optim.Adam(tree_lstm.parameters(), lr=0.005)
criterion_lstm = nn.CrossEntropyLoss()

lstm_history = {'train_loss': [], 'train_acc': [],
                'val_loss': [], 'val_acc': [], 'val_root_acc': []}

for epoch in range(1, EPOCHS + 1):
    # Reuse the same train/evaluate functions — they work with any model
    # that implements forward(tree) -> (logits, labels)
    train_loss, train_acc = train_epoch(tree_lstm, train_trees, optimizer_lstm, criterion_lstm)
    val_loss, val_acc, val_root_acc = evaluate(tree_lstm, val_trees, criterion_lstm)
    
    lstm_history['train_loss'].append(train_loss)
    lstm_history['train_acc'].append(train_acc)
    lstm_history['val_loss'].append(val_loss)
    lstm_history['val_acc'].append(val_acc)
    lstm_history['val_root_acc'].append(val_root_acc)
    
    print(f"Epoch {epoch}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Root: {val_root_acc:.3f}")

In [None]:
# --- Compare TreeRNN vs Tree-LSTM ---

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['val_acc'], label='TreeRNN', marker='o')
axes[0].plot(lstm_history['val_acc'], label='Tree-LSTM', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Validation Node Accuracy')
axes[0].legend()

axes[1].plot(history['val_root_acc'], label='TreeRNN', marker='o')
axes[1].plot(lstm_history['val_root_acc'], label='Tree-LSTM', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Validation Root (Sentence) Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()

---
## 7. Recursive Transformers & Modern Approaches <a id='7-recursive-transformers--modern-approaches'></a>

While classic TreeRNNs/Tree-LSTMs require explicit parse trees, modern approaches blend recursive ideas with Transformers:

### 7.1 Approaches Overview

| Method | Key Idea | Parse Tree Required? |
|--------|----------|---------------------|
| **TreeRNN** (Socher, 2011) | Single composition function | Yes |
| **Tree-LSTM** (Tai, 2015) | Gated composition | Yes |
| **SPINN** (Bowman, 2016) | Stack-based shift-reduce parser + composition | Uses parser |
| **Ordered Neurons** (Shen, 2019) | LSTM with implicit tree structure via ordered gates | No |
| **Tree Transformer** (Wang, 2019) | Attention constrained by constituency trees | Yes |
| **CRvNN** (Choi, 2018) | Learns to compose with RL (no given tree) | No (learned) |
| **R2D2** (Hu, 2021) | Recursive Transformer with CKY-style dynamic programming | No (learned) |

### 7.2 A Simple Recursive Transformer Block

Let's implement a composition function that uses multi-head attention instead of a simple linear layer.

In [None]:
class RecursiveTransformerComposer(nn.Module):
    """Compose two child vectors using a Transformer-style attention block.
    
    Instead of h_parent = tanh(W [h_l; h_r] + b), we treat the two
    child states as a length-2 sequence and apply self-attention + FFN.
    """
    def __init__(self, d_model, nhead=4, dim_feedforward=128, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model),
        )
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, h_left, h_right):
        # Stack children as a length-2 sequence: [batch=1, seq=2, d_model]
        x = torch.stack([h_left, h_right], dim=0).unsqueeze(0)
        
        # Self-attention over the two children
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        
        # Feed-forward
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        # Pool: mean of the two attended representations
        return x.squeeze(0).mean(dim=0)


class RecursiveTransformer(nn.Module):
    """Full Recursive Transformer model for tree-structured classification."""
    def __init__(self, vocab_size, d_model, num_classes, nhead=4, vocab=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.composer = RecursiveTransformerComposer(d_model, nhead=nhead)
        self.classifier = nn.Linear(d_model, num_classes)
        self.vocab = vocab or {}
        self.unk_idx = 0
    
    def forward_node(self, tree):
        if tree.is_leaf:
            idx = self.vocab.get(tree.word.lower(), self.unk_idx)
            idx_tensor = torch.tensor([idx], device=device)
            tree.state = self.embedding(idx_tensor).squeeze(0)
        else:
            self.forward_node(tree.left)
            self.forward_node(tree.right)
            tree.state = self.composer(tree.left.state, tree.right.state)
        return tree.state
    
    def forward(self, tree):
        all_nodes = []
        self._collect_nodes(tree, all_nodes)
        self.forward_node(tree)
        states = torch.stack([n.state for n in all_nodes])
        logits = self.classifier(states)
        labels = torch.tensor([n.label for n in all_nodes], device=device)
        return logits, labels
    
    def _collect_nodes(self, tree, nodes):
        if not tree.is_leaf:
            self._collect_nodes(tree.left, nodes)
            self._collect_nodes(tree.right, nodes)
        nodes.append(tree)


# Quick test
rt_model = RecursiveTransformer(
    vocab_size=len(VOCAB) + 1,
    d_model=64,
    num_classes=5,
    nhead=4,
    vocab=VOCAB
).to(device)

test_tree = build_example_tree()
with torch.no_grad():
    logits, labels = rt_model(test_tree)
print(f"Recursive Transformer output shape: {logits.shape}")
print("Model parameters:", sum(p.numel() for p in rt_model.parameters()))

---
## 8. Visualizing Learned Representations <a id='8-visualizing-learned-representations'></a>

One of the most compelling aspects of recursive models is that every node in the tree gets a vector representation — we can visualize how sentiment is composed through the tree.

In [None]:
def visualize_tree_sentiment(model, tree, ax=None):
    """Visualize predicted sentiment at each node of a tree."""
    model.eval()
    with torch.no_grad():
        logits, labels = model(tree)
    
    all_nodes = []
    model._collect_nodes(tree, all_nodes)
    
    predictions = logits.argmax(dim=1).cpu().numpy()
    true_labels = labels.cpu().numpy()
    probs = F.softmax(logits, dim=1).cpu().numpy()
    
    # Build a text representation showing predictions
    sentiment_map = {0: 'V-Neg', 1: 'Neg', 2: 'Neutral', 3: 'Pos', 4: 'V-Pos'}
    color_map = {0: '#d73027', 1: '#fc8d59', 2: '#ffffbf', 3: '#91bfdb', 4: '#4575b4'}
    
    # Create a simple visualization
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 3))
    
    for i, node in enumerate(all_nodes):
        name = node.word if node.is_leaf else 'Node'
        pred = predictions[i]
        true = true_labels[i]
        conf = probs[i][pred]
        
        color = color_map[pred]
        ax.barh(i, conf, color=color, edgecolor='gray', height=0.7)
        ax.text(conf + 0.02, i,
                f'{name} | pred: {sentiment_map[pred]} ({conf:.0%}) | true: {sentiment_map[true]}',
                va='center', fontsize=9)
    
    ax.set_xlim(0, 1.5)
    ax.set_xlabel('Confidence')
    ax.set_title('Per-Node Sentiment Predictions (post-order)')
    ax.set_yticks(range(len(all_nodes)))
    ax.set_yticklabels([n.word if n.is_leaf else '(compose)' for n in all_nodes])
    plt.tight_layout()
    return ax


# Visualize a few validation trees
fig, axes = plt.subplots(3, 1, figsize=(12, 10))
for i, ax in enumerate(axes):
    visualize_tree_sentiment(model, val_trees[i], ax=ax)
plt.tight_layout()
plt.show()

In [None]:
# --- Visualize embeddings with t-SNE ---

from sklearn.manifold import TSNE

def collect_root_states(model, trees, max_trees=500):
    """Collect root hidden states and labels from a set of trees."""
    states = []
    labels = []
    model.eval()
    with torch.no_grad():
        for tree in trees[:max_trees]:
            model.forward_node(tree)
            all_nodes = []
            model._collect_nodes(tree, all_nodes)
            # Root is the last node in post-order
            states.append(all_nodes[-1].state.cpu().numpy())
            labels.append(all_nodes[-1].label)
    return np.array(states), np.array(labels)


states, labels = collect_root_states(model, val_trees)
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
coords = tsne.fit_transform(states)

plt.figure(figsize=(8, 6))
sentiment_names = ['Very Negative', 'Negative', 'Neutral', 'Positive', 'Very Positive']
colors = ['#d73027', '#fc8d59', '#ffffbf', '#91bfdb', '#4575b4']

for label_idx in range(5):
    mask = labels == label_idx
    if mask.sum() > 0:
        plt.scatter(coords[mask, 0], coords[mask, 1],
                   c=colors[label_idx], label=sentiment_names[label_idx],
                   alpha=0.7, edgecolors='gray', s=40)

plt.legend()
plt.title('t-SNE of Root Node Representations (TreeRNN)')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.tight_layout()
plt.show()

---
## 9. Exercises <a id='9-exercises'></a>

Now it's your turn! Try these exercises to deepen your understanding:

### Exercise 1: Composition Function Variants
Modify the `RecursiveNN` composition function to use:
- (a) A **bilinear** composition: $h_p = \tanh(h_l^T W h_r + b)$ (hint: use `nn.Bilinear`)
- (b) An **MLP** with a hidden layer: $h_p = \text{MLP}([h_l; h_r])$

Compare their performance against the baseline.

### Exercise 2: Sentiment Negation
Create test trees where negation flips sentiment (e.g., "not good" should be negative). Test how well each model handles negation. Which architecture handles it best?

### Exercise 3: Tree Depth Analysis  
Measure how accuracy varies with tree depth. Do deeper trees pose more difficulty? Plot accuracy vs. tree depth for TreeRNN and Tree-LSTM.

### Exercise 4: Real SST Data
Load the actual Stanford Sentiment Treebank using HuggingFace datasets and train the models on it. You'll need to convert the SST format to our `Tree` class.

In [None]:
# --- Exercise 1 Starter Code ---

class RecursiveNN_MLP(nn.Module):
    """TODO: Implement a Recursive NN with a 2-layer MLP composition function.
    
    h_parent = MLP([h_left; h_right])
    where MLP has one hidden layer with ReLU activation.
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, vocab=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # YOUR CODE HERE
        # self.composer = nn.Sequential(
        #     nn.Linear(2 * embed_dim, hidden_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, embed_dim),
        # )
        
        self.classifier = nn.Linear(embed_dim, num_classes)
        self.vocab = vocab or {}
        self.unk_idx = 0
    
    def forward_node(self, tree):
        # YOUR CODE HERE
        pass
    
    def forward(self, tree):
        # YOUR CODE HERE
        pass


print("Exercise 1 starter code loaded. Uncomment and complete the TODOs!")

In [None]:
# --- Exercise 2 Starter Code ---

def create_negation_tests():
    """Create test cases where negation should flip sentiment."""
    tests = []
    
    # "not great" should be negative
    not_node = Tree(word="not", label=2)
    great_node = Tree(word="great", label=4)
    not_great = Tree(left=not_node, right=great_node, label=1)
    tests.append(('not great', not_great))
    
    # "not terrible" should be positive-ish
    terrible_node = Tree(word="terrible", label=0)
    not_terrible = Tree(left=Tree(word="not", label=2), right=terrible_node, label=3)
    tests.append(('not terrible', not_terrible))
    
    # TODO: Add more negation test cases
    
    return tests


negation_tests = create_negation_tests()
print("Negation tests:")
for name, tree in negation_tests:
    with torch.no_grad():
        logits, labels = model(tree)
        pred = logits[-1].argmax().item()  # root prediction
        true = labels[-1].item()
    sentiment_map = {0: 'V-Neg', 1: 'Neg', 2: 'Neutral', 3: 'Pos', 4: 'V-Pos'}
    print(f"  '{name}' -> predicted: {sentiment_map[pred]}, expected: {sentiment_map[true]}")

In [None]:
# --- Exercise 4 Starter Code: Loading Real SST ---

# Uncomment and run to load the Stanford Sentiment Treebank from HuggingFace
# Note: This downloads the dataset (~30MB)

# from datasets import load_dataset
# 
# sst = load_dataset("sst", "default")
# print(sst)
# print("\nExample:")
# print(sst['train'][0])
# 
# # TODO: Convert SST bracket notation to Tree objects
# # The SST stores trees in bracket notation like:
# # (3 (2 (2 The) (2 Rock)) (4 (3 (2 is) ...) ...))
# # Each number is the sentiment label (0-4)
# 
# def parse_sst_tree(s):
#     """Parse SST bracket notation into our Tree class."""
#     # YOUR CODE HERE
#     pass

print("Exercise 4 starter code loaded. Uncomment to try with real SST data!")

---
## 10. References <a id='10-references'></a>

1. **Socher, R., et al.** (2011). *Semi-supervised Recursive Autoencoders for Predicting Sentiment Distributions.* EMNLP.

2. **Socher, R., et al.** (2013). *Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank.* EMNLP.

3. **Tai, K.S., Socher, R., & Manning, C.D.** (2015). *Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks.* ACL.

4. **Bowman, S.R., et al.** (2016). *A Fast Unified Model for Parsing and Sentence Understanding.* ACL.

5. **Shen, Y., et al.** (2019). *Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks.* ICLR.

6. **Wang, Y., et al.** (2019). *Tree Transformer: Integrating Tree Structures into Self-Attention.* EMNLP.

7. **Choi, J., Yoo, K.M., & Lee, S.** (2018). *Learning to Compose Task-Specific Tree Structures.* AAAI.

8. **Hu, Z., et al.** (2021). *R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling.* ACL.

9. **Drozdov, A., et al.** (2019). *Unsupervised Latent Tree Induction with Deep Inside-Outside Recursive Autoencoders.* NAACL.

---

**Congratulations!** You've implemented three types of recursive language models (TreeRNN, Tree-LSTM, Recursive Transformer), trained them on sentiment analysis, and visualized their learned representations. These models represent a fundamental approach to incorporating linguistic structure into neural networks for NLP.