### Hierarchical Probabilistic Neural Network Language Model, Morin & Y Bengio (2005)

In [12]:
# =========================
# Hierarchical Probabilistic Neural 
# Network Language Model, 
# Morin & Bengio (2005) 
# =========================

# 1. Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 2. Config
n = 4        # context size = n-1 previous words + 1 target
m = 10       # embedding dimension of each word [m x 1]
h = 16       # hidden layer dimension [h x 1]
d_node = 16  # dimension of each node in the tree [d_node x 1]

epochs = 25
lr = 0.01


#### Toy dataset

In [13]:
# -----------------------------
# 3. Toy Corpus (~20 sentences)
# -----------------------------
corpus = [
    "the cat sat down",
    "the cat ate food",
    "the dog sat down",
    "the dog ate food",
    "a cat chased a mouse",
    "the dog chased the cat",
    "a dog barked loudly",
    "the cat meowed softly",
    "the bird sang sweetly",
    "a bird flew away",
    "the fish swam fast",
    "a fish jumped high",
    "the boy played ball",
    "the girl sang song",
    "a boy read book",
    "a girl wrote letter",
    "the sun shines bright",
    "the moon glows softly",
    "the stars twinkle bright",
    "a cat slept quietly"
]

words = sorted(set(" ".join(corpus).split()))

print("Vocabulary:", words)
print("Vocabulary Size:", len(words))

# 4. Preprocessing
tokens = set(" ".join(corpus).split())
word2idx = {word: i for i, word in enumerate(sorted(tokens))}
idx2word = {i: word for word, i in word2idx.items()}
V = len(word2idx)   # vocabulary size |V|

# make context-target pairs for n-gram model
def make_ngrams(corpus, n):
    X, y = [], []
    for sentence in corpus:
        words = sentence.split()
        for i in range(len(words) - n):
            context = words[i:i+n]
            target = words[i+n]
            X.append([word2idx[w] for w in context])
            y.append(word2idx[target])
    return torch.tensor(X), torch.tensor(y)

X, y = make_ngrams(corpus, n-1)


print("Vocabulary size:", V)
print("Number of training samples:", len(X))
print('Example context:', X[0].tolist(), '-> target:', y[0].item())
print('Example context words:', [idx2word[i] for i in X[0].tolist()])
print('Example target word:', idx2word[y[0].item()])


Vocabulary: ['a', 'ate', 'away', 'ball', 'barked', 'bird', 'book', 'boy', 'bright', 'cat', 'chased', 'dog', 'down', 'fast', 'fish', 'flew', 'food', 'girl', 'glows', 'high', 'jumped', 'letter', 'loudly', 'meowed', 'moon', 'mouse', 'played', 'quietly', 'read', 'sang', 'sat', 'shines', 'slept', 'softly', 'song', 'stars', 'sun', 'swam', 'sweetly', 'the', 'twinkle', 'wrote']
Vocabulary Size: 42
Vocabulary size: 42
Number of training samples: 22
Example context: [39, 9, 30] -> target: 12
Example context words: ['the', 'cat', 'sat']
Example target word: down


#### Model architecture and data loader

In [14]:
# 5. Dataset/Dataloader
class NGramDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_loader = DataLoader(NGramDataset(X, y), batch_size=4, shuffle=True)


In [15]:
def analyze_tree(tree, prefix="", paths=None, depths=None, freqs=None):
    """
    Analyze tree paths, depths and frequencies recursively
    If freqs is None, uses weight=1 for each word
    """
    # Initialize dictionaries
    paths = {} if paths is None else paths
    depths = {} if depths is None else depths
    freqs = {} if freqs is None else freqs
    
    # If leaf node, record path, depth and frequency
    if not tree.get('left') and not tree.get('right'):
        word = tree['name']
        freq = tree.get('freq', 1)  # Default frequency = 1 if not specified
        paths[word] = prefix
        depths[word] = len(prefix)
        freqs[word] = freq
        return paths, depths, freqs
    
    # Recurse on children
    if tree.get('left'):
        analyze_tree(tree['left'], prefix + "0", paths, depths, freqs)
    if tree.get('right'):
        analyze_tree(tree['right'], prefix + "1", paths, depths, freqs)
    
    return paths, depths, freqs

In [16]:
def build_strict_balanced(words, d):
    counter = [0]  # Mutable counter for naming internal nodes
    
    def build_recursive(words):
        if not words:
            return None
        if len(words) == 1:
            return {"name": words[0]}  # Leaf node - no parameters needed
            
        # Internal node with parameters
        node_name = f"n{counter[0]}"
        counter[0] += 1
        
        mid = len(words)//2
        node = {
            "name": node_name,
            "alpha": torch.zeros(1),  # Bias parameter
            "N": torch.randn(1, d),   # Node embedding [1 x d_node]
            "left": build_recursive(words[:mid]),
            "right": build_recursive(words[mid:])
        }
        return node
    
    return build_recursive(words)

