# Faithful NPLM Implementation

This notebook provides a **complete, production-quality implementation** of the Neural Probabilistic Language Model (NPLM) from the original repository code. This is NOT simplified - it uses the actual implementation from `models/nplm.py`, `models/embeddings.py`, `models/adaptive_softmax.py`, and related files.

The code is cherry-picked directly from the original implementation to maintain full architectural fidelity.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import uuid
import threading
from collections import defaultdict
from tqdm import tqdm
import sys
import pickle

In [None]:
# Utility Functions (from utils/__init__.py)

def left_shift(x, dim=-1, shift=1, fill=None):
    """Left shift the given tensor"""
    if not shift:
        return x
    
    if fill is not None:
        x = right_pad(x, dim, shift, fill)
    
    shape = list(x.shape)
    dims = len(shape)
    dim = dim % dims
    return x[tuple(slice(shift if d == dim else 0, s + shift) for d, s in enumerate(shape))]


def right_shift(x, dim=-1, shift=1, fill=None):
    """Right shift the given tensor"""
    if not shift:
        return x
    
    if fill is not None:
        x = left_pad(x, dim, shift, fill)
    
    shape = list(x.shape)
    dims = len(shape)
    dim = dim % dims
    return x[tuple(slice(-shift if d == dim else s) for d, s in enumerate(shape))]


def left_pad(x, dim=-1, count=1, fill=0):
    """Left pad the given tensor"""
    if not count:
        return x
    
    shape = list(x.shape)
    dims = len(shape)
    dim = dim % dims
    fill_shape = shape[:dim] + [count] + shape[dim + 1:]
    return torch.cat((x.new_full(fill_shape, fill), x), dim)


def right_pad(x, dim=-1, count=1, fill=0):
    """Right pad the given tensor"""
    if not count:
        return x
    
    shape = list(x.shape)
    dims = len(shape)
    dim = dim % dims
    fill_shape = shape[:dim] + [count] + shape[dim + 1:]
    return torch.cat((x, x.new_full(fill_shape, fill)), dim)


