In [1]:
import time
import math
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
from torchinfo import summary

In [2]:
text = open("nietzsche.txt").read()
chars = sorted(list(set(text)))
chars.insert(0, '\0')

In [3]:
char_to_index = {v: i for i, v in enumerate(chars)}
index_to_char = {i: v for i, v in enumerate(chars)}

In [4]:
total_index = [char_to_index[char] for char in text]
pred_num = 25
xin = [[total_index[j + i] for j in range(0, len(total_index) - 1 - pred_num, pred_num)] for i in range(pred_num)]
y = [total_index[i + pred_num] for i in range(0, len(total_index) - 1 - pred_num, pred_num)]

In [5]:
X = np.stack([np.stack(xin[i][:-2]) for i in range(pred_num)],1)
Y = np.stack(y[:-2])

In [6]:
X_tensor = torch.tensor(X, dtype=torch.long)
Y_tensor = torch.tensor(Y, dtype=torch.long)

In [7]:
class MultiHeadSelfAttention(nn.Module):
    def __init__( self, in_size: int, out_size: int, n_heads: int, n_timesteps: int, dropout_prob: float = 0, device = 'cpu' ):
        super(MultiHeadSelfAttention, self).__init__()
        
        self.Wk = nn.Linear(in_size, in_size, bias=True)
        self.Wq = nn.Linear(in_size, in_size, bias=True)
        self.Wv = nn.Linear(in_size, in_size, bias=True)
        self.residual_proj = nn.Linear(in_size, out_size, bias=True)
        
        # Create lower triangular mask
        mask = torch.tril(torch.ones(n_timesteps, n_timesteps))
        self.register_buffer('mask', mask)
        
        self.att_dropout = nn.Dropout(dropout_prob)
        self.residual_dropout = nn.Dropout(dropout_prob)
        self.softmax = nn.Softmax(dim=-1)
        
        # Store head_size and verify that it's an integer
        self.H = in_size // n_heads
        if in_size % n_heads != 0:
            raise ValueError("Embedding dimension not divisible in equal heads.")
        
        # Move to specified device
        self.to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        H = self.H
        nh = D // H  # Num heads
        
        # Get key, queries and values from the input
        k = self.Wk(x)  # (B, T, D)
        q = self.Wq(x)  # (B, T, D)
        v = self.Wv(x)  # (B, T, D)
        
        # Reshape into different heads
        k = k.reshape(B, T, nh, H).transpose(1, 2)  # (B, nh, T, H)
        q = q.reshape(B, T, nh, H).transpose(1, 2)  # (B, nh, T, H)
        v = v.reshape(B, T, nh, H).transpose(1, 2)  # (B, nh, T, H)
        
        # Compute attention activation
        kT = k.transpose(-2, -1)  # (B, nh, H, T)
        att = torch.matmul(q, kT)  # (B, nh, T, T)
        
        # Scale attention scores
        att = att / (H ** 2)
        
        # Apply mask (to block out future characters)
        mask = self.mask[:T, :T]  # Get appropriate size mask
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        att = att.masked_fill(mask == 0, float('-inf'))
        
        # Apply softmax and dropout
        att = self.softmax(att)
        att = self.att_dropout(att)
        
        # Compute weighted sum between values
        out = torch.matmul(att, v)  # (B, nh, T, H)
        
        # Restack heads in D dimension
        out = out.transpose(1, 2).reshape(B, T, D)  # (B, T, D)
        
        # Apply final projection (Dense layer) and dropout
        out = self.residual_proj(out)  # (B, T, out_size)
        out = self.residual_dropout(out)
        
        return out


In [8]:
class FullyConnected(nn.Module):
    def __init__(self, in_size: int, out_size: int, dropout_prob: float = 0, device: str = 'cpu', bias: bool = True):
        super(FullyConnected, self).__init__()
        
        self.l1 = nn.Linear(in_size, in_size * 2, bias=bias)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(in_size * 2, out_size)
        self.dropout = nn.Dropout(dropout_prob)
        
        # Move to specified device
        self.to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.l1(x)
        z = self.relu(z)
        z = self.l2(z)
        z = self.dropout(z)
        return z


In [9]:
class Block(nn.Module):
    def __init__( self, in_size: int, out_size: int, n_heads: int, n_timesteps: int, dropout_prob: float = 0, device: str = 'cpu' ):
        super(Block, self).__init__()
        
        self.att = MultiHeadSelfAttention( in_size, in_size, n_heads, n_timesteps, dropout_prob, device )
        self.ln1 = nn.LayerNorm(in_size)
        self.fcc = FullyConnected(in_size, out_size, dropout_prob, device, True)
        self.ln2 = nn.LayerNorm(out_size)
        
        self.to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = x + self.att(self.ln1(x))
        z = z + self.fcc(self.ln2(z))
        
        return z


In [10]:
class PositionalEmbedding(nn.Module):
    def __init__(self, input_size: int, embed_size: int):
        super(PositionalEmbedding, self).__init__()
        self.E = nn.Parameter(torch.randn(input_size, embed_size))
    
    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        _, T = idx.shape
        
        positions = torch.arange(T, device=idx.device) # Create indices for the positions (0 to T-1)
        x = self.E[positions]
        
        batch_size = idx.shape[0]
        x = x.unsqueeze(0).expand(batch_size, -1, -1)
        return x

In [11]:
class Transformer(nn.Module):
    """
    Transformer model translated from TypeScript
    """
    def __init__(self, vocab_size, hidden_size, n_timesteps, n_heads, dropout_p=0.2, device=None):
        super(Transformer, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = PositionalEmbedding(n_timesteps, hidden_size)
        self.b1 = Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p, device)
        self.b2 = Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p, device)
        self.ln = nn.LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
        if device:
            self.to(device)
    
    def forward(self, x):
        # Embedding and positional encoding
        z = self.embed(x)
        z = z + self.pos_embed(x)  # Simplified addition
        
        z = self.b1(z)
        z = self.b2(z)
        
        z = self.ln(z)
        last_token_output = z[:, -1, :]

        out = self.linear(last_token_output)
        
        return out