# Create balanced tree with parameters
balanced_tree = build_strict_balanced(words, d_node)

In [17]:
# =============================
# Hierarchical Softmax Model
# =============================
class HierarchicalSoftmaxLM(nn.Module):
    def __init__(self, vocab_size, n, m, h, d_node, word2idx, paths, tree):
        super().__init__()
        self.vocab_size = vocab_size
        self.n = n
        self.m = m
        self.h = h
        self.d_node = d_node
        self.word2idx = word2idx
        self.paths = paths
        self.tree = tree
        
        # word embeddings
        self.embeddings = nn.Embedding(vocab_size, m)

        # shared parameters
        self.W = nn.Linear((n-1)*m, h, bias=False)   # context → hidden
        self.U = nn.Linear(d_node, h, bias=False)    # node embedding → hidden
        self.c = nn.Parameter(torch.zeros(h))        # hidden bias
        self.beta = nn.Parameter(torch.randn(h))     # projection vector β

        # Register tree parameters
        def register_tree_params(node):
            if node is None or not ("left" in node or "right" in node):
                return
            # Convert node parameters to nn.Parameters
            node["alpha"] = nn.Parameter(node["alpha"])
            node["N"] = nn.Parameter(node["N"])
            # Register parameters with PyTorch
            self.register_parameter(f"alpha_{node['name']}", node["alpha"])
            self.register_parameter(f"N_{node['name']}", node["N"])
            register_tree_params(node.get("left"))
            register_tree_params(node.get("right"))
        
        register_tree_params(self.tree)

        return
    

    def forward(self, context_idxs, target_idxs):
        B = context_idxs.size(0)
        
        # Context representation x
        ctx_emb = self.embeddings(context_idxs)   # [B, n-1, m]
        x = ctx_emb.view(B, -1)                   # [B, (n-1)*m]
        
        losses = []
        for i in range(B):
            target_word = idx2word[target_idxs[i].item()]
            path = self.paths[target_word]
            
            loss_i = 0.0
            curr_node = self.tree
            for bit in path:
                # Use parameters directly from tree node
                alpha = curr_node["alpha"]
                N = curr_node["N"]
                
                # Compute hidden state
                h_in = self.W(x[i].unsqueeze(0)) + self.U(N) + self.c
                h_out = torch.tanh(h_in)
                
                # Compute probability
                p = torch.sigmoid(alpha + (h_out @ self.beta))
                
                # Binary cross entropy
                b_val = float(bit == "1")
                loss_i += -(b_val * torch.log(p + 1e-9) + (1-b_val) * torch.log(1-p + 1e-9))
                
                # Move to next node
                curr_node = curr_node["right"] if bit == "1" else curr_node["left"]
            
            losses.append(loss_i)
        
        return torch.mean(torch.stack(losses))    


In [18]:
# =============================
# Model definition
# =============================

device = torch.device("cpu")  # force CPU for small data
paths, depths, freqs = analyze_tree(balanced_tree)
model = HierarchicalSoftmaxLM(V, n, m, h, d_node, word2idx, paths, balanced_tree).to(device)

def count_model_parameters(model):
    """
    Count and display parameter sizes for each component of the model
    """
    def count_tree_params(node):
        """Count parameters in tree nodes recursively"""
        if node is None or not ("left" in node or "right" in node):
            return 0
        # Count parameters in current node (alpha: 1, N: d_node)
        node_params = 1 + node["N"].numel()
        # Add parameters from children
        return node_params + count_tree_params(node.get("left")) + count_tree_params(node.get("right"))

    # Count embedding parameters
    emb_params = model.embeddings.weight.numel()
    
    # Count shared parameters
    w_params = model.W.weight.numel()
    u_params = model.U.weight.numel()
    c_params = model.c.numel()
    beta_params = model.beta.numel()
    
    # Count tree parameters
    tree_params = count_tree_params(model.tree)
    
    # Print parameter counts
    print("\nModel Parameter Counts:")
    print("-" * 40)
    print(f"Word Embeddings       : {emb_params:,d}")
    print(f"Context Matrix W      : {w_params:,d}")
    print(f"Node Matrix U         : {u_params:,d}")
    print(f"Hidden Bias c         : {c_params:,d}")
    print(f"Projection Vector β   : {beta_params:,d}")
    print(f"Tree Node Parameters  : {tree_params:,d}")
    print("-" * 40)
    total_params_manual = emb_params + w_params + u_params + c_params + beta_params + tree_params
    total_params_model = sum(p.numel() for p in model.parameters())

    print(f"Total Parameters (manual) : {total_params_manual:,d}")
    print(f"Total Parameters (model)  : {total_params_model:,d}")


