In [1]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset


class SelfAttentionHead(torch.nn.Module):
    def __init__(self, embedding_dim, query_dim, key_dim, value_dim):
        super().__init__()
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.Wq = torch.nn.Linear(embedding_dim, query_dim, bias=False)
        self.Wk = torch.nn.Linear(embedding_dim, key_dim, bias=False)
        self.Wv = torch.nn.Linear(embedding_dim, value_dim, bias=False)

    called = False

    def forward(self, x):
        if not SelfAttentionHead.called:
            print('my forward called')
            SelfAttentionHead.called = True
            
        q = torch.matmul(x, torch.transpose(self.Wq.weight, 0, 1))
        k = torch.matmul(x, torch.transpose(self.Wk.weight, 0, 1))
        v = torch.matmul(x, torch.transpose(self.Wv.weight.data, 0, 1))

        energy = torch.matmul(q, k.transpose(1, 2))
        normalized_energy = torch.softmax(energy / math.sqrt(self.key_dim), dim=2)
        out = torch.matmul(normalized_energy, v)

        return out


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, nheads, embedding_dim, query_dim, key_dim, value_dim):
        super().__init__()
        self.attention_heads = torch.nn.ModuleList([
            SelfAttentionHead(embedding_dim, query_dim, key_dim, value_dim)
            for _ in range(nheads)
        ])
        self.Wo = torch.nn.Linear(nheads * value_dim, embedding_dim)

    def forward(self, x):
        output = torch.cat(tuple(ah(x) for ah in self.attention_heads), dim=2)
        output = self.Wo(output)
        return output


class EncoderLayer(torch.nn.Module):
    def __init__(self, nheads, embedding_dim, query_dim, key_dim, value_dim):
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(nheads, embedding_dim, query_dim, key_dim, value_dim)
        self.norm1 = torch.nn.LayerNorm(embedding_dim)

        self.fully_connected = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(embedding_dim * 4, embedding_dim * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(embedding_dim * 4, embedding_dim),
        )

        self.norm2 = torch.nn.LayerNorm(embedding_dim)

    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        x = x + self.multi_head_attention(x)
        x = self.norm1(x)
        x = x + self.fully_connected(x)
        x = self.norm2(x)

        return x

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        # encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        encoder_layers = EncoderLayer(nheads=nhead, embedding_dim=d_model, query_dim=64, key_dim=64, value_dim=64)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.classifier = nn.Linear(d_model * 256, 2)
        self.init_weights()
        
    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        src_mask = torch.ones(src.size(0), src.size(0), device=src.device)
        output = self.transformer_encoder(src, src_mask)
        output = self.classifier(output.flatten(start_dim=1))
        return output

