### Vectorized implementation - Hierarchical Probabilistic Neural Network Language Model, Morin & Y Bengio (2005)

In [15]:
# =========================
# Hierarchical Probabilistic Neural 
# Network Language Model, 
# Morin & Bengio (2005) 
# =========================

# 1. Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 2. Config
n = 4        # context size = n-1 previous words + 1 target
m = 10       # embedding dimension of each word [m x 1]
h = 16       # hidden layer dimension [h x 1]
d_node = 16  # dimension of each node in the tree [d_node x 1]

epochs = 25
lr = 0.01


#### Toy dataset

In [16]:
# -----------------------------
# 3. Toy Corpus (~20 sentences)
# -----------------------------
corpus = [
    "the cat sat down",
    "the cat ate food",
    "the dog sat down",
    "the dog ate food",
    "a cat chased a mouse",
    "the dog chased the cat",
    "a dog barked loudly",
    "the cat meowed softly",
    "the bird sang sweetly",
    "a bird flew away",
    "the fish swam fast",
    "a fish jumped high",
    "the boy played ball",
    "the girl sang song",
    "a boy read book",
    "a girl wrote letter",
    "the sun shines bright",
    "the moon glows softly",
    "the stars twinkle bright",
    "a cat slept quietly"
]

words = sorted(set(" ".join(corpus).split()))

print("Vocabulary:", words)
print("Vocabulary Size:", len(words))

# 4. Preprocessing
tokens = set(" ".join(corpus).split())
word2idx = {word: i for i, word in enumerate(sorted(tokens))}
idx2word = {i: word for word, i in word2idx.items()}
V = len(word2idx)   # vocabulary size |V|

# make context-target pairs for n-gram model
def make_ngrams(corpus, n):
    X, y = [], []
    for sentence in corpus:
        words = sentence.split()
        for i in range(len(words) - n):
            context = words[i:i+n]
            target = words[i+n]
            X.append([word2idx[w] for w in context])
            y.append(word2idx[target])
    return torch.tensor(X), torch.tensor(y)

X, y = make_ngrams(corpus, n-1)


print("Vocabulary size:", V)
print("Number of training samples:", len(X))
print('Example context:', X[0].tolist(), '-> target:', y[0].item())
print('Example context words:', [idx2word[i] for i in X[0].tolist()])
print('Example target word:', idx2word[y[0].item()])


Vocabulary: ['a', 'ate', 'away', 'ball', 'barked', 'bird', 'book', 'boy', 'bright', 'cat', 'chased', 'dog', 'down', 'fast', 'fish', 'flew', 'food', 'girl', 'glows', 'high', 'jumped', 'letter', 'loudly', 'meowed', 'moon', 'mouse', 'played', 'quietly', 'read', 'sang', 'sat', 'shines', 'slept', 'softly', 'song', 'stars', 'sun', 'swam', 'sweetly', 'the', 'twinkle', 'wrote']
Vocabulary Size: 42
Vocabulary size: 42
Number of training samples: 22
Example context: [39, 9, 30] -> target: 12
Example context words: ['the', 'cat', 'sat']
Example target word: down


#### Data loader

In [17]:
# 5. Dataset/Dataloader
class NGramDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_loader = DataLoader(NGramDataset(X, y), batch_size=4, shuffle=True)


#### Forming the tree

In [18]:
def build_strict_balanced(words, d):
    counter = [0]  # Mutable counter for naming internal nodes
    
    def build_recursive(words):
        if not words:
            return None
        if len(words) == 1:
            return {"name": words[0]}  # Leaf node - no parameters needed
            
        # Internal node with parameters
        node_name = f"n{counter[0]}"
        counter[0] += 1
        
        mid = len(words)//2
        node = {
            "name": node_name,
            "left": build_recursive(words[:mid]),
            "right": build_recursive(words[mid:])
        }
        return node
    
    return build_recursive(words)

# Create balanced tree with parameters
balanced_tree = build_strict_balanced(words, d_node)

In [19]:
def analyze_tree(tree, prefix="", paths=None, depths=None, freqs=None):
    """
    Analyze tree paths, depths and frequencies recursively
    If freqs is None, uses weight=1 for each word
    """
    # Initialize dictionaries
    paths = {} if paths is None else paths
    depths = {} if depths is None else depths
    freqs = {} if freqs is None else freqs
    
    # If leaf node, record path, depth and frequency
    if not tree.get('left') and not tree.get('right'):
        word = tree['name']
        freq = tree.get('freq', 1)  # Default frequency = 1 if not specified
        paths[word] = prefix
        depths[word] = len(prefix)
        freqs[word] = freq
        return paths, depths, freqs
    
    # Recurse on children
    if tree.get('left'):
        analyze_tree(tree['left'], prefix + "0", paths, depths, freqs)
    if tree.get('right'):
        analyze_tree(tree['right'], prefix + "1", paths, depths, freqs)
    
    return paths, depths, freqs

