In [21]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import string

In [22]:
class ELMoPretrainDataset(Dataset):

    def __init__(self, dataset_path: str, seq_len: int):
        text = open(dataset_path, "r").readlines()
        self.seq_len = seq_len
        self.text, self.char2idx, self.idx2char, self.word2idx, self.idx2word = self.preprocess(text)
        self.tokenized_chars, self.tokenized_words = self.tokenize(self.text)

    def tokenize(self, text: list[str]) -> list[list[int]]:
        char_tokenize = [[self.char2idx[char] for char in word] for word in text]
        word_tokenize = [self.word2idx[word] for word in text]
        return char_tokenize, word_tokenize
    
    def pad(self, sequence: list[list[int]]):
        max_word_len = len(max(sequence, key=lambda x: len(x)))
        for i in range(len(sequence)):
            pad_len = max_word_len - len(sequence[i])
            front_pad = pad_len // 2
            back_pad = pad_len - front_pad
            sequence[i] = ([0] * front_pad) + sequence[i] + ([0] * back_pad)

    def preprocess(self, text: list[str]) -> list[str]:
        to_remove = []
        for i, line in enumerate(text):
            if line == " \n":
                to_remove.append(i)
            elif "=" in line:
                to_remove.append(i)

        for idx in to_remove[::-1]:
            del text[idx]
        
        text = " ".join(text).lower()
        text = text.translate(str.maketrans('', '', '!"#$%&\'()*+-./:=?@[\\]^_`{|}~'))
        text = "".join([i for i in text if (not i.isdigit()) and i.isascii()])

        char2idx = {char: i + 1 for i, char in enumerate(sorted(list(set(list(text)))))}
        char2idx["<pad>"] = 0
        idx2char = {value: key for key, value in char2idx.items()}
        text = text.split()

        word2idx = {word: i for i, word in enumerate(sorted(list(set(text))))}
        idx2word = {value: key for key, value in word2idx.items()}
        return text, char2idx, idx2char, word2idx, idx2word
    
    def __len__(self) -> int:
        return len(self.text) - self.seq_len
    
    def __getitem__(self, idx):
        src, tgt = self.tokenized_chars[idx: idx + self.seq_len], self.tokenized_words[idx+1: idx + self.seq_len + 1]
        self.pad(src)
        return torch.tensor(src), torch.tensor(tgt)

In [23]:
class PadCollate:

    def __init__(self, dim=0):
        self.dim = dim

    def __call__(self, batch):
        max_len = max(batch, key=lambda x: x[0].shape[-1])[0].shape[-1]
        seq_len = batch[0][1].shape[0]
        srcs, tgts = [], []
        for src, tgt in batch:
            pad_len = max_len - src.shape[-1]
            front_pad = pad_len // 2
            back_pad = pad_len - front_pad
            srcs.append(torch.cat([torch.zeros(seq_len, front_pad), src, torch.zeros(seq_len, back_pad)], dim=-1))
            tgts.append(tgt)
        return torch.stack(srcs).type(torch.long), torch.stack(tgts)

In [24]:
class CharacterConvolutions(nn.Module):

    def __init__(self, in_size, embed_dim):
        super(CharacterConvolutions, self).__init__()
        self.embedding = nn.Embedding(in_size, embed_dim)
        conv_layer_params = [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]]
        self.conv_layers = nn.ModuleList([nn.Conv2d(embed_dim, out_dim, (1, ksize)) for ksize, out_dim in conv_layer_params])
        self.relu = nn.ReLU()

    # x has shape [batch_size, seq_len, characters]
    def forward(self, x):
        x = self.embedding(x) # [batch_size, seq_len, characters, char_embed_dim]
        x = x.permute(0, 3, 1, 2) # [batch_size, char_embed_dim, seq_len, characters]
        convs = [c(x) for c in self.conv_layers]
        pools = [self.relu(F.max_pool2d(c, kernel_size=(1, c.shape[-1]))).squeeze(-1).permute(0, 2, 1) for c in convs]
        return torch.cat(pools, dim=-1)

