Attention is All You Need (Vaswani et al., 2017)

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import sys
!{sys.executable} -m pip install sentencepiece

Defaulting to user installation because normal site-packages is not writeable


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from datasets import load_dataset
from transformers import T5Tokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

Load dataset, split train & test

In [4]:
dataset = load_dataset("sciq")
train_data = dataset["train"]
test_data = dataset["test"]
val_data = dataset["validation"]

Tokenization, padding, truncation

In [5]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")

MAX_INPUT_LEN = 64
MAX_TARGET_LEN = 16

def preprocess(example):
    input_text = f"question: {example['question']}"
    target_text = example["correct_answer"]

    input_enc = tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=MAX_INPUT_LEN
    )
    target_enc = tokenizer(
        target_text,
        padding="max_length",
        truncation=True,
        max_length=MAX_TARGET_LEN
    )

    return {
        "input_ids": torch.tensor(input_enc.input_ids),
        "attention_mask": torch.tensor(input_enc.attention_mask),
        "labels": torch.tensor(target_enc.input_ids),
    }

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


PyTorch dataset wrapper

In [6]:
class SciQDataset(Dataset):
    def __init__(self, hf_dataset):
        self.data = [preprocess(x) for x in hf_dataset]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

Create data loaders

In [7]:
train_dataset = SciQDataset(train_data)
val_dataset = SciQDataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

Positional encoding

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=64):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Masking function

In [9]:
def generate_causal_mask(size):
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
    return mask

Attention layer

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(0.1)

    def forward(self, x, mask=None, kv=None):
        if kv is None:
            kv = x

        B, T_q, _ = x.size()
        B, T_kv, _ = kv.size()

        q = self.q_proj(x)
        k = self.k_proj(kv)
        v = self.v_proj(kv)

        q = q.view(B, T_q, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T_kv, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T_kv, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T_q, -1)
        return self.out_proj(out)

Feed forward layer

In [11]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_ff, d_model),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        return self.net(x)

Encoder

In [12]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

Decoder

In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out):
        causal_mask = generate_causal_mask(x.size(1)).to(x.device)
        x = x + self.self_attn(self.norm1(x), mask=causal_mask)
        x = x + self.cross_attn(self.norm2(x), kv=enc_out)
        x = x + self.ff(self.norm3(x))
        return x

Tranformer model

In [14]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, num_heads=2, num_layers=2, d_ff=128, max_len=64):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.dropout = nn.Dropout(0.1)
        self.enc_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
        self.dec_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.dropout(self.pos_enc(self.tok_emb(src)))
        tgt = self.dropout(self.pos_enc(self.tok_emb(tgt)))
        for layer in self.enc_layers:
            src = layer(src)
        for layer in self.dec_layers:
            tgt = layer(tgt, src)
        return self.out(tgt)