In [20]:
def preprocess_tree(tree, words):
    """
    Converts tree into tensors for vectorized HierarchicalSoftmaxLM.
    Returns: path_nodes, path_bits, num_nodes, unk_node_id, word2idx
    """
    node2id = {}
    counter = [0]

    # Assign IDs only to internal nodes
    def assign_ids(node):
        if not node.get("left") and not node.get("right"):
            return
        node2id[node["name"]] = counter[0]
        counter[0] += 1
        if node.get("left"):  assign_ids(node["left"])
        if node.get("right"): assign_ids(node["right"])

    assign_ids(tree)
    num_nodes = counter[0]

    # Special UNK node for padding
    unk_node_id = num_nodes
    num_nodes += 1  # include UNK

    # Get bit paths as strings
    paths, depths, freqs = analyze_tree(tree)

    # Build word→idx
    word2idx = {w: i for i, w in enumerate(words)}

    # Convert paths to node/bit sequences
    max_len = max(len(p) for p in paths.values())
    path_nodes = torch.full((len(words), max_len), fill_value=unk_node_id, dtype=torch.long)
    path_bits  = torch.full((len(words), max_len), fill_value=-1, dtype=torch.long)

    def traverse_for_nodes(node, prefix=""):
        if not node.get("left") and not node.get("right"):
            word = node["name"]
            wid = word2idx[word]
            for i, bit in enumerate(prefix):
                # find the node ID for this decision
                # step through prefix until this depth
                curr = tree
                for b in prefix[:i]:
                    curr = curr["left"] if b == "0" else curr["right"]
                nid = node2id[curr["name"]]
                path_nodes[wid, i] = nid
                path_bits[wid, i] = int(bit)
            return
        if node.get("left"):  traverse_for_nodes(node["left"], prefix + "0")
        if node.get("right"): traverse_for_nodes(node["right"], prefix + "1")

    traverse_for_nodes(tree)

    return path_nodes, path_bits, num_nodes, unk_node_id, word2idx