def triu(inputs, diagonal=0, span=1, stride=1, offset=0):
    """Returns an upper triangular matrix with span support"""
    for i, row in enumerate(inputs):
        row[:span * (diagonal + i // stride) + offset] = 0.
    return inputs


In [None]:
# TokenEmbedding (from models/embeddings.py)

class TokenEmbedding(nn.Module):
    """An embedding layer used for the transformer"""
    def __init__(self, num_embeddings, embedding_dim, proj_dim, cutoffs, emb_std=0.01, proj_std=0.02, div_val=1, padding_idx=0, do_proj=False):
        super(TokenEmbedding, self).__init__()

        self.vocab_size = num_embeddings
        self.embed_dim = embedding_dim
        self.proj_dim = proj_dim
        self.cutoffs = [0] + cutoffs + [self.vocab_size]
        self.div_val = div_val

        self.emb_scale = self.proj_dim ** 0.5
        self.emb_std = emb_std
        self.proj_std = proj_std
        self.do_proj = do_proj

        self.padding_idx = padding_idx

        self.emb_layers = nn.ModuleList()
        self.emb_projs = nn.ModuleList()
        if self.div_val == 1:
            self.emb_layers.append(
                nn.Embedding(self.vocab_size, self.embed_dim)
            )
            if self.proj_dim != self.embed_dim and self.do_proj:
                self.emb_projs.append(nn.Linear(self.proj_dim, self.embed_dim))
        else:
            for i in range(len(self.cutoffs) - 1):
                l_idx, r_idx = self.cutoffs[i], self.cutoffs[i+1]
                d_emb_i = self.embed_dim // (self.div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
                self.emb_projs.append(nn.Linear(self.proj_dim, d_emb_i))

        self.reset_parameters()

    def reset_parameters(self):
        """Reset params"""
        for l in self.emb_layers:
            if self.emb_std is not None:
                nn.init.normal_(l.weight, mean=0, std=self.emb_std)
            else:
                nn.init.normal_(l.weight, mean=0, std=self.embed_dim ** -0.5)

        for p in self.emb_projs:
            if self.proj_std is not None:
                nn.init.normal_(p.weight, mean=0, std=self.proj_std)
            else:
                nn.init.normal_(p.weight, mean=0, std=self.embed_dim ** -0.5)
            nn.init.constant_(p.bias, 0.)

        nn.init.constant_(self.emb_layers[0].weight[self.padding_idx], 0)

    def forward(self, inputs, reverse=False):
        """Implement the forward pass of the embedding"""

        if reverse:
            return F.linear(inputs, self.emb_layers[0].weight)
        else:
            if self.div_val == 1:
                embed = self.emb_layers[0](inputs)
                if self.proj_dim != self.embed_dim and self.do_proj:
                    embed  = F.linear(embed, self.emb_projs[0].weight)
            else:
                param = next(self.parameters())
                inp_flat = inputs.contiguous().view(-1)
                emb_flat = torch.zeros([inp_flat.size(0), self.proj_dim], 
                    dtype=param.dtype, device=param.device)

                for i in range(len(self.cutoffs)-1):
                    l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]

                    mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                    indices_i = mask_i.nonzero().squeeze()

                    if indices_i.numel() == 0:
                        continue

                    inp_i = inp_flat.index_select(0, indices_i) - l_idx
                    emb_i = self.emb_layers[i](inp_i)
                    emb_i = F.linear(emb_i, self.emb_projs[i].weight.t())

                    emb_flat.index_copy_(0, indices_i, emb_i)

                embed = emb_flat.view(*inputs.size(), self.proj_dim)

            embed.mul_(self.emb_scale)

            return embed


In [None]:
# PositionEmbedding (from models/embeddings.py)

class PositionEmbedding(nn.Module):
    """Produce position embeddings"""
    def __init__(self, dim, freq=1e4):
        super(PositionEmbedding, self).__init__()
        self.dim = dim
        self.freq = freq

    _embeddings = threading.local()
    def forward(self, inputs):
        device = inputs.device
        max_length = inputs.shape[1]
        embedding_store = PositionEmbedding._embeddings.__dict__
        device_store = embedding_store.get(device, {})
        if (
                not device_store or
                self.dim not in device_store or
                device_store[self.dim].shape[0] < max_length
        ):
            positions = torch.arange(0., max_length, device=device).unsqueeze(1)
            dims = torch.arange(0., self.dim, 2., device=device).unsqueeze(0) / (self.dim - 2)

            sin = torch.sin(positions / torch.pow(self.freq, dims))
            cos = torch.cos(positions / torch.pow(self.freq, dims))

            embeddings = torch.stack((sin, cos), 0)
            device_store[self.dim] = embeddings.transpose(0, 1).contiguous().view(-1, self.dim)

        embeddings = device_store[self.dim]
        embedding_store[device] = device_store
        return embeddings[:max_length].unsqueeze(0)


In [None]:
# AdaptiveSoftmax (from models/adaptive_softmax.py)

class AdaptiveSoftmax(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
                 keep_order=False, init_std=0.02, init_proj_std=0.01):
        super(AdaptiveSoftmax, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed
        self.d_proj = d_proj

        self.cutoffs = cutoffs + [n_token]
        self.cutoff_ends = [0] + self.cutoffs
        self.div_val = div_val
        self.init_std = init_std
        self.init_proj_std = init_proj_std

        self.shortlist_size = self.cutoffs[0]
        self.n_clusters = len(self.cutoffs) - 1
        self.head_size = self.shortlist_size + self.n_clusters

        if self.n_clusters > 0:
            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
            
        self.out_layers = nn.ModuleList()
        self.out_projs = nn.ModuleList()

        if div_val == 1:
            for i in range(len(self.cutoffs)):
                if d_proj != d_embed:
                    self.out_projs.append(nn.Linear(d_proj, d_embed))
                else:
                    self.out_projs.append(None)

            self.out_layers.append(nn.Linear(d_embed, n_token))
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)

                self.out_projs.append(nn.Linear(d_proj, d_emb_i))
                self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
            self.reset_parameters()
        self.keep_order = keep_order

    def reset_parameters(self):
        nn.init.normal_(self.cluster_weight, 0., self.init_std)
        nn.init.constant_(self.cluster_bias, 0.)

        for i in range(len(self.out_projs)):
            if self.out_projs[i] is not None:
                nn.init.normal_(self.out_projs[i].weight, 0., self.init_proj_std)
                nn.init.constant_(self.out_projs[i].bias, 0.)
        for i in range(len(self.out_layers)):
            nn.init.normal_(self.out_layers[i].weight, 0., self.init_proj_std)
            nn.init.constant_(self.out_layers[i].bias, 0.)

    def _compute_logit(self, hidden, weight, bias, proj):
        if proj is None:
            logit = F.linear(hidden, weight, bias=bias)
        else:
            proj_hid = proj(hidden)
            logit = F.linear(proj_hid, weight, bias=bias)
        return logit

    def forward(self, hidden, target, keep_order=False, return_rank=False, return_all_logprobs=False):
        if hidden.size(0) != target.size(0):
            raise RuntimeError('Input and target should have the same size in the batch dimension.')

        if self.n_clusters == 0:
            logit = self._compute_logit(hidden, self.out_layers[0].weight,
                                        self.out_layers[0].bias, self.out_projs[0])
            nll = -F.log_softmax(logit, dim=-1).gather(1, target.unsqueeze(1)).squeeze(1)
        else:
            # construct weights and biases
            weights, biases = [], []
            for i in range(len(self.cutoffs)):
                if self.div_val == 1:
                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
                    weight_i = self.out_layers[0].weight[l_idx:r_idx]
                    bias_i = self.out_layers[0].bias[l_idx:r_idx]
                else:
                    weight_i = self.out_layers[i].weight
                    bias_i = self.out_layers[i].bias

                if i == 0:
                    weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
                    bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)

                weights.append(weight_i)
                biases.append(bias_i)

            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]

            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
            head_logprob = F.log_softmax(head_logit, dim=1)

            if return_all_logprobs:
                all_logprobs = [head_logprob[:, :-len(self.cutoffs)+1]]
                cutoff_values = [0] + self.cutoffs

                for i in range(len(cutoff_values) - 1):
                    l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]

                    if i != 0:
                        weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
                        tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
                        tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
                        all_logprobs.append(head_logprob[:, -i] + tail_logprob_i)

                if return_all_logprobs:
                    return torch.cat(all_logprobs, dim=1)

            nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)

            offset = 0
            cutoff_values = [0] + self.cutoffs
            for i in range(len(cutoff_values) - 1):
                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]

                mask_i = (target >= l_idx) & (target < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:
                    continue

                target_i = target.index_select(0, indices_i) - l_idx
                head_logprob_i = head_logprob.index_select(0, indices_i)

                if i == 0:
                    logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
                else:
                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
                    hidden_i = hidden.index_select(0, indices_i)
                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
                    logprob_i = head_logprob_i[:, -i] + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)

                if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
                    nll.index_copy_(0, indices_i, -logprob_i)
                else:
                    nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)

                offset += logprob_i.size(0)

            if return_rank:
                assert keep_order is True
                for i in range(len(cutoff_values) - 1):
                    head_logprob_i = head_logprob

                    if i == 0:
                        rank = (-head_logprob[:, :-self.n_clusters] < nll[:, None]).sum(-1)
                    else:
                        weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
                        hidden_i = hidden
                        tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
                        tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
                        tail_logprob_i = head_logprob_i[:, -i].unsqueeze(-1) + tail_logprob_i
                        rank += (-tail_logprob_i < nll[:, None]).sum(-1)
                        
        if not return_rank:
            return nll
        else:
            return nll, rank


