In [14]:
from datasets import load_dataset
import torch
import torch.nn as nn
import numpy as np

ptb = load_dataset('ptb-text-only/ptb_text_only', trust_remote_code=True)

train = ptb['train']
val = ptb['validation']
test = ptb['test']

In [15]:
def load_glove_embeddings(glove_file):
    glove_embeddings = {}

    with open(glove_file, 'r') as f:
        for line in f:
            values = line.strip().split()
            word = values[0]
            vector = np.array(values[1:], dtypre=np.float32)
            glove_embeddings[word] = vector

    return glove_embeddings

In [10]:
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):  # x: (batch, seq_len, d_model)
        pos = torch.arange(x.size(1), device=x.device).view(1, x.size(1))  # (1, seq_len)
        embedding = self.pos_embedding(pos)  # (1, seq_len, d_model)
        return x + embedding    

In [13]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_head, num_layers, dropout=0.1):
        super(TransformerLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = LearnedPositionalEncoding(d_model)


        # Transformer Encoder Layer
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_head)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=n_head,
                dim_feedforward=2048,
                dropout=dropout
            ),
            num_layers=num_layers
        )
        
        # Output layer
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x):   # x: (batch, seq_len)
        # Embedding and positional encoding
        x = self.embedding(x)   # (batch, seq_len, d_model)
        x = self.pos_encoding(x)  # (batch, seq_len, d_model)

        mask = nn.Transformer.generate_square_subsequent_mask(x.size(1)).to(x.device)   # (seq_len, seq_len)
        out = self.transformer_encoder(x, mask=mask)    # (batch, seq_len, d_model)
        
        return out