#### Model architecture

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HierarchicalSoftmaxLM(nn.Module):
    def __init__(self, vocab_size, n, m, h, d_node,
                 word2idx, path_nodes, path_bits, num_nodes, unk_node_id):
        super().__init__()
        self.vocab_size = vocab_size
        self.n = n
        self.m = m
        self.h = h
        self.d_node = d_node
        self.word2idx = word2idx
        self.path_nodes = path_nodes      # [vocab_size, L]
        self.path_bits = path_bits        # [vocab_size, L]
        self.num_nodes = num_nodes
        self.unk_node_id = unk_node_id
        self.max_path_len = path_nodes.size(1)

        # word embeddings
        self.embeddings = nn.Embedding(vocab_size, m)

        # shared transforms
        self.W = nn.Linear((n-1)*m, h, bias=False)   # context → hidden
        self.U = nn.Linear(d_node, h, bias=False)    # node embedding → hidden
        self.c = nn.Parameter(torch.zeros(h))
        self.beta = nn.Parameter(torch.randn(h))

        # node parameters (vectorized)
        self.N = nn.Embedding(num_nodes, d_node)   # node embeddings
        self.alpha = nn.Embedding(num_nodes, 1)    # node biases

    def forward(self, context_idxs, target_idxs):
        B = context_idxs.size(0)

        # ===== Context representation x =====
        ctx_emb = self.embeddings(context_idxs)      # [B, n-1, m]
        x = ctx_emb.view(B, -1)                      # [B, (n-1)*m]

        # ===== Path lookup =====
        path_nodes_batch = self.path_nodes[target_idxs]   # [B, L]
        path_bits_batch  = self.path_bits[target_idxs]    # [B, L]

        # Node embeddings & biases for this batch
        N_paths    = self.N(path_nodes_batch)        # [B, L, d_node]
        alpha_path = self.alpha(path_nodes_batch)    # [B, L, 1]

        # ===== Hidden computation =====
        h_x = self.W(x)                              # [B, h]
        h_N = self.U(N_paths)                        # [B, L, h]
        h = torch.tanh(h_x.unsqueeze(1) + h_N + self.c)   # [B, L, h]

        # ===== Probabilities =====
        p = torch.sigmoid(alpha_path.squeeze(-1) + (h @ self.beta))  # [B, L]

        # ===== Loss =====
        mask = (path_bits_batch != -1)    # ignore padding
        target = path_bits_batch.float()
        loss = F.binary_cross_entropy(p, target, reduction="none")

        # [B, L]
        masked_loss = loss * mask

        # average over valid path positions for each sample
        loss_per_sample = masked_loss.sum(dim=1) / mask.sum(dim=1)   # [B]

        # mean over batch
        loss = loss_per_sample.mean()


        return loss
    
    def forward_naive(self, context_idxs, target_idxs):
        """
        Naive forward: loop over batch and over path positions.
        Equivalent to vectorized forward(), useful for debugging.
        """
        B = context_idxs.size(0)
        
        # ===== Context representation x =====
        ctx_emb = self.embeddings(context_idxs)      # [B, n-1, m]
        x = ctx_emb.view(B, -1)                      # [B, (n-1)*m]

        losses = []
        for i in range(B):
            # one sample in the batch
            target_idx = target_idxs[i]
            path_nodes_i = self.path_nodes[target_idx]   # [L]
            path_bits_i  = self.path_bits[target_idx]    # [L]

            h_x = self.W(x[i].unsqueeze(0))  # [1, h]
            loss_i = 0.0
            count = 0

            for j, node_id in enumerate(path_nodes_i):
                if path_bits_i[j] == -1:  # padding
                    continue

                # node parameters
                N_j = self.N(node_id)           # [d_node]
                alpha_j = self.alpha(node_id)   # [1]

                # hidden
                h_N = self.U(N_j.unsqueeze(0))  # [1, h]
                h = torch.tanh(h_x + h_N + self.c)  # [1, h]

                # prob
                p = torch.sigmoid(alpha_j + (h @ self.beta))  # [1]
                
                # target bit
                b_val = path_bits_i[j].float()
                loss_j = -(b_val * torch.log(p + 1e-9) +
                           (1 - b_val) * torch.log(1 - p + 1e-9))
                loss_i += loss_j
                count += 1

            if count > 0:
                loss_i = loss_i / count
            losses.append(loss_i)

        return torch.stack(losses).mean()
    
    @torch.no_grad()
    def predict_next(self, context_idxs):
        """
        Vectorized next-word probability computation using hierarchical softmax.
        Returns [B, V] probabilities.
        """
        B = context_idxs.size(0)

        # ===== Context representation =====
        ctx_emb = self.embeddings(context_idxs)  # [B, n-1, m]
        x = ctx_emb.view(B, -1)                 # [B, (n-1)*m]
        h_x = self.W(x)                         # [B, h]

        # Expand across all vocab words
        path_nodes_batch = self.path_nodes.unsqueeze(0).expand(B, -1, -1)  # [B, V, L]
        path_bits_batch  = self.path_bits.unsqueeze(0).expand(B, -1, -1)   # [B, V, L]

        # Node params
        N_paths    = self.N(path_nodes_batch)     # [B, V, L, d_node]
        alpha_path = self.alpha(path_nodes_batch) # [B, V, L, 1]

        # Hidden computation
        h_N = self.U(N_paths)  # [B, V, L, h]
        h   = torch.tanh(h_x.unsqueeze(1).unsqueeze(2) + h_N + self.c)  # [B, V, L, h]

        # Node probabilities
        p = torch.sigmoid(alpha_path.squeeze(-1) + (h @ self.beta))  # [B, V, L]

        # Mask padding
        mask = (path_bits_batch != -1).float()  # [B, V, L]
        bits = path_bits_batch.float()

        # Log-prob along path
        logp = bits * torch.log(p + 1e-9) + (1 - bits) * torch.log(1 - p + 1e-9)
        logp = (logp * mask).sum(dim=-1)  # sum over path length → [B, V]

        probs = torch.exp(logp)  # true hierarchical softmax probabilities
        return probs  # [B, V]

    @torch.no_grad()
    def predict_next_naive(self, context_idxs):
        """
        Naive loop-based next-word probability computation.
        """
        B = context_idxs.size(0)

        # Context representation
        ctx_emb = self.embeddings(context_idxs)   # [B, n-1, m]
        x = ctx_emb.view(B, -1)                   # [B, (n-1)*m]

        probs = []
        for i in range(B):
            h_x = self.W(x[i].unsqueeze(0))  # [1, h]
            word_probs = []

            for target_idx in range(self.vocab_size):
                path_nodes_i = self.path_nodes[target_idx]  # [L]
                path_bits_i  = self.path_bits[target_idx]   # [L]

                logp = 0.0
                for j, node_id in enumerate(path_nodes_i):
                    if path_bits_i[j] == -1:  # padding
                        continue

                    N_j = self.N(node_id)
                    alpha_j = self.alpha(node_id)

                    h_N = self.U(N_j.unsqueeze(0))  # [1, h]
                    h = torch.tanh(h_x + h_N + self.c)  # [1, h]

                    p = torch.sigmoid(alpha_j + (h @ self.beta))
                    b_val = float(path_bits_i[j])
                    logp += b_val * torch.log(p + 1e-9) + (1 - b_val) * torch.log(1 - p + 1e-9)

                word_probs.append(logp)

            logps = torch.stack(word_probs).squeeze()  # [V]
            probs.append(torch.exp(logps))            # convert log → prob

        probs = torch.stack(probs, dim=0)  # [B, V]
        return probs    