In [None]:
# LabelSmoothingLoss (from models/utils.py)

class LabelSmoothingLoss(nn.Module):
    """Implements the label smoothing loss"""
    def __init__(self, smoothing=0.0, ignore_index=-1, reduction='sum'):
        super(LabelSmoothingLoss,  self).__init__()
        self.reduction = reduction
        self.smoothing = smoothing
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        num_classes = inputs.shape[1]
        smoothed = inputs.new_full(inputs.shape, self.smoothing / num_classes)
        smoothed.scatter_(1, targets.unsqueeze(1), 1 - self.smoothing)

        if self.ignore_index >= 0 and self.ignore_index < num_classes:
            smoothed[:, self.ignore_index] = 0.
            mask = targets == self.ignore_index
            smoothed.masked_fill_(mask.unsqueeze(1), 0.)

        return F.kl_div(inputs.log_softmax(1), smoothed, reduction=self.reduction)


In [None]:
# NPLMFF (from models/nplm.py)

class NPLMFF(nn.Module):
    """Implements the NPLM feed-forward network"""
    def __init__(self, input_dim, hidden_dim, init_std=0.02, output_proj=True, proj_dim=-1):
        super(NPLMFF, self).__init__()

        self.init_std = init_std
        self.relu = nn.ReLU()

        if proj_dim != -1:
            self.hidden = nn.Linear(input_dim, proj_dim)
            self.output = nn.Linear(proj_dim, hidden_dim)
        else:
            self.hidden = nn.Linear(input_dim, hidden_dim)
            if output_proj:
                self.output = nn.Linear(hidden_dim, input_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.hidden.weight, 0., self.init_std)
        nn.init.constant_(self.hidden.bias, 0.)

        if hasattr(self, 'output'):
            nn.init.normal_(self.output.weight, 0., self.init_std)
            nn.init.constant_(self.output.bias, 0.)

    def forward(self, inputs):
        if hasattr(self, 'output'):
            return self.output(self.relu(self.hidden(inputs)))
        else:
            return self.relu(self.hidden(inputs))


