In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import re, json, pickle
import random

In [2]:
def preprocess(text, contractions_json="contractions.json"):
    contractions = json.load(open(contractions_json))
    filtered_tokenized = []
    for word in re.findall(r"[\w']+|['.,!?;\n]", text):
        if word.lower() not in contractions:
            if len(word) == 1 and word[0] == "'":
                filtered_tokenized.append(word)
            elif word[0] == "'" and word[-1] == "'":
                filtered_tokenized.append("'")
                filtered_tokenized.append(word[1:-1])
                filtered_tokenized.append("'")
            elif word[0] == "'":
                filtered_tokenized.append("'")
                filtered_tokenized.append(word[1:])
            elif word[-1] == "'":
                filtered_tokenized.append(word[:-1])
                filtered_tokenized.append("'")
            else:
                filtered_tokenized.append(word)
        else:
            split_up = contractions[word.lower()][0].split()
            if word[0].isupper():
                split_up[0] = split_up[0][0].upper() + split_up[0][1:]
            filtered_tokenized.extend(split_up)
    vocab = sorted(list(set(filtered_tokenized)))
    return filtered_tokenized, vocab


class SequenceDataset(Dataset):

    def __init__(self, txt_path, glove_path, contractions_path, context_len):
        self.context_len = context_len
        self.tokenized, vocab = preprocess("".join(open(txt_path).readlines()), contractions_path)
        self.glove = pickle.load(open(glove_path, "rb"))
        loop = tqdm(self.tokenized, total=len(self.tokenized), leave=True, position=0)
        loop.set_description("Loading dataset...")
        self.embedded = [torch.tensor(self.glove[word][:300]).unsqueeze(0) if word in self.glove else torch.randn(1, 300) for word in loop]
        self.embedded = torch.concat(self.embedded, dim=0)

        self.word2idx = {"<PAD>": 0}
        for i, word in enumerate(vocab):
            self.word2idx[word] = i + 1
        self.idx2word = {value: key for key, value in self.word2idx.items()}


    def __len__(self):
        return self.embedded.shape[0] - (self.context_len - 1)
    
    def pad_and_mask(self, seq):
        seq_len, embed_dim = seq.shape
        mask = torch.concat([torch.ones(seq_len), torch.zeros(self.context_len - seq_len)], dim=0).type(torch.bool)
        seq = torch.concat([seq, torch.zeros(self.context_len - seq_len, embed_dim)])
        return seq, mask

    def __getitem__(self, idx):
        seq_len = random.randint(1, self.context_len)
        src = self.embedded[idx:idx+seq_len]
        tgt = torch.tensor([self.word2idx[self.tokenized[idx+seq_len]]])
        src, mask = self.pad_and_mask(src)
        return src, mask, tgt

In [3]:
class PositionalEncoder(nn.Module):

    def __init__(self, d_model, seq_len, device, p=0.1):
        super(PositionalEncoder, self).__init__()
        self.pe = torch.arange(seq_len).unsqueeze(-1).repeat(1, d_model).type(torch.float32)
        even_pos = torch.arange(0, d_model, 2)
        self.pe[:, ::2] = torch.sin(self.pe[:, ::2] / (10000 ** (even_pos/d_model)))
        self.pe[:, 1::2] = torch.cos(self.pe[:, 1::2] / (10000 ** ((even_pos + 1)/d_model)))
        self.pe = self.pe.unsqueeze(0).to(device)
        self.dropout = nn.Dropout(p=p)

    # x has shape [batch, seq_len, embed_dim]
    def forward(self, x):
        return self.dropout(x + self.pe)


class MultiHeadAttention(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert output_dim % num_heads == 0    # output_dim must be divisible by num_heads
        self.num_heads, self.head_dim = num_heads, output_dim // num_heads
        self.qkv_linear = nn.Linear(input_dim, output_dim * 3)
        self.out_linear = nn.Linear(output_dim, output_dim)

    # x has shape [batch_size, seq_len, input_dim]
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        if mask is not None:
            mask_batch_size, num_heads, d1, d2 = mask.shape
            assert d1 == seq_len and d2 == seq_len
            assert mask_batch_size == batch_size and num_heads == self.num_heads

        # computing q, k and v across multiple heads with a single linear layer
        qkv = self.qkv_linear(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim * 3)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)

        attn_output = self.scaled_dot_product(q, k, v, mask)
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
        return self.out_linear(attn_output)

    # q, k and v have shape [batch_size, num_heads, seq_len, head_dim]
    def scaled_dot_product(self, q, k, v, mask):
        d_k = k.shape[-1]
        qk = q.matmul(k.transpose(-1, -2)) / d_k
        if mask is not None:
            qk = qk.masked_fill(~mask, -torch.inf)
        attn_weights = qk.softmax(dim=-1)
        return attn_weights.matmul(v)
    