class Encoder(torch.nn.Module):
    def __init__(self, n_tokens, num_encoder_layers, nheads, embedding_dim, query_dim, key_dim, value_dim, max_len):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embedding = torch.nn.Embedding(n_tokens, embedding_dim)
        self.encoder_layers = torch.nn.ModuleList([
            EncoderLayer(nheads, embedding_dim, query_dim, key_dim, value_dim)
            for _ in range(num_encoder_layers)
        ])
        
        self.pe = PositionalEncoding(d_model=embedding_dim, max_len=embedding_dim // 2)
        self.classifier = torch.nn.Linear(embedding_dim, 2)

    def forward(self, x):
        # print('applying embedding to:', x.shape)
        x = self.embedding(x) * math.sqrt(self.embedding_dim)
        # print('applying pos encoding to:', x.shape)
        x = self.pe(x)
        # print('applying encoder layers to:', x.shape)
        for layer in self.encoder_layers:
            x = layer(x)
        # print('applying classifier to:', x.shape)
        x = torch.softmax(x[0], dim=1)
        return x

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [107]:
import datasets
from transformers import AutoTokenizer

torch.manual_seed(42)

tok = AutoTokenizer.from_pretrained('distilbert-base-uncased')
ntokens = len(tok.get_vocab().keys())
model = TransformerModel(ntoken=ntokens, nlayers=2, nhead=4, d_model=48, d_hid=20)
# model = Encoder(n_tokens=ntokens, num_encoder_layers=3, nheads=8, embedding_dim=200, query_dim=64, key_dim=64, value_dim=64, max_len=256)
data = datasets.load_dataset('imdb').with_format('torch').shuffle(seed=42)
train_data = data['train']
test_data = data['test']


Found cached dataset imdb (/home/ubuntu/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-01ad04b69ceba701.arrow
Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-f85c25fbaaf4c75e.arrow
Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-bfdb1e053999b3b1.arrow


In [110]:
import time
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def evaluate(model):
    print('======')
    print('Evaluating...')
    loss = 0
    correct = 0
    examples = 1000
    batch = 100
    num_batches = 10
    for i in range(0, examples, batch):
        tokenized = tok.batch_encode_plus([test_data[i]['text'] for i in range(100)], max_length=256,
                                    padding='max_length', truncation=True, return_tensors='pt')['input_ids']
        result = model(tokenized)
        actual = torch.tensor([test_data[i]['label'].item() for i in range(0, 100)])
        actual_oh = torch.nn.functional.one_hot(actual, num_classes=2).float()
        loss += torch.nn.CrossEntropyLoss()(result, actual_oh).item()
        correct += sum(result.argmax(dim=1) == actual).item()
    print('Loss: ', loss / num_batches)
    print('Accuracy: ', correct * 100 / examples, '%')
    print('======')

MAX = 8001
def train(model: torch.nn.Module, epochs: int) -> None:
    model.train()  # turn on train mode

    for epoch in range(1, epochs + 1):
        total_loss = 0.
        log_interval = 25
        batch_size = 10
        eval_interval = 400
        start_time = time.time()

        for i in range(0, MAX, batch_size):
            batch = train_data[i:i+batch_size]
            tokenized = tok.batch_encode_plus(list(batch['text']), max_length=256,
                                            padding='max_length', truncation=True, return_tensors='pt')['input_ids']
            output = model(tokenized)
            loss = criterion(output, torch.nn.functional.one_hot(batch['label'], num_classes=2).float())

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            total_loss += loss.item()
            batch_number = i // batch_size
            if batch_number % log_interval == 0 and i > 0:
                lr = scheduler.get_last_lr()[0]
                ms_per_batch = (time.time() - start_time) * 1000 / log_interval
                cur_loss = total_loss / log_interval
                ppl = math.exp(cur_loss)
                print(f'| epoch {epoch:3d} | {batch_number:5.0f}/{MAX//batch_size:5.0f} batches | '
                    f'lr {lr:02.6f} | ms/batch {ms_per_batch:5.2f} | '
                    f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
                total_loss = 0
                start_time = time.time()
            if batch_number % eval_interval == 0 and i > 0:
                evaluate(model)
                scheduler.step()

        evaluate(model)

In [111]:
train(model, 3)

| epoch   1 |    25/  800 batches | lr 0.000015 | ms/batch 41.97 | loss  0.75 | ppl     2.12
| epoch   1 |    50/  800 batches | lr 0.000015 | ms/batch 40.79 | loss  0.70 | ppl     2.02
| epoch   1 |    75/  800 batches | lr 0.000015 | ms/batch 40.89 | loss  0.73 | ppl     2.07
| epoch   1 |   100/  800 batches | lr 0.000015 | ms/batch 40.73 | loss  0.72 | ppl     2.06
| epoch   1 |   125/  800 batches | lr 0.000015 | ms/batch 40.97 | loss  0.75 | ppl     2.11
| epoch   1 |   150/  800 batches | lr 0.000015 | ms/batch 41.04 | loss  0.74 | ppl     2.09
| epoch   1 |   175/  800 batches | lr 0.000015 | ms/batch 41.14 | loss  0.72 | ppl     2.06
| epoch   1 |   200/  800 batches | lr 0.000015 | ms/batch 41.01 | loss  0.72 | ppl     2.05
| epoch   1 |   225/  800 batches | lr 0.000015 | ms/batch 40.97 | loss  0.71 | ppl     2.04
| epoch   1 |   250/  800 batches | lr 0.000015 | ms/batch 40.74 | loss  0.73 | ppl     2.09
| epoch   1 |   275/  800 batches | lr 0.000015 | ms/batch 41.03 | los

KeyboardInterrupt: 

In [42]:
predicted

tensor([0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1,
        0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0,
        1, 1, 0, 1])