In [None]:
# NPLMSublayer (from models/nplm.py)

class NPLMSublayer(nn.Module):
    def __init__(self, sublayer, do_add, no_layernorm, sublayer_shape, dropout_p=0.1, init_std=0.02):
        super(NPLMSublayer, self).__init__()
        self.init_std = init_std
        self.sublayer = sublayer
        self.sublayer_shape = sublayer_shape
        self.do_add = do_add
        self.norm = nn.LayerNorm(sublayer_shape) if not no_layernorm else None
        self.dropout = nn.Dropout(dropout_p, inplace=True)
        self.reset_parameters()

    def reset_parameters(self):
        if self.norm is not None:
            nn.init.normal_(self.norm.weight, 1.0, self.init_std)

    def forward(self, inputs, *sublayer_args, **sublayer_kwargs):
        if self.do_add: 
            if inputs.size(2) != self.sublayer_shape:
                bsz, seq_len, dim = inputs.shape
                inputs = inputs.view(bsz, seq_len, -1, self.sublayer_shape).contiguous().sum(dim=-2)
            out = inputs + self.dropout(self.sublayer(*sublayer_args, **sublayer_kwargs))
            return out if self.norm is None else self.norm(out)
        else:
            out = self.dropout(self.sublayer(*sublayer_args, **sublayer_kwargs))
            return out if self.norm is None else self.norm(out)


In [None]:
# NPLMLayer (from models/nplm.py)

