In [17]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import dataset

In [11]:
from src.grokking.dataset import ModularArithmetic, ModularArithmeticDataset, collate_fn

In [9]:
from torch.utils.data import Dataset, DataLoader

In [12]:
ma = ModularArithmetic()
train_ds = ModularArithmeticDataset(ma, train=True)
valid_ds = ModularArithmeticDataset(ma, train=False)
batch_size = 512
train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=batch_size)
print(next(iter(train_dataloader)))

(tensor([[ 3,  0, 54,  1],
        [28,  0, 86,  1],
        [50,  0, 90,  1],
        ...,
        [84,  0, 95,  1],
        [78,  0, 43,  1],
        [ 2,  0, 63,  1]]), tensor([48, 41, 59, 21, 97, 66, 43, 14, 63, 53, 12, 76, 32, 49, 34, 28, 63, 67,
        75, 62, 68, 67, 35, 92, 59, 16, 58, 61,  6, 59, 55,  2, 87, 93, 46, 13,
        37, 28, 85, 72, 76, 63, 63, 62, 86, 33, 93, 12, 11, 96, 41, 45, 51, 53,
        88, 83, 89, 68, 31, 88, 61, 16, 44, 34, 48, 13, 55, 59, 71, 27, 20,  6,
        52,  2, 82, 25, 59, 96, 44,  6, 46,  5, 98, 18, 73, 28, 28,  9, 82, 84,
        16, 39,  4, 10, 76,  7, 71, 62, 83, 53,  9, 45, 98, 57, 86,  9, 31, 54,
        10, 96, 81, 37, 23, 31, 11, 84, 77, 44, 41, 86, 25,  7,  6, 77, 68, 14,
         2, 84, 80, 56, 15, 50, 22, 20, 81, 18, 89, 63, 37, 89,  5, 48, 14, 92,
        34, 98, 76, 85, 86, 59, 52,  7, 63, 67, 83, 97, 51, 70, 13, 98, 41, 54,
        60, 24, 62, 98, 83, 59,  6, 25,  9, 15, 86, 14, 88, 23, 58, 65, 50, 37,
         2, 94, 55, 22, 86, 

In [13]:
test_data = next(iter(train_dataloader))

In [15]:
test_data[0].shape

torch.Size([512, 4])

In [16]:
test_data[1].shape

torch.Size([512])

In [None]:
tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:

In [18]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        decoder_layers = TransformerDecoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, memory: Tensor) -> Tensor:
        
        # forward the GPT model
        token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
        position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)        
        

#         src = self.encoder(src) * math.sqrt(self.d_model)
#         src = self.pos_encoder(src)
#         output = self.transformer_decoder(tgt, memory, tgt_mask, memory_mask)
        
#         output = self.decoder(output)
#         return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [19]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [4]:
device = 'cpu'

In [20]:
ntokens = 99  # size of vocabulary
emsize = 128  # embedding dimension
d_hid = 128  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 4  # number of heads in nn.MultiheadAttention
dropout = 0.0  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

In [21]:
model

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=128, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
        

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 0.0001  # learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.98), weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [22]:
src_mask = generate_square_subsequent_mask(4).to(device)

In [23]:
src_mask

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])