#### Model definition

In [22]:
balanced_tree = build_strict_balanced(words, d_node)
path_nodes, path_bits, num_nodes, unk_node_id, word2idx = preprocess_tree(balanced_tree, words)

model = HierarchicalSoftmaxLM(
    vocab_size=len(words),
    n=n, m=m, h=h, d_node=d_node,
    word2idx=word2idx,
    path_nodes=path_nodes,
    path_bits=path_bits,
    num_nodes=num_nodes,
    unk_node_id=unk_node_id
)

# Print model architecture and parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
print("Total parameters:", count_parameters(model))


def parameter_report(model):
    report = []
    total = 0
    for name, p in model.named_parameters():
        if p.requires_grad:
            num = p.numel()
            total += num
            report.append(f"{name:30} {str(list(p.shape)):20} {num}")
    report.append(f"{'Total':30} {'':20} {total}")
    return "\n".join(report)

print(parameter_report(model))



Total parameters: 1902
c                              [16]                 16
beta                           [16]                 16
embeddings.weight              [42, 10]             420
W.weight                       [16, 30]             480
U.weight                       [16, 16]             256
N.weight                       [42, 16]             672
alpha.weight                   [42, 1]              42
Total                                               1902


#### Comparing the naive and vectorized implementations

In [23]:
losses_naive = 0
losses_vec = 0
for context_idxs, target_idxs in train_loader:
    with torch.no_grad():
        # Naive loss
        loss_naive = model.forward_naive(context_idxs, target_idxs)
        # Vectorized loss
        loss_vec = model.forward(context_idxs, target_idxs)

        losses_naive += loss_naive.item()
        losses_vec += loss_vec.item()
        print(f"Naive loss: {loss_naive.item():.4f}, Vectorized loss: {loss_vec.item():.4f}")

print("Naive loss (Sum):      ", losses_naive)
print("Vectorized loss (Sum):", losses_vec)
print("Difference:     ", abs(losses_naive - losses_vec))

Naive loss: 0.8713, Vectorized loss: 0.8713
Naive loss: 0.9398, Vectorized loss: 0.9398
Naive loss: 1.0094, Vectorized loss: 1.0094
Naive loss: 1.5997, Vectorized loss: 1.5997
Naive loss: 1.3727, Vectorized loss: 1.3727
Naive loss: 1.5239, Vectorized loss: 1.5239
Naive loss (Sum):       7.316938698291779
Vectorized loss (Sum): 7.316938281059265
Difference:      4.172325134277344e-07


#### Training loop

In [24]:
# =============================
# Training
# =============================