class NPLMLayer(nn.Module):
    """Implements a single decoder layer in a NPLM decoder stack"""
    def __init__(self, config, num_heads, dim, hidden_dim, layer_i, dropout_p=0.1):
        super(NPLMLayer, self).__init__()
        self.config = config
        self.uuid = uuid.uuid4() 

        # ngm: n tokens that concat with full embs
        # wsz: window size to average for long term context
        self.ngm, self.wsz = config.context_config    
        self.long_term_block = 0 if self.ngm > 0 and self.wsz == -1 else \
                                    (config.batch_length - self.ngm) // self.wsz
        self.long_term_block *= self.config.num_global_agg

        self.emb_dim = dim
        self.dim_concat_embs = self.ngm * dim + self.long_term_block * dim
        
        self.hidden_dim = hidden_dim
        self.num_layers = config.num_layers
        
        if layer_i in config.concat_layers:
            if self.config.global_aggregate == 'kernel':
                for i in range(self.long_term_block):
                    setattr(self, f'learned_global_kernels_l{layer_i}_b{i}', 
                            nn.Parameter(torch.tensor(1./self.wsz).repeat(self.wsz)[None, None, :, None]\
                                .repeat(self.config.num_global_agg, 1, 1, 1),requires_grad=True))

            self.ffn_nplm = NPLMSublayer(
                        NPLMFF(self.dim_concat_embs, 
                               self.emb_dim, 
                               output_proj=False,
                               proj_dim=self.config.mid_dim), 
                        True,
                        config.no_layernorm,
                        self.emb_dim, dropout_p)

        else:
            self.ffn = NPLMSublayer(
                        NPLMFF(self.emb_dim, self.hidden_dim), 
                        True,
                        config.no_layernorm,
                        self.emb_dim, dropout_p)
    
    _kernels = threading.local()
    def _get_kernel(self, device):
        kernel_store = NPLMLayer._kernels.__dict__
        if device not in kernel_store:
            kernel_store[device] = torch.tensor(1./self.wsz).repeat(self.wsz)[None, None, :, None].to(device)
        return kernel_store[device]

    _masks = threading.local()
    def mask(self, inputs):
        dim = inputs.shape[1]
        device = inputs.device
        mask_store = NPLMLayer._masks.__dict__
        if device not in mask_store:
            mask = inputs.new_full((dim, dim), float('-inf'))
            mask_store[device] = triu(mask, 1, 1, 1)

        mask = mask_store[device]
        return mask[None, :dim, :dim]

    def reset_parameters(self):
        if hasattr(self, 'ffn'):
            self.ffn.reset_parameters()
        if hasattr(self, 'ffn_nplm'):
            self.ffn_nplm.reset_parameters()

    def forward(self, inputs, layer_i=0, global_mem=0):
        state = inputs['state']
        cache = inputs.get('cache')
        decoder_position = state.shape[1] - 1
        ngm, wsz = self.ngm, self.wsz
        dim = self.emb_dim

        # embedding concatenation layer
        if layer_i in self.config.concat_layers: 
            bsz, L, emb_dim = state.shape

            state_ = state.new_full((bsz, L, self.dim_concat_embs), 0.)
            for i in range(ngm):
                state_[:, i:, i*emb_dim : (i+1)*emb_dim] = state[:, : L-i, :]

            ltb = min((L - ngm) // wsz, self.long_term_block // self.config.num_global_agg)
            ltb = min((L - ngm) // wsz * self.config.num_global_agg, self.long_term_block) \
                    if self.config.global_aggregate == 'average' else ltb

            for  i in range(ltb):
                if self.config.global_aggregate == 'average':
                    conv_tmp = F.conv1d(state[:, None, : - ngm - i*wsz], 
                                        self._get_kernel(state.device),
                                        padding=(wsz-1,0))[:, :, :-wsz+1].squeeze(1)
                    state_[:, ngm + i * wsz:, (ngm+i) * dim: (ngm+i+1) * dim] = conv_tmp

                elif self.config.global_aggregate == 'kernel':
                    conv_tmp = F.conv1d(state[:, None, : - ngm - i*wsz],  
                                        getattr(self, f'learned_global_kernels_l{layer_i}_b{i}'),
                                        padding=(wsz-1,0))[:, :, :-wsz+1].squeeze(1)

                    conv_tmp = conv_tmp.transpose(2, 1).contiguous().view(bsz, -1, dim * ltb * self.config.num_global_agg)
                    state_[:, ngm + i * wsz:, (ngm + i) * dim: ] = conv_tmp

            _, global_l, global_dim = state.shape
            self.global_mem = state_[:, :, ngm * dim:].view(bsz, global_l, -1, emb_dim).contiguous()
            self.global_mem = self.global_mem.sum(dim=-2)

            state = self.ffn_nplm(state_, state_)

        else:
            # regular NPLM layer
            state = self.ffn(state, state)
            state = state + global_mem

        if cache is not None:
            cached = cache.get(self.uuid)
            state = cache[self.uuid] = torch.cat((cached, state), 1)

        return {'state': state, 'cache': cache}


In [None]:
# NPLM Model (from models/nplm.py)

class NPLM(nn.Module):
    """The neural probabilistic LM module"""
    def __init__(self, config, dataset):
        super(NPLM, self).__init__()

        self.dataset = dataset
        
        self.adaptive = config.adaptive
        self.ngm, self.wsz = config.context_config
        self.long_term_block = 0 if self.ngm > 0 and self.wsz == -1 else \
                                    (config.batch_length - self.ngm) // self.wsz

        self.dim_concat_embs = self.ngm * config.embedding_size + self.long_term_block * config.embedding_size

        self.embedding = TokenEmbedding(
                dataset.vocab_size,
                config.embedding_size,
                config.model_size, 
                config.cutoffs,
                emb_std=config.emb_std,
                proj_std = config.proj_std,
                div_val=config.div_val,
                padding_idx=self.padding_idx,
                do_proj=config.do_proj
            )

        if self.adaptive:
            self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.embedding_size, 
                                                    config.cutoffs, div_val=config.div_val)

            self.tie_weights = config.tie_weights
            self.tie_projs = config.tie_projs

            if self.tie_weights:
                for i in range(len(self.adaptive_softmax.out_layers)):
                    self.adaptive_softmax.out_layers[i].weight = self.embedding.emb_layers[i].weight

            if self.tie_projs:
                for i in range(1, len(self.adaptive_softmax.out_projs)):
                    if config.div_val == 1 and config.model_size != config.embedding_size:
                        self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[0]
                    elif config.div_val != 1:
                        self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[i]

        self.layers = self.create_layers(config)
        self.position_embedding = PositionEmbedding(config.model_size)
        self.label_smoothing = LabelSmoothingLoss(
            config.label_smoothing or 0,
            ignore_index=self.padding_idx,
            reduction='none'
        )
        self.cross_entropy = nn.CrossEntropyLoss(
            ignore_index=self.padding_idx,
            reduction='none'
        )

        self.dropout = nn.Dropout(config.dropout_p, inplace=True)
        self.config = config

    @classmethod
    def create_layers(self, config):
        kwargs = {'dropout_p': config.dropout_p}
        args = [config, config.num_heads, config.embedding_size, config.hidden_dim]

        layers = nn.ModuleList([
            NPLMLayer(*args, layer_i, **kwargs)
            for layer_i in range(config.num_layers)
        ])

        return layers

    @property
    def padding_idx(self):
        return self.dataset.padding_idx

    @property
    def eos_idx(self):
        return  self.dataset.eos_idx

    def reset_named_parameters(self, modules):
        if 'layers' in modules:
            for layer in self.layers:
                layer.reset_parameters()

        if 'embeddings' in modules:
            self.embedding.reset_parameters()

    def forward(self, batch):
        batch = batch.t()
        targets = left_shift(batch)
        decoded = self.decode(right_shift(batch))

        state = decoded['state']

        if not self.adaptive:
            logits = self.embedding(state, reverse=True).transpose(2, 1)
            dims = list(range(1, logits.dim()))
            nll = self.cross_entropy(logits, targets).view(-1)
            smoothed_nll = self.label_smoothing(logits, targets).sum(dims)

            if not self.config.return_rank:
                return smoothed_nll, nll
            else:
                logits = logits.transpose(2, 1)
                assert targets.shape[0] == 1
                targets = targets.squeeze(0)
                target_logits = logits[:, range(targets.shape[0]), targets]
                rank = (logits > target_logits.unsqueeze(-1)).sum(dim=-1)
                return rank, nll

        else:
            state = state.view(-1, state.shape[-1])
            targets = targets.contiguous().view(-1)

            if not self.config.return_rank:
                nll = self.adaptive_softmax(state, targets, keep_order=True)
                smoothed_nll = nll
                return smoothed_nll, nll
            else:
                nll, rank = self.adaptive_softmax(state, targets, keep_order=True, return_rank=True)
                return rank, nll

        return smoothed_nll, nll

    def decode(self, batch, cache=None):
        word_embedding = self.embed(batch, self.embedding)

        decoded = {
            'cache': cache,
            'state': word_embedding,
        }

        # concat layer
        decoded = self.layers[0](decoded, layer_i=0)
        global_mem = self.layers[0].global_mem

        # regular layers
        for i, decoder in enumerate(self.layers[1:]):
            decoded = decoder(decoded, layer_i=i+1, global_mem=global_mem)

        state = decoded['state']
        if cache is not None:
            state = state[:, -1:]

        return {
            'cache': decoded.get('cache'),
            'state': state,
        }

    def embed(self, inputs, token_embedding):
        if self.config.TFN:
            return self.dropout(token_embedding(inputs) + self.position_embedding(inputs))
        else:
            return self.dropout(token_embedding(inputs))


In [None]:
# Configuration Class

class SimpleConfig:
    """Simplified config for NPLM"""
    def __init__(self):
        # Model architecture
        self.embedding_size = 256
        self.model_size = 256
        self.hidden_dim = 1024
        self.num_layers = 4
        self.num_heads = 4
        
        # NPLM specific
        self.context_config = (3, 4)  # (ngm, wsz): 3 tokens concat, window size 4
        self.concat_layers = [0]  # Which layers use concatenation
        self.global_aggregate = 'average'  # 'average' or 'kernel'
        self.num_global_agg = 1
        self.mid_dim = 512  # intermediate dimension for NPLM FF
        
        # Regularization
        self.dropout_p = 0.1
        self.label_smoothing = 0.0
        self.no_layernorm = False
        
        # Adaptive softmax
        self.adaptive = False
        self.cutoffs = []
        self.div_val = 1
        self.tie_weights = False
        self.tie_projs = False
        
        # Embedding
        self.emb_std = 0.01
        self.proj_std = 0.02
        self.do_proj = False
        
        # Training
        self.batch_size = 16
        self.batch_length = 32
        self.TFN = False  # Transformer-N variant
        self.return_rank = False
        
        # Optimizer
        self.base_lr = 0.0005
        self.final_lr = 0.0001
        self.warmup_steps = 100
        self.max_steps = 2000


In [None]:
# Simple Synthetic Dataset

class SimpleDataset:
    """A simple synthetic dataset for demonstration"""
    def __init__(self, vocab_size=1000, seq_length=10000):
        self.vocab_size = vocab_size
        self.padding_idx = 0
        self.eos_idx = 1
        
        # Generate synthetic data (random tokens)
        # In practice, this would be real text data
        np.random.seed(42)
        self.data = torch.from_numpy(
            np.random.randint(2, vocab_size, size=seq_length)
        ).long()
        
    def get_batch(self, batch_size, batch_length):
        """Get a random batch"""
        max_start = len(self.data) - batch_length - 1
        starts = torch.randint(0, max_start, (batch_size,))
        
        batch = torch.stack([
            self.data[start:start + batch_length] 
            for start in starts
        ])
        
        return batch  # batch_size x batch_length


In [None]:
# Training Function

def train_nplm(model, dataset, config, num_steps=100, device='cpu'):
    """Simple training loop"""
    model = model.to(device)
    model.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config.base_lr)
    
    losses = []
    for step in tqdm(range(num_steps), desc="Training"):
        # Get batch
        batch = dataset.get_batch(config.batch_size, config.batch_length)
        batch = batch.to(device)
        
        # Forward pass
        smoothed_nll, nll = model(batch)
        loss = smoothed_nll.sum()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
        optimizer.step()
        
        # Record loss
        losses.append(nll.sum().item() / (config.batch_size * config.batch_length))
        
        if step % 20 == 0:
            print(f"Step {step}, NLL: {losses[-1]:.4f}, PPL: {np.exp(losses[-1]):.4f}")
    
    return losses