In [25]:
class ELMoLSTM(nn.Module):

    def __init__(self, input_size, hidden_size, dropout):
        super(ELMoLSTM, self).__init__()
        self.in_proj = nn.Linear(input_size, hidden_size)
        self.lstm1 = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.middle_proj = nn.Linear(hidden_size * 2, input_size)
        self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.out_proj = nn.Linear(hidden_size * 2, input_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.in_proj(x)
        layer1 = self.middle_proj(self.lstm1(x)[0] + x.repeat(1, 1, 2))
        layer2 = self.out_proj(self.lstm2(self.dropout(layer1))[0])
        return layer1, layer2

In [26]:
class ELMo(nn.Module):

    def __init__(self, input_size, char_embed_dim=16, output_size=128, hidden_size=1024, dropout=0.5):
        super(ELMo, self).__init__()
        self.character_convolutions = CharacterConvolutions(input_size, char_embed_dim)
        self.highway = nn.Linear(2048, 2048)
        self.highway_gate = nn.Linear(2048, 2048)
        self.in_proj = nn.Linear(2048, output_size)
        self.elmo_lstm = ELMoLSTM(output_size, hidden_size=hidden_size, dropout=dropout)

    def forward(self, x):
        x = self.character_convolutions(x)
        h = self.highway(x)
        h_gate = F.sigmoid(self.highway_gate(x))
        layer1 = self.in_proj((h * h_gate) + (x * (1 - h_gate)))
        layer2, layer3 = self.elmo_lstm(layer1)
        return layer1, layer2, layer3

In [27]:
class ELMoPretrainModel(nn.Module):

    def __init__(self, num_chars, output_size, num_words, dropout=0.5):
        super(ELMoPretrainModel, self).__init__()
        self.elmo = ELMo(num_chars, output_size=output_size)
        self.dropout = nn.Dropout(p=dropout)
        self.out_proj = nn.Linear(output_size, num_words)

    def forward(self, x):
        _, _, x = self.elmo(x)
        return self.out_proj(self.dropout(x))

In [28]:
dataset = ELMoPretrainDataset("wikitext-2/wiki.train.tokens", seq_len=100)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=PadCollate())

In [29]:
model = ELMoPretrainModel(len(dataset.char2idx), 128, len(dataset.word2idx)).to("mps")
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), 3e-6)
EPOCHS = 100

In [30]:
for e in range(EPOCHS):
    loop = tqdm(loader, total=len(loader), position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    for src, tgt in loop:
        src, tgt = src.to("mps"), tgt.to("mps")
        opt.zero_grad()
        yhat = model(src)
        loss = crit(yhat.view(-1, yhat.shape[-1]), tgt.view(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()
        print(loss.item())
        loop.set_postfix(loss = loss.item())

Epoch : [0/100]:   0%|          | 1/55723 [00:02<32:24:43,  2.09s/it, loss=10.2]

10.233040809631348


Epoch : [0/100]:   0%|          | 2/55723 [00:03<25:59:49,  1.68s/it, loss=10.2]

10.225953102111816


Epoch : [0/100]:   0%|          | 3/55723 [00:04<23:17:06,  1.50s/it, loss=10.2]

10.229717254638672


Epoch : [0/100]:   0%|          | 4/55723 [00:06<22:08:29,  1.43s/it, loss=10.2]

10.229687690734863


Epoch : [0/100]:   0%|          | 5/55723 [00:07<21:36:31,  1.40s/it, loss=10.2]

10.226344108581543


Epoch : [0/100]:   0%|          | 6/55723 [00:08<21:14:49,  1.37s/it, loss=10.2]

10.229728698730469


Epoch : [0/100]:   0%|          | 6/55723 [00:09<25:34:45,  1.65s/it, loss=10.2]


KeyboardInterrupt: 