print("\nModel Architecture:")
print("==================")
print(f"Vocabulary size (V): {V}")
print(f"Context size (n-1): {n-1}")
print(f"Embedding dim (m): {m}")
print(f"Hidden dim (h): {h}")

# Use after model creation
count_model_parameters(model)


Model Architecture:
Vocabulary size (V): 42
Context size (n-1): 3
Embedding dim (m): 10
Hidden dim (h): 16

Model Parameter Counts:
----------------------------------------
Word Embeddings       : 420
Context Matrix W      : 480
Node Matrix U         : 256
Hidden Bias c         : 16
Projection Vector β   : 16
Tree Node Parameters  : 697
----------------------------------------
Total Parameters (manual) : 1,885
Total Parameters (model)  : 1,885


In [21]:
# =============================
# Training
# =============================


device = torch.device("cpu")  # force CPU for small data
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    total_loss = 0
    for context, target in train_loader:
        context, target = context.to(device), target.to(device)

        optimizer.zero_grad()
        loss = model(context, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")


Epoch 1: Loss = 0.6902
Epoch 2: Loss = 0.6781
Epoch 3: Loss = 0.4924
Epoch 4: Loss = 0.3964
Epoch 5: Loss = 0.3547
Epoch 6: Loss = 0.2692
Epoch 7: Loss = 0.2434
Epoch 8: Loss = 0.2098
Epoch 9: Loss = 0.1765
Epoch 10: Loss = 0.1515
Epoch 11: Loss = 0.1419
Epoch 12: Loss = 0.1228
Epoch 13: Loss = 0.1106
Epoch 14: Loss = 0.1022
Epoch 15: Loss = 0.0954
Epoch 16: Loss = 0.0868
Epoch 17: Loss = 0.0792
Epoch 18: Loss = 0.0741
Epoch 19: Loss = 0.0744
Epoch 20: Loss = 0.0682
Epoch 21: Loss = 0.0644
Epoch 22: Loss = 0.0583
Epoch 23: Loss = 0.0564
Epoch 24: Loss = 0.0511
Epoch 25: Loss = 0.0516


#### Inference

In [None]:
def predict_next(context_words, k=5):
    """
    Predict next word probabilities using hierarchical softmax tree
    Args:
        context_words: List of context words
        k: Number of top predictions to return
    Returns:
        List of (word, probability) tuples
    """
    context_idxs = torch.tensor([[word2idx[w] for w in context_words]])
    
    with torch.no_grad():
        # Get context embedding
        ctx_emb = model.embeddings(context_idxs)   # [1, n-1, m]
        x = ctx_emb.view(1, -1)                    # [1, (n-1)*m]
        
        def get_node_prob(node, x):
            """Compute probability of taking right path at node"""
            if not node or not ("left" in node or "right" in node):
                return 1.0
            
            # Get node parameters
            alpha = node["alpha"]
            N = node["N"]
            
            # Compute hidden state
            h_in = model.W(x) + model.U(N) + model.c
            h_out = torch.tanh(h_in)
            
            # Return probability of going right
            return torch.sigmoid(alpha + (h_out @ model.beta))
        
        def get_word_prob(word, x):
            """Compute probability of word by multiplying path probabilities"""
            path = model.paths[word]
            curr_node = model.tree
            log_prob = 0.0
            
            for bit in path:
                p = get_node_prob(curr_node, x)
                log_prob += torch.log(p if bit == "1" else (1 - p))
                curr_node = curr_node["right"] if bit == "1" else curr_node["left"]
            prob = torch.exp(log_prob).item()                 
            
            return prob
        
        # Compute probabilities for all words
        word_probs = [(word, get_word_prob(word, x)) for word in words]
        
        # Sort by probability and get top k
        word_probs.sort(key=lambda x: x[1], reverse=True)
        predictions = word_probs[:k]
    
    return predictions

# Test predictions
test_contexts = [
    ["the", "cat", "sat"],
    ["a", "cat", "chased"],
    ["a", "dog", "barked"],
    ["the", "stars", "glows"]
]

for context in test_contexts:
    predictions = predict_next(context)
    print(f"\nContext: {context}")
    for i, (word, prob) in enumerate(predictions, 1):
        print(f"  {i}. '{word}' ({prob:.4f})")

KeyboardInterrupt: 