In [None]:
# Evaluation Function

def evaluate_nplm(model, dataset, config, device='cpu'):
    """Simple evaluation"""
    model.eval()
    
    total_nll = 0
    total_tokens = 0
    
    with torch.no_grad():
        for _ in range(10):  # Evaluate on 10 batches
            batch = dataset.get_batch(config.batch_size, config.batch_length)
            batch = batch.to(device)
            
            _, nll = model(batch)
            total_nll += nll.sum().item()
            total_tokens += config.batch_size * config.batch_length
    
    avg_nll = total_nll / total_tokens
    ppl = np.exp(avg_nll)
    
    print(f"Evaluation - NLL: {avg_nll:.4f}, PPL: {ppl:.4f}")
    return avg_nll, ppl


In [None]:
# Main Execution

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create config
config = SimpleConfig()
print("Configuration created")

# Create dataset
dataset = SimpleDataset(vocab_size=1000, seq_length=10000)
print(f"Dataset created with vocab_size={dataset.vocab_size}")

# Create model
model = NPLM(config, dataset)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
print(model)


In [None]:
# Train the model
print("\nStarting training...")
losses = train_nplm(model, dataset, config, num_steps=200, device=device)

# Plot training curve
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('NLL')
plt.title('Training Loss')
plt.grid(True)
plt.show()


