In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import IterableDataset

import torchtext
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import matplotlib.pyplot as plt
import spacy
import numpy as np

from copy import deepcopy
import random
import math
import time
from tqdm.notebook import tqdm

from typing import Tuple

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
# copied from https://pytorch.org/tutorials/beginner/transformer_tutorial.html

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 64
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [5]:
bptt = 150

def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].T.reshape(-1)
    return data, target

___

In [6]:
import sys
sys.path.append("../")

from sparse_transformer.layers import encoder
from sparse_transformer.factorized_attn.fixed import FixedFA

In [7]:
vocab_size = len(vocab.get_itos())
max_length = bptt
d_model = 256
d_ff = 256 * 2
n_head = 8
dropout_p = 0.1
attn_typ = 'fixed'
l = 30
c = 1
n_enc_layer = 3

In [8]:
model = encoder.Encoder(vocab_size,
                max_length,
                 d_model,
                 d_ff,
                 n_head,
                 dropout_p,
                 attn_typ,
                 l,
                 c,
                 n_enc_layer).to(device)

In [9]:
(model.fa.generate_multi_head_attn_mask(bptt) == 0).sum()

tensor(11140)

In [9]:
i = 0

model(get_batch(train_data, i)[0].T).shape

torch.Size([64, 150, 28782])

In [10]:
torch.cuda.memory_allocated('cuda:0')

85969408

___

# Train

In [10]:
# torch.autograd.set_detect_anomaly(True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2.5e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)

In [11]:
def train(model, data, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    cnt = 0
    
    for i in tqdm(range(0, data.shape[0] - 1, bptt), desc='train') :
        src, trg = get_batch(train_data, i)
        src = src.T # [batch_size, sequence_length]
        
        optimizer.zero_grad()        
        output = model(src) # [batch_size, sequence_length, hidden_dim]
        output = output.view(-1, vocab_size)
        
        loss = criterion(output, trg)
        loss.backward()        
        
        ma = torch.cuda.memory_allocated('cuda:0')
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)        
        optimizer.step()        
        
        epoch_loss += loss.item()
        cnt += 1
    return epoch_loss / cnt

def evaluate(model, data, criterion):
    
    model.eval()
    
    epoch_loss = 0
    cnt = 0
    
    with torch.no_grad() : 
        for i in tqdm(range(0, data.shape[0] - 1, bptt), desc='valid') :
            src, trg = get_batch(data, i)
            src = src.T # [batch_size, sequence_length]

            output = model(src)
            output = output.view(-1, vocab_size)
            loss = criterion(output.view(-1, output.shape[-1]), trg)
            epoch_loss += loss.item()
            cnt += 1

    return epoch_loss / cnt

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [12]:
N_EPOCHS = 20
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_data, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_data, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'gpt.pt')
        
    scheduler.step()
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

train:   0%|          | 0/214 [00:00<?, ?it/s]

KeyboardInterrupt: 