In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer


In [22]:
params = {
    "epochs": 100,
    "learning_rate": 3e-4,
    "batch_size": 32,
    "embedding_dim": 512,
    "nhead": 8,
    "num_layers": 3,
    "dropout": 0.1,
    "block_size": 13,
    "dim_feedforward": 4,
}
load_model = False
save_model = True
model_filename = "models/hafez.pt"


In [21]:
class Head(nn.Module):
    """
    Self-attention head layer.
    """

    def __init__(self, embedding_dim, head_size, dropout=0.0):
        super().__init__()

        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, v, k, q, mask=None):
        _, _, C = q.shape
        value, key, query = self.value(v), self.key(k), self.query(q)
        weights = query @ key.transpose(-2, -1) * C**-0.5

        if mask is not None:
            weights = weights.masked_fill(mask, float("-inf"))

        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        out = weights @ value
        return out


class MultiheadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, dropout=0.0):
        super().__init__()

        assert (
            embedding_dim % num_heads == 0
        ), f"{embedding_dim=} must be divisible by {num_heads=}"
        head_size = embedding_dim // num_heads

        self.ln = nn.LayerNorm(embedding_dim)
        self.heads = nn.ModuleList(
            [Head(embedding_dim, head_size, dropout) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, v, k, q, mask=None):
        v, k, q = self.ln(v), self.ln(k), self.ln(q)
        out = torch.cat([head(v, k, q, mask) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, embedding_dim, dim_feedforward, dropout=0.0):
        super().__init__()

        # feed-forward network
        self.ffn = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, dim_feedforward * embedding_dim),
            nn.ReLU(),
            nn.Linear(dim_feedforward * embedding_dim, embedding_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        out = self.ffn(x)
        return out


def generate_square_subsequent_mask(sz, device):
    return torch.tril(torch.ones(sz, sz).to(device)) == 0


class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, dim_feedforward, dropout=0.0):
        super().__init__()

        # multi-head self attention with triangular mask. Nodes communicate only
        # with previous nodes.
        self.attn = MultiheadAttention(embedding_dim, num_heads, dropout)
        self.ffn = FeedForward(embedding_dim, dim_feedforward)

    def forward(self, x, mask):
        out = x
        out = out + self.attn(out, out, out, mask)
        out = out + self.ffn(out)
        return out


class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_layers,
        block_size,
        embedding_dim,
        nhead,
        dim_feedforward,
        dropout,
        device,
    ):
        super().__init__()
        self.device = device

        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.pos = nn.Embedding(block_size, embedding_dim)

        self.decoders = nn.ModuleList(
            [
                DecoderBlock(
                    embedding_dim,
                    nhead,
                    dim_feedforward,
                    dropout,
                )
                for _ in range(num_layers)
            ]
        )

        self.proj = nn.Linear(embedding_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        _, T = x.shape
        positions = torch.arange(T).unsqueeze(0).to(self.device)
        out = self.emb(x) + self.pos(positions)
        out = self.dropout(out)

        mask = generate_square_subsequent_mask(T, self.device)

        for decoder in self.decoders:
            out = decoder(out, mask)

        out = self.proj(out)
        out = self.dropout(out)

        return out


In [27]:
def get_tokenizer(filenames):
    tk = Tokenizer(BPE(unk_token="[UNK]"))
    tk.enable_padding(pad_id=3)

    trainer = BpeTrainer(special_tokens=["[SOS]", "[EOS]", "[UNK]", "[PAD]"])
    tk.train(filenames, trainer)

    return tk


def split(data, train_ratio=0.8, val_ratio=0.1):
    ntrain = int(train_ratio * len(data))
    nval = int(val_ratio * len(data))

    train = data[:ntrain]
    val = data[ntrain : ntrain + nval]
    test = data[ntrain + nval :]

    return train, val, test


def get_batch(data, batch_size, block_size, device):
    """
    Generates a batch of examples.
    """
    indices = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in indices]).to(device)
    y = torch.stack([data[i+1:i+block_size+1] for i in indices]).to(device)

    return x, y