Setup training

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = tokenizer.vocab_size
model = Transformer(vocab_size=vocab_size).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [16]:
def evaluate(model, val_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            output = model(input_ids, labels[:, :-1])
            logits = output[:, :, :]
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
            total_loss += loss.item()
    return total_loss / len(val_loader)

Training

In [17]:
EPOCHS = 50

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for batch in loop:
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        decoder_input = labels[:, :-1]
        decoder_target = labels[:, 1:]

        optimizer.zero_grad()
        output = model(input_ids, decoder_input)
        loss = loss_fn(output.view(-1, output.size(-1)), decoder_target.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(train_loss=loss.item())

    val_loss = evaluate(model, val_loader)
    print(f"\nEpoch {epoch+1} | Train Loss: {total_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")

Epoch 1: 100%|██████████| 730/730 [00:57<00:00, 12.76it/s, train_loss=5.82]



Epoch 1 | Train Loss: 6.4039 | Val Loss: 5.5893


Epoch 2: 100%|██████████| 730/730 [00:52<00:00, 13.95it/s, train_loss=5.69]



Epoch 2 | Train Loss: 5.4183 | Val Loss: 5.4089


Epoch 3: 100%|██████████| 730/730 [00:55<00:00, 13.10it/s, train_loss=5.2] 



Epoch 3 | Train Loss: 5.2564 | Val Loss: 5.2817


Epoch 4: 100%|██████████| 730/730 [00:53<00:00, 13.59it/s, train_loss=5.05]



Epoch 4 | Train Loss: 5.1243 | Val Loss: 5.1946


Epoch 5: 100%|██████████| 730/730 [00:53<00:00, 13.72it/s, train_loss=5.81]



Epoch 5 | Train Loss: 4.9995 | Val Loss: 5.0832


Epoch 6: 100%|██████████| 730/730 [00:55<00:00, 13.22it/s, train_loss=4.95]



Epoch 6 | Train Loss: 4.8850 | Val Loss: 4.9967


Epoch 7: 100%|██████████| 730/730 [00:54<00:00, 13.37it/s, train_loss=4.95]



Epoch 7 | Train Loss: 4.7773 | Val Loss: 4.9072


Epoch 8: 100%|██████████| 730/730 [00:53<00:00, 13.60it/s, train_loss=5.61]



Epoch 8 | Train Loss: 4.6739 | Val Loss: 4.8334


Epoch 9: 100%|██████████| 730/730 [00:55<00:00, 13.10it/s, train_loss=4.42]



Epoch 9 | Train Loss: 4.5794 | Val Loss: 4.7633


Epoch 10: 100%|██████████| 730/730 [00:56<00:00, 13.01it/s, train_loss=5.18]



Epoch 10 | Train Loss: 4.4866 | Val Loss: 4.7056


Epoch 11: 100%|██████████| 730/730 [00:56<00:00, 12.99it/s, train_loss=3.81]



Epoch 11 | Train Loss: 4.3991 | Val Loss: 4.6524


Epoch 12: 100%|██████████| 730/730 [00:54<00:00, 13.49it/s, train_loss=4.56]



Epoch 12 | Train Loss: 4.3290 | Val Loss: 4.6096


Epoch 13: 100%|██████████| 730/730 [00:55<00:00, 13.12it/s, train_loss=3.77]



Epoch 13 | Train Loss: 4.2583 | Val Loss: 4.5515


Epoch 14: 100%|██████████| 730/730 [00:56<00:00, 12.95it/s, train_loss=4.66]



Epoch 14 | Train Loss: 4.1974 | Val Loss: 4.5359


Epoch 15: 100%|██████████| 730/730 [00:56<00:00, 12.87it/s, train_loss=3.72]



Epoch 15 | Train Loss: 4.1369 | Val Loss: 4.4928


Epoch 16: 100%|██████████| 730/730 [00:54<00:00, 13.35it/s, train_loss=4.52]



Epoch 16 | Train Loss: 4.0749 | Val Loss: 4.4621


Epoch 17: 100%|██████████| 730/730 [01:02<00:00, 11.71it/s, train_loss=4.54]



Epoch 17 | Train Loss: 4.0197 | Val Loss: 4.4349


Epoch 18: 100%|██████████| 730/730 [01:04<00:00, 11.31it/s, train_loss=4.09]



Epoch 18 | Train Loss: 3.9747 | Val Loss: 4.4089


Epoch 19: 100%|██████████| 730/730 [01:05<00:00, 11.09it/s, train_loss=4.46]



Epoch 19 | Train Loss: 3.9279 | Val Loss: 4.3989


Epoch 20: 100%|██████████| 730/730 [01:05<00:00, 11.13it/s, train_loss=3.37]



Epoch 20 | Train Loss: 3.8744 | Val Loss: 4.3800


In [18]:
def evaluate_accuracy(model, val_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            decoder_input = labels[:, :-1]
            target = labels[:, 1:]

            output = model(input_ids, decoder_input)
            pred_ids = output.argmax(dim=-1)
            correct += (pred_ids == target).float().sum().item()
            total += target.numel()
    return correct / total

In [19]:
torch.save(model.state_dict(), "transformer_qa.pt")

In [20]:
# model.load_state_dict(torch.load("transformer_qa.pt"))
# model.eval()

In [25]:
def generate_answer(question_text, max_len=MAX_TARGET_LEN):
    model.eval()
    input_text = f"question: {question_text}"
    enc = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=MAX_INPUT_LEN
    )
    input_ids = enc["input_ids"].to(device)

    decoder_input = torch.full((1, 1), tokenizer.pad_token_id, dtype=torch.long).to(device)

    for _ in range(max_len):
        with torch.no_grad():
            logits = model(input_ids, decoder_input)

        next_token_logits = logits[:, -1, :]
        next_token = next_token_logits.argmax(dim=-1, keepdim=True)

        decoder_input = torch.cat([decoder_input, next_token], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(decoder_input[0][1:], skip_special_tokens=True)

In [26]:
model.eval()

for i in range(5):
    q = test_data[i]["question"]
    a = generate_answer(q)
    print(f"Q: {q}")
    print(f"Model: {a}")
    print(f"True: {test_data[i]['correct_answer']}\n")

Q: Compounds that are capable of accepting electrons, such as o 2 or f2, are called what?
Model: 
True: oxidants

Q: What term in biotechnology means a genetically exact copy of an organism?
Model: 
True: clone

Q: Vertebrata are characterized by the presence of what?
Model: 
True: backbone

Q: What is the height above or below sea level called?
Model: 
True: elevation

Q: Ice cores, varves and what else indicate the environmental conditions at the time of their creation?
Model: 
True: tree rings