In [None]:
# Evaluate the model
print("\nEvaluating model...")
avg_nll, ppl = evaluate_nplm(model, dataset, config, device=device)


In [None]:
# Inspect Model Architecture

print("\n=== NPLM Architecture Details ===\n")

print("Embedding Layer:")
print(f"  - Vocab size: {model.embedding.vocab_size}")
print(f"  - Embedding dim: {model.embedding.embed_dim}")
print(f"  - Projection dim: {model.embedding.proj_dim}")

print("\nNPLM Layers:")
for i, layer in enumerate(model.layers):
    print(f"  Layer {i}:")
    print(f"    - Context config (ngm, wsz): ({layer.ngm}, {layer.wsz})")
    print(f"    - Dim concat embs: {layer.dim_concat_embs}")
    print(f"    - Has ffn_nplm: {hasattr(layer, 'ffn_nplm')}")
    print(f"    - Has regular ffn: {hasattr(layer, 'ffn')}")

print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


## Summary

This notebook demonstrates a **faithful, production-quality implementation** of the NPLM architecture with:

1. **TokenEmbedding** - Adaptive embedding with optional projection
2. **PositionEmbedding** - Sinusoidal position encoding
3. **AdaptiveSoftmax** - Hierarchical softmax for large vocabularies
4. **NPLMFF** - Feed-forward network with configurable dimensions
5. **NPLMSublayer** - Residual connection wrapper with layer normalization
6. **NPLMLayer** - Core NPLM layer with embedding concatenation and global aggregation
7. **NPLM** - Complete model with training and evaluation capabilities

### Key Features:
- Context concatenation (ngm tokens)
- Global aggregation with averaging or learned kernels
- Configurable layer depths and dimensions
- Label smoothing loss
- Adaptive softmax support

### Configuration:
The model uses `context_config = (3, 4)`, meaning:
- 3 recent tokens are concatenated with full embeddings
- Distant context is aggregated using windows of size 4
- Only layer 0 uses concatenation (specified in `concat_layers`)
- Other layers are regular feed-forward layers

This is the **actual implementation** from the repository, not a simplified version!