def get_loss(logits, y, ignore_index):
    """
    Computes cross-entropy loss, given logits and labels.
    """
    B, T, C = logits.shape
    # F.cross_entropy expects size C, (B, C), or (B, C, ...)
    # logits shape is (B, T, C), so we flatten the first two dimensions.
    return F.cross_entropy(
        logits.view(B * T, C), y.reshape(B * T), ignore_index=ignore_index
    )


def generate(first_mesra, tk, model, device):
    """
    Generates second mesra.
    """
    token_ids = tk.encode(first_mesra).ids
    x = torch.tensor(token_ids, dtype=torch.long, device=device).unsqueeze(0)

    while True:
        logits = model(x)
        # only consider the last logit
        logits = logits[:, -1, :]
        score = F.softmax(logits, dim=-1)
        next_token_id = score.multinomial(1)
        x = torch.cat((x, next_token_id), dim=1)
        if "\n" in tk.id_to_token(next_token_id):
            break

    x = x.view(-1)
    return " ".join([tk.id_to_token(t) for t in x])


In [None]:
!mkdir -p data models
!wget https://raw.githubusercontent.com/eissana/poetGPT/master/data/hafez.txt -O data/hafez.txt

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")

text_file = "data/hafez.txt"

tk = get_tokenizer([text_file])
vocab_size = tk.get_vocab_size()
print(f"vocab size: {vocab_size}")


running on cpu



vocab size: 30000


In [23]:
model = Transformer(
    vocab_size=vocab_size,
    num_layers=params["num_layers"],
    block_size=params["block_size"],
    embedding_dim=params["embedding_dim"],
    nhead=params["nhead"],
    dim_feedforward=params["dim_feedforward"],
    dropout=params["dropout"],
    device=device,
).to(device)

num_params = sum([p.nelement() for p in model.parameters()])
print(f"model parameters: {num_params}")


model parameters: 511776


In [8]:
with open(text_file) as f:
    text = f.read()

token_ids = torch.tensor(tk.encode(text).ids, dtype=torch.long)

train, val, _ = split(token_ids, 0.9, 0.1)


In [24]:
train_losses, val_losses = [], []

In [25]:
optimizer = torch.optim.AdamW(model.parameters(), lr=params["learning_rate"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10
)

if load_model:
    state = torch.load(model_filename)

    model.load_state_dict(state["model"])
    optimizer.load_state_dict(state["optimizer"])
    scheduler.load_state_dict(state["scheduler"])

first_mesra = "الا يا ايها الساقی ادر کاسا و ناولها"
print(f"\nfirst mesra: {first_mesra}")

for epoch in range(params["epochs"]):
    model.eval()
    with torch.no_grad():
      if epoch % 10 == 0:
        print(f"epoch {epoch+1} / {params['epochs']}")
        second_mesra = generate(first_mesra, tk, model, device)
        print(f"second mesra:\n{second_mesra}")

      x, y = get_batch(val, params["batch_size"], params["block_size"], device)

      logits = model(x)
      vloss = get_loss(logits, y, ignore_index=tk.token_to_id("[PAD]"))
      val_losses.append(vloss.item())

    model.train()
    x, y = get_batch(train, params["batch_size"], params["block_size"], device)

    logits = model(x)
    loss = get_loss(logits, y, ignore_index=tk.token_to_id("[PAD]"))
    train_losses.append(loss.item())

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()

    if save_model:
        checkpoint = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        }
        torch.save(checkpoint, model_filename)

    scheduler.step(train_losses[-1])



first mesra: الا يا ايها الساقی ادر کاسا و ناولها
epoch 1 / 10
second mesra:
يا ايها ال ساقی ادر  کاسا و  ناول ها مغبچه  طاووس  جان زير  ديدار  جان سوز



In [26]:
print(f"loss of a random model: {np.log(tk.get_vocab_size())}")
print(f"final training loss: {np.mean(train_losses)}")
print(f"final validation loss: {np.mean(val_losses)}")

loss of a random model: 10.308952660644293
final training loss: 10.964788627624511
final validation loss: 10.87977409362793


In [None]:
eval_size = 10
plt.plot(torch.tensor(train_losses).view(-1, eval_size).mean(axis=1));
plt.plot(torch.tensor(val_losses).view(-1, eval_size).mean(axis=1));
plt.legend(['training loss', 'validation loss']);