In [1]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from numpy import sqrt, log

class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * sqrt(self.emb_size)


In [2]:
class TokenTypeTransformer(nn.Module):
    def __init__(self,
                 num_layers: int,
                 emb_size: int,
                 nhead: int,
                 vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(TokenTypeTransformer, self).__init__()
        encoder_layer = TransformerEncoderLayer(d_model=emb_size,
                                                nhead=nhead,
                                                dim_feedforward=dim_feedforward,
                                                dropout=dropout,
                                                batch_first=True
                                                )
        self.transformer: TransformerEncoder = TransformerEncoder(encoder_layer=encoder_layer,
                                              num_layers=num_layers)
        self.generator = nn.Linear(emb_size, vocab_size)
        self.tok_emb = TokenEmbedding(vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)
    

    def forward(self,
                src: Tensor,
                src_mask: Tensor = None,
                src_padding_mask: Tensor = None,
                is_causal: bool = None):
        src_emb = self.positional_encoding(self.tok_emb(src))
        outs = self.transformer(src=src_emb, mask=src_mask, src_key_padding_mask=src_padding_mask, is_causal=is_causal)
        return self.generator(outs)

In [3]:
from dataset import TokenTypesDataset


train_dataset = TokenTypesDataset(folder="../rnn_tokentype_data/train")
val_dataset = TokenTypesDataset(folder="../rnn_tokentype_data/validation", train=False, vocabs=(train_dataset.token2idx, train_dataset.idx2token), max_length=train_dataset.max_length)
test_dataset = TokenTypesDataset(folder="../rnn_tokentype_data/test", train=False, vocabs=(train_dataset.token2idx, train_dataset.idx2token), max_length=train_dataset.max_length)
assert val_dataset.vocab_size == train_dataset.vocab_size == test_dataset.vocab_size
assert val_dataset.max_length == train_dataset.max_length == test_dataset.max_length


In [4]:
def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_masks(src, pad_idx, device):
    src_seq_len = src.shape[1]

    src_mask = generate_square_subsequent_mask(src_seq_len, device=device)

    src_padding_mask = (src == pad_idx)
    return src_mask, src_padding_mask

In [5]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

def train_epoch(model: TokenTypeTransformer, optimizer, loss_fn, train_dataloader: DataLoader, device):
    model.train()
    losses = 0

    for src in tqdm(train_dataloader, leave=False):
        src = src.to(device)

        src_mask, src_padding_mask = create_masks(src, pad_idx=train_dataloader.dataset.pad_id, device=device)

        logits = model.forward(src, src_mask, src_padding_mask, is_causal=True)[:, :-1, :]

        src_out = src[:, 1:]

        optimizer.zero_grad()
        
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), src_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model, loss_fn, val_dataloader, device):
    model.eval()
    losses = 0

    for src in tqdm(val_dataloader, leave=False):
        src = src.to(device)
        src_mask, src_padding_mask = create_masks(src, pad_idx=val_dataloader.dataset.pad_id, device=device)

        logits = model.forward(src, src_mask, src_padding_mask, is_causal=True)[:, :-1, :]

        src_out = src[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), src_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [6]:
import torch.nn as nn
from torch.utils.data import DataLoader
from timeit import default_timer as timer

NUM_EPOCHS = 15
BATCH_SIZE = 256

torch.manual_seed(42)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = TokenTypeTransformer(
                            num_layers=3,
                            emb_size=128,
                            nhead=4,
                            vocab_size=train_dataset.vocab_size
                            )

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = model.to(device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_dataset.pad_id)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True)

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer, loss_fn, train_loader, device)
    end_time = timer()
    val_loss = evaluate(transformer, loss_fn, val_loader, device)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

  0%|          | 0/231 [00:00<?, ?it/s]



  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.350, Val loss: 1.010, Epoch time = 22.735s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 2, Train loss: 1.038, Val loss: 0.922, Epoch time = 22.397s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 3, Train loss: 0.977, Val loss: 0.888, Epoch time = 22.380s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 4, Train loss: 0.944, Val loss: 0.867, Epoch time = 22.390s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 5, Train loss: 0.923, Val loss: 0.853, Epoch time = 22.399s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 6, Train loss: 0.908, Val loss: 0.842, Epoch time = 22.382s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 7, Train loss: 0.896, Val loss: 0.832, Epoch time = 22.397s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 8, Train loss: 0.886, Val loss: 0.823, Epoch time = 22.405s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 9, Train loss: 0.876, Val loss: 0.816, Epoch time = 22.415s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 10, Train loss: 0.868, Val loss: 0.817, Epoch time = 22.384s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 11, Train loss: 0.860, Val loss: 0.810, Epoch time = 22.423s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 12, Train loss: 0.853, Val loss: 0.813, Epoch time = 22.413s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 13, Train loss: 0.846, Val loss: 0.802, Epoch time = 22.451s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 14, Train loss: 0.840, Val loss: 0.809, Epoch time = 22.436s


  0%|          | 0/231 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 15, Train loss: 0.835, Val loss: 0.803, Epoch time = 22.415s


In [23]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
for batch in train_loader:
    src_mask, src_pad_mask = (create_masks(batch, pad_idx=0, device=device))
    src_mask = src_mask.to(device)
    src_pad_mask = src_pad_mask.to(device)
    batch = batch.to(device)
    is_bos = (batch == train_loader.dataset.bos_id)

    labels = batch[:, 1:]
    logits_without_last = transformer(batch, src_mask, src_pad_mask)[:, :-1, :]
    predictions = logits_without_last.argmax(dim=-1)

    print(labels)

    print("--------")

    print(predictions)

    print("--------")

    correct_predictions = ((predictions == labels) & labels != train_loader.dataset.pad_id)
    print(correct_predictions)

    print("--------")
    
    print(transformer(batch, src_mask, src_pad_mask).shape)
    break
    

tensor([[ 4,  4,  5,  4,  4,  4,  5,  4,  6,  5,  4,  5, 17, 10,  2,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,

In [None]:
def compute_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Compute the accuracy of predictions for an RNN language model.
    
    :param logits: Logits output by the model of shape [batch_size, sequence_length, vocab_size]
    :param labels: Ground truth labels of shape [batch_size, sequence_length]
    :return: Accuracy as a float
    """
    # Find the argmax of the logits along the last dimension to get the most likely token indices
    predictions = logits.argmax(dim=-1)
    
    # Compute the number of correct predictions
    correct_predictions = (predictions == labels).float().sum()
    
    # Calculate the accuracy
    accuracy = correct_predictions / labels.numel()
    
    return accuracy.item()

def calc_accuracy(model, loader) -> float:
    with torch.no_grad():
        model.eval()
        for indices, lengths in tqdm(loader, desc=tqdm_desc):
            optimizer.zero_grad()
            indices = indices[:, :lengths.max()].to(device)
            logits = model(indices[:, :-1], lengths - 1)
            loss = criterion(logits.transpose(1, 2), indices[:, 1:])
            loss.backward()
            optimizer.step()
        
            train_acc += compute_accuracy(logits, indices[:, 1:]) * indices.shape[0]
        
            train_loss += loss.item() * indices.shape[0]
        
        train_loss /= len(loader.dataset)
        train_acc = train_acc / len(loader.dataset)
        return train_loss, train_acc