device = torch.device("cpu")  # force CPU for small data
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    total_loss = 0
    for context, target in train_loader:
        context, target = context.to(device), target.to(device)

        optimizer.zero_grad()
        loss = model(context, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")


Epoch 1: Loss = 1.1841
Epoch 2: Loss = 0.8585
Epoch 3: Loss = 0.7073
Epoch 4: Loss = 0.6045
Epoch 5: Loss = 0.5201
Epoch 6: Loss = 0.4887
Epoch 7: Loss = 0.4365
Epoch 8: Loss = 0.3897
Epoch 9: Loss = 0.3371
Epoch 10: Loss = 0.2944
Epoch 11: Loss = 0.2619
Epoch 12: Loss = 0.2294
Epoch 13: Loss = 0.2022
Epoch 14: Loss = 0.1779
Epoch 15: Loss = 0.1547
Epoch 16: Loss = 0.1382
Epoch 17: Loss = 0.1211
Epoch 18: Loss = 0.1058
Epoch 19: Loss = 0.0965
Epoch 20: Loss = 0.0849
Epoch 21: Loss = 0.0760
Epoch 22: Loss = 0.0683
Epoch 23: Loss = 0.0625
Epoch 24: Loss = 0.0570
Epoch 25: Loss = 0.0507


#### Inference

In [25]:
def predict_next_words_batch(model, context_words_batch, word2idx, idx2word, k=5, method="vectorized"):
    """
    Predict next word probabilities for a batch of contexts using hierarchical softmax.
    
    Args:
        model: HierarchicalSoftmaxLM instance
        context_words_batch: List of contexts, where each context is a list of n-1 words
        word2idx: dict mapping word -> index
        idx2word: dict mapping index -> word
        k: number of top predictions per context
        method: "vectorized" or "naive"
    
    Returns:
        List of lists, where each inner list contains (word, probability) tuples for one context
    """
    # Convert all contexts to indices
    batch_size = len(context_words_batch)
    context_idxs = torch.tensor([
        [word2idx.get(w, model.unk_node_id) for w in context]
        for context in context_words_batch
    ], dtype=torch.long)  # [B, n-1]

    # Get probabilities for all contexts at once
    if method == "vectorized":
        probs = model.predict_next(context_idxs)  # [B, V]
    elif method == "naive":
        probs = model.predict_next_naive(context_idxs)  # [B, V]
    else:
        raise ValueError("method must be 'vectorized' or 'naive'")

    # Get top-k predictions for each context
    topk_probs, topk_idx = torch.topk(probs, k, dim=1)  # [B, k]
    
    # Convert to list of word-probability pairs for each context
    results = []
    for i in range(batch_size):
        context_results = [
            (idx2word[idx.item()], prob.item())
            for idx, prob in zip(topk_idx[i], topk_probs[i])
        ]
        results.append(context_results)

        # Print predictions for this context
        print(f"\nContext: {context_words_batch[i]}")
        print(f"{method.capitalize()} top-{k}:")
        for j, (word, prob) in enumerate(context_results):
            print(f"  {j}. '{word}' ({prob:.4f})")
        print("-" * 50)        

    return results


In [26]:

# Example usage:
test_contexts = [
    ["the", "cat", "sat"],
    ["a", "cat", "chased"],
    ["a", "dog", "barked"],
    ["the", "stars", "glows"]
]


In [27]:

# Get predictions - naive
predictions_naive = predict_next_words_batch(model, test_contexts, word2idx, idx2word, k=5, method="naive")



Context: ['the', 'cat', 'sat']
Naive top-5:
  0. 'down' (0.7875)
  1. 'away' (0.0666)
  2. 'bird' (0.0518)
  3. 'a' (0.0245)
  4. 'food' (0.0244)
--------------------------------------------------

Context: ['a', 'cat', 'chased']
Naive top-5:
  0. 'a' (0.8077)
  1. 'away' (0.0612)
  2. 'down' (0.0409)
  3. 'letter' (0.0313)
  4. 'bird' (0.0180)
--------------------------------------------------

Context: ['a', 'dog', 'barked']
Naive top-5:
  0. 'loudly' (0.7393)
  1. 'food' (0.0655)
  2. 'letter' (0.0509)
  3. 'mouse' (0.0385)
  4. 'quietly' (0.0381)
--------------------------------------------------

Context: ['the', 'stars', 'glows']
Naive top-5:
  0. 'food' (0.3163)
  1. 'bright' (0.1905)
  2. 'book' (0.1835)
  3. 'high' (0.0573)
  4. 'cat' (0.0474)
--------------------------------------------------


In [28]:

# Get predictions - vectorized
predictions_vectorized = predict_next_words_batch(model, test_contexts, word2idx, idx2word, k=5, method="vectorized")



Context: ['the', 'cat', 'sat']
Vectorized top-5:
  0. 'down' (0.7875)
  1. 'away' (0.0666)
  2. 'bird' (0.0518)
  3. 'a' (0.0245)
  4. 'food' (0.0244)
--------------------------------------------------

Context: ['a', 'cat', 'chased']
Vectorized top-5:
  0. 'a' (0.8077)
  1. 'away' (0.0612)
  2. 'down' (0.0409)
  3. 'letter' (0.0313)
  4. 'bird' (0.0180)
--------------------------------------------------

Context: ['a', 'dog', 'barked']
Vectorized top-5:
  0. 'loudly' (0.7393)
  1. 'food' (0.0655)
  2. 'letter' (0.0509)
  3. 'mouse' (0.0385)
  4. 'quietly' (0.0381)
--------------------------------------------------

Context: ['the', 'stars', 'glows']
Vectorized top-5:
  0. 'food' (0.3163)
  1. 'bright' (0.1905)
  2. 'book' (0.1835)
  3. 'high' (0.0573)
  4. 'cat' (0.0474)
--------------------------------------------------