In [20]:
class SimpleTransformer(nn.Module):
    """
    Simple transformer in pure Pytorch
    """
    def __init__(self, vocab_size, embedding_dim, num_heads, num_layers, hidden_dim):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward=hidden_dim,dropout=0.5),num_layers=num_layers)
        self.fc = nn.Linear(embedding_dim, vocab_size)
        self.embedding_dim = embedding_dim 
        self.pos_encoder = PositionalEncoding(embedding_dim)


    def forward(self, x):
        embedded = self.embedding(x)
        embedded = self.pos_encoder(embedded)
        embedded = embedded.transpose(0, 1)

        transformer_out = self.transformer_encoder(embedded) 
        transformer_out = transformer_out.transpose(0, 1) 
        
        last_token_output = transformer_out[:, -1, :]

        out = self.fc(last_token_output) # (batch_size, vocab_size)
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) 
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


In [21]:
vocab_size = 86
embedding_dim = 42
num_heads = 2  # Number of attention heads
num_layers = 2 # Number of Transformer encoder layers
hidden_dim = 128 # Hidden dimension for feedforward network in Transformer

In [22]:
model = SimpleTransformer(vocab_size, embedding_dim, num_heads, num_layers, hidden_dim)
# model = Transformer( vocab_size=vocab_size, hidden_size=hidden_dim, n_timesteps=25, n_heads=num_heads, dropout_p=0.2, device='cpu' )



In [23]:
summary(model)

Layer (type:depth-idx)                                            Param #
SimpleTransformer                                                 --
├─Embedding: 1-1                                                  3,612
├─TransformerEncoder: 1-2                                         --
│    └─ModuleList: 2-1                                            --
│    │    └─TransformerEncoderLayer: 3-1                          18,314
│    │    └─TransformerEncoderLayer: 3-2                          18,314
├─Linear: 1-3                                                     3,698
├─PositionalEncoding: 1-4                                         --
│    └─Dropout: 2-2                                               --
Total params: 43,938
Trainable params: 43,938
Non-trainable params: 0

In [24]:
print("Trainable Parameters:",sum(p.numel() for p in model.parameters() if p.requires_grad),"Total:",sum(p.numel() for p in model.parameters()))

Trainable Parameters: 43938 Total: 43938


In [25]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [26]:
epochs = 100
batch_size = 64

In [None]:
for epoch in range(epochs):
    current_time = time.time()
    for i in range(0, len(X), batch_size):

        X_batch = X_tensor[i:i + batch_size]
        Y_batch = Y_tensor[i:i + batch_size]

        optimizer.zero_grad()

        outputs = model(X_batch)
        loss = criterion(outputs, Y_batch)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}, Time: {time.time() - current_time:.2f}s')

Epoch [1/100], Loss: 2.7215, Time: 3.30s
Epoch [2/100], Loss: 2.7262, Time: 3.16s
Epoch [3/100], Loss: 2.6038, Time: 3.14s
Epoch [4/100], Loss: 2.7319, Time: 3.20s
Epoch [5/100], Loss: 2.5058, Time: 3.13s
Epoch [6/100], Loss: 2.5415, Time: 3.14s
Epoch [7/100], Loss: 2.4053, Time: 3.13s
Epoch [8/100], Loss: 2.4882, Time: 3.21s
Epoch [9/100], Loss: 2.5397, Time: 3.56s
Epoch [10/100], Loss: 2.4738, Time: 3.36s
Epoch [11/100], Loss: 2.5524, Time: 3.41s
Epoch [12/100], Loss: 2.4964, Time: 3.09s
Epoch [13/100], Loss: 2.4651, Time: 2.98s
Epoch [14/100], Loss: 2.5220, Time: 2.99s
Epoch [15/100], Loss: 2.5625, Time: 2.98s
Epoch [16/100], Loss: 2.5308, Time: 3.01s
Epoch [17/100], Loss: 2.5118, Time: 2.96s
Epoch [18/100], Loss: 2.5505, Time: 2.99s
Epoch [19/100], Loss: 2.4972, Time: 2.99s
Epoch [20/100], Loss: 2.5287, Time: 3.00s
Epoch [21/100], Loss: 2.5252, Time: 3.00s
Epoch [22/100], Loss: 2.3908, Time: 2.98s
Epoch [23/100], Loss: 2.4932, Time: 2.98s
Epoch [24/100], Loss: 2.4242, Time: 2.98s
E

In [None]:
def predict_next_char(inp):
    model.eval() # Set the model to evaluation mode

    index = [char_to_index[i] for i in inp]
    arr = np.expand_dims(np.array(index), axis=0)
    input_tensor = torch.tensor(arr, dtype=torch.long) # Convert to tensor
    with torch.no_grad(): # Disable gradient calculation during inference
        prediction = model(input_tensor)
    predicted_index = torch.argmax(prediction).item() # get the index of the maximum log-probability
    return index_to_char[predicted_index],inp+index_to_char[predicted_index]

In [None]:
print(predict_next_char('those w'))
print(predict_next_char(' th'))
print(predict_next_char(' an'))
print(predict_next_char('does th'))
print(predict_next_char('woma'))
print(predict_next_char('philosoph'))

('e', 'those we')
('e', ' the')
('e', ' ane')
('e', 'does the')
('n', 'woman')
(' ', 'philosoph ')


In [35]:
torch.save(model.state_dict(), 'simpleTransformer_3pred.pth')