class TransformerDecoder(nn.Module):

    def __init__(self, d_model, num_heads, p=0.1):
        super(TransformerDecoder, self).__init__()
        self.linear_layer = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )

        self.layers = nn.ModuleList([
            MultiHeadAttention(d_model, d_model, num_heads),
            nn.Dropout(p=p),
            nn.LayerNorm(d_model),
            MultiHeadAttention(d_model, d_model, num_heads),
            nn.Dropout(p=p),
            nn.LayerNorm(d_model),
            self.linear_layer,
            nn.Dropout(p=p),
            nn.LayerNorm(d_model)
        ])

    # x has shape [batch_size, seq_len, embed_dim]
    def forward(self, x, mask=None):
        prev = x

        for layer in self.layers:
            if isinstance(layer, nn.LayerNorm):
                x = layer(x + prev)
                prev = x
            elif isinstance(layer, MultiHeadAttention):
                x = layer(x, mask)
            else:
                x = layer(x)
        return x


class TransformerModel(nn.Module):

    def __init__(self, input_dim, d_model, output_dim, num_heads, context_len, device):
        super(TransformerModel, self).__init__()
        self.num_heads = num_heads
        self.layers = nn.ModuleList([
            nn.Embedding(input_dim, d_model),
            PositionalEncoder(d_model, context_len, device),
            TransformerDecoder(d_model, self.num_heads),
            TransformerDecoder(d_model, self.num_heads),
            TransformerDecoder(d_model, self.num_heads),
        ])
        self.out_proj = nn.Linear(d_model, output_dim)

    def forward(self, x, mask=None):
        for layer in self.layers:
            if isinstance(layer, TransformerDecoder):
                x = layer(x, mask)
            else:
                x = layer(x)
        return self.out_proj(x.view(-1, x.shape[-1]))

In [4]:
def predict(embedded, model, dataset, device):
    embedded, mask = dataset.pad_and_mask(embedded)
    embedded, mask = embedded.unsqueeze(0).to(device), mask.unsqueeze(0).to(device)
    mask = mask.unsqueeze(1).unsqueeze(-2).repeat(1, model.num_heads, dataset.context_len, 1)
    output = model(embedded, mask).view(-1).softmax(dim=0)
    return torch.multinomial(output, 1)

def sample(sentence, generate_len, max_seq_len, model, dataset, device):
    tokens, _ = preprocess(sentence)
    embedded = torch.tensor([dataset.char2idx[c] for c in tokens])
    for _ in range(generate_len):
        output = predict(embedded[(len(embedded) - max_seq_len):], model, dataset, device)
        embedded = torch.concat([embedded, output.cpu()])
    return "".join([dataset.idx2char[c.item()] for c in embedded])

In [5]:
INPUT_DIM = 62
MODEL_DIM = 512
NUM_HEADS = 8
CONTEXT_LEN = 500
BATCH_SIZE = 64
SAMPLE_GENERATE_LEN = 50
EPOCHS = 100

LR = 1e-10
BETAS = [0.9, 0.98]
DEV = torch.device("mps")

In [6]:
model = TransformerModel(
    input_dim=154254,
    d_model=512,
    output_dim=154254,
    num_heads=8,
    context_len=200,
    device=DEV
).to(DEV)

In [8]:
x = torch.arange(0, 200).unsqueeze(0).repeat(32, 1).to(DEV)
for i in tqdm(range(100), total=100):
    model(x)

 17%|█▋        | 17/100 [00:07<00:35,  2.35it/s]


KeyboardInterrupt: 

In [6]:
dataset = SequenceDataset("/content/drive/MyDrive/transformer-files/shakespeare-sonnet.txt", CONTEXT_LEN)
loader = DataLoader(dataset, BATCH_SIZE, shuffle=True)

TypeError: __init__() missing 2 required positional arguments: 'contractions_path' and 'context_len'

In [None]:
model = TransformerModel(
    input_dim=INPUT_DIM,
    d_model=MODEL_DIM,
    output_dim=len(dataset.char2idx),
    num_heads=NUM_HEADS,
    context_len=CONTEXT_LEN,
    device=DEV
).to(DEV)

In [24]:
opt = optim.Adam(model.parameters(), lr=LR, betas=BETAS)
crit = nn.CrossEntropyLoss()

In [25]:
for e in range(EPOCHS):
    loop = tqdm(enumerate(loader), total=len(loader), leave=True, position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    for i, (src, mask, tgt) in loop:
        src, mask, tgt = src.to(DEV), mask.to(DEV), tgt.to(DEV)
        mask = mask.unsqueeze(1).unsqueeze(-2).repeat(1, model.num_heads, dataset.context_len, 1)

        opt.zero_grad()
        yhat = model(src, mask)
        loss = crit(yhat, tgt.view(-1))
        loss.backward()
        opt.step()

        total_loss += loss.item()
        loop.set_postfix(loss = total_loss/(i + 1))

Epoch : [0/100]:   0%|          | 3/1463 [00:30<4:05:50, 10.10s/it, loss=4.28]


KeyboardInterrupt: 