<a href="https://colab.research.google.com/github/eisbetterthanpi/transformer/blob/main/transformer_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/9cf2d4ead514e661e20d2070c9bf7324/transformer_tutorial.ipynb
%pip install portalocker
%pip install torchdata

In [None]:
# @title data
import torch
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train') # train_iter will be "consumed" by the process of building the vocab
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter):
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# line by line of wiki  = Valkyria Chronicles III =
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter) # list of int, [2049990]
val_data = data_process(val_iter)
test_data = data_process(test_iter)

# batch_size = 20
# eval_batch_size = 10
# # # text transposed, concat remove spaces
# train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
# # val_data = batchify(val_data, eval_batch_size)
# # test_data = batchify(test_data, eval_batch_size)

# print(train_data[:10])

def detoken(tgt_tokens):
    return " ".join(vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")



In [None]:
# @title Datasetme

class Datasetme(torch.utils.data.Dataset):
    def __init__(self, raw_data, batch_size):
        # train_iter = WikiText2(split='train')
        # tokenizer = get_tokenizer('basic_english')
        # vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
        # vocab.set_default_index(vocab['<unk>'])
        int_data = self.data_process(raw_data) # list of int, [2049990]
        self.batch_size = batch_size
        self.data = self.batchify(int_data, batch_size)
        self.bptt = 35
        self.ind = torch.arange(0, self.data.size(0) - 1, step=self.bptt)
        # self.data.size(0) // self.batch_size
        # print(self.data.shape)

    def data_process(self, raw_text_iter):
        data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

    def batchify(self, data, bsz):
        seq_len = data.size(0) // bsz
        data = data[:seq_len * bsz]
        data = data.view(bsz, seq_len).t().contiguous()
        return data#.to(device)

    def get_batch(self, source, i): # [full_seq_len, batch_size], int
        seq_len = min(self.bptt, len(source) - 1 - i)
        data = source[i:i+seq_len]
        target = source[i+1:i+1+seq_len].reshape(-1)
        return data, target

    def detoken(self, tgt_tokens):
        return " ".join(vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

    def __len__(self):
        # return len(self.data)
        return self.data.size(0) // self.batch_size

    def __getitem__(self, index):
        data, targets = self.get_batch(self.data, self.ind[index])
        return data, targets

train_iter, val_iter, test_iter = WikiText2() # line by line of wiki  = Valkyria Chronicles III =
batch_size=128
train_loader = Datasetme(train_iter, batch_size)
val_loader = Datasetme(val_iter, batch_size)
test_loader = Datasetme(test_iter, batch_size)

train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
batch_first=False
drop_last=True



In [None]:

# @title Datasetme

class Datasetme(torch.utils.data.Dataset):
    # def __init__(self, raw_data, batch_size):
    def __init__(self, raw_data):
        # train_iter = WikiText2(split='train')
        # tokenizer = get_tokenizer('basic_english')
        # vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
        # vocab.set_default_index(vocab['<unk>'])
        self.data = self.data_process(raw_data) # list of int, [2049990]
        # self.batch_size = batch_size # sentence len?
        # self.data = self.batchify(int_data, batch_size)
        self.bptt = 35
        self.ind = torch.arange(0, self.data.size(0) - 1, step=self.bptt)
        # self.data.size(0) // self.batch_size
        # print(self.data.shape)

    def data_process(self, raw_text_iter):
        data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


    # def get_batch(self, source, i): # [full_seq_len, batch_size], int
    #     data = source[i:i+seq_len]
    #     target = source[i+1:i+1+seq_len].reshape(-1)
    #     return data, target

    def __len__(self):
        # return len(self.data)
        # return self.data.size(0) // self.batch_size
        return len(self.data) // self.bptt

    def __getitem__(self, idx):
        i=idx*self.bptt
        # data, targets = self.get_batch(self.data, self.ind[index])
        # seq_len = min(self.bptt, len(self.data) - 1 - i)
        seq_len = self.bptt
        data = self.data[i:i+seq_len]
        target = self.data[i+1:i+1+seq_len].reshape(-1)
        return data, target

train_iter, val_iter, test_iter = WikiText2() # line by line of wiki  = Valkyria Chronicles III =
batch_size=128
# train_iter = Datasetme(train_iter, batch_size)
train_iter = Datasetme(train_iter)
val_iter = Datasetme(val_iter)
test_iter = Datasetme(test_iter)

def collate_fn(data):
    x,y=zip(*data)
    # print("collate",len(x),len(y))
    x=torch.stack(list(x), dim=1) # batch_first->dim=0
    y=torch.stack(list(y)).T.flatten()
    # print(y.shape)
    # print(y[:5,:5])
    # y=y.T.flatten()
    # print(y[:5,:5])
    # .reshape(-1)
    # print(x.shape)
    # def batchify(self, data, bsz):
    #     seq_len = data.size(0) // bsz
    #     data = data[:seq_len * bsz]
    #     data = data.view(bsz, seq_len).t().contiguous()
    #     return data#.to(device)
    return x, y

train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_iter, batch_size=batch_size, collate_fn=collate_fn, drop_last=True)

for x,y in train_loader:
    # print(x,y) # [35, 128], [35, 128]
    print(detoken(x[:,0]))
    print(detoken(x[:,1]))
    print(detoken(x[:,2]))
    print(detoken(y[:100]))
    break



= valkyria chronicles iii = senjō no valkyria 3 <unk> chronicles ( japanese 戦場のヴァルキュリア3 , lit . valkyria of the battlefield 3 ) , commonly referred to as valkyria chronicles iii outside japan , is
a tactical role @-@ playing video game developed by sega and media . vision for the playstation portable . released in january 2011 in japan , it is the third game in the valkyria series
. <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the nameless , a penal military unit serving the
valkyria tactical <unk> of the character s , chronicles valkyria a partially each characters game the during . movement health are to also her . to = and militia army , calamity on members darcsen allies effort evidence manpower . their as asking . in one like with ' the was of the unique memory the honjou a the games early the on the , ' the worked . , of written ( game , between unpopularity the valkyria last with found pr

In [None]:
# @title Transformer Model
import math
import os
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words
# square attention mask is required because the self-attention layers in nn.TransformerDecoder are only allowed to attend the earlier positions in the sequence
# The log-softmax function isn't applied here due to the later use of CrossEntropyLoss_, which requires the inputs to be unnormalized logits.


# positional encodings have the same dimension as the embeddings so that the two can be summed
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x): # x: [seq_len, batch_size, embedding_dim]
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class TransformerModel(nn.Module):
    def __init__(self, ntoken, d_model, nhead, d_hid, nlayers, dropout = 0.5):
        super().__init__()
        # self.model_type = 'Transformer'
        self.embedding = nn.Embedding(ntoken, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1 # gpt 0.02
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask = None): # [seq_len, batch_size], [seq_len, seq_len]
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None: src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
        output = self.transformer_encoder(src, src_mask)
        # print("fwd",output.shape) # float [seq_len, batch_size, d_model]
        output = self.linear(output) # no  log-softmax bec use CrossEntropyLoss which requires the inputs to be unnormalized logits.
        return output # [seq_len, batch_size, ntoken]

ntokens = len(vocab)  # size of vocabulary ; vocab size is equal to the length of the vocab object
model = TransformerModel(ntokens, d_model=512, nhead=8, d_hid=512, nlayers=6, dropout=0.1).to(device)

# nhead, d_model, nlayers = 12,768,12
# pw_ff 3072 d_hid
# https://pytorch.org/hub/huggingface_pytorch-transformers/
# gpt paper https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf
# bert paper https://arxiv.org/pdf/1810.04805.pdf
# https://vitalflux.com/bert-vs-gpt-differences-real-life-examples/
# Toronto BookCorpus (800M words) and English Wikipedia (2,500M words), BookCorpus

# https://www.analyticsvidhya.com/blog/2022/10/generative-pre-training-gpt-for-natural-language-understanding/




In [None]:
# @title wandb
# https://docs.wandb.ai/quickstart
!pip install wandb
import wandb
wandb.login() # 487a2109e55dce4e13fc70681781de9f50f27be7
run = wandb.init(
    project="transformer_tut",
    config={
        "model": "adam 1e-3",
        "optim": "adam",
        # "learning_rate": 5,
    })


In [None]:
# @title train eval
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
scaler = torch.cuda.amp.GradScaler()

# train function with automatic mixed precision
def strain(model, dataloader, optimizer, loss_fn, scheduler=None):
    model.train()
    total_loss = 0.
    for batch, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)
        with torch.cuda.amp.autocast(): # automatic mixed percision
            output = model(data) # [seq_len, batch_size, ntoken]
            output_flat = output.view(-1, ntokens)
            loss = loss_fn(output_flat, targets)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if scheduler is not None:
            scheduler.step()
            # print("### lr: ", optimizer.param_groups[0]["lr"])
        # print("strain",loss.item())
        total_loss += loss.item()
        try: wandb.log({"train loss": loss.item()/len(targets)})
        except NameError: pass
    return total_loss / len(dataloader)


def batchify(data, bsz): # [N]
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device) # [N // bsz, bsz]

train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
train_data = batchify(train_data, batch_size)  # [seq_len, batch_size]


bptt = 35
def get_batch(source, i): # [full_seq_len, batch_size], int
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len] # [seq_len, batch_size]
    target = source[i+1:i+1+seq_len].reshape(-1) # [seq_len * batch_size]
    return data, target

def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0.


    # num_batches = len(train_data) // bptt
    # for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
    #     data, targets = get_batch(train_data, i)
        # print("train",data.shape, targets.shape) # int int [35, 128],[4480]
    #     break
        # print("t1",data[:,0].shape) # [35]
        # print("t2",targets[:100].shape) # [35, 128]
        # print(detoken(data[:,0]))
        # print(detoken(data[:,1]))
        # print(detoken(data[:,2]))
        # print(detoken(targets[:100]))
    for batch, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)
        # print("train",data.shape, targets.shape) # int int [35, 128] [4480]
        # break
        output = model(data)
        # print("train",output.shape, targets.shape) # [35, 128, 28782], [4480]
        output_flat = output.view(-1, ntokens)
        loss = loss_fn(output_flat, targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def test(model, loader, loss_fn):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for data, targets in loader:
            data, targets = data.to(device), targets.to(device)
            seq_len = data.size(0)
            output = model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += seq_len * loss_fn(output_flat, targets).item()
    # return total_loss / (len(loader) - 1)
    return total_loss / len(loader)


In [None]:
# @title generate
def generate(model, src_sentence):
    model.eval()
    # src = src_sentence.view(1,-1).to(device)
    src = src_sentence.view(-1,1).to(device)
    # num_tokens = src.shape[1]
    num_tokens = src.shape[0]
    trg_indexes = src
    # max_len = src.shape[1]+5
    max_len = src.shape[0]+5
    for i in range(max_len):
    # for i in range(3):
        with torch.no_grad():
            output = model(trg_indexes)
            # print("train: ",output.shape) # [5, 1, 28782]
        # print(output.argmax(2)[:,-1])
        # print(detoken(trg_indexes[0]))
        # print(detoken(output.argmax(2)[0]))
        # pred_token = output.argmax(2)[:,-1].unsqueeze(1)
        pred_token = output.argmax(2)[-1,:].unsqueeze(1)
        # trg_indexes.append(pred_token)
        # print(trg_indexes.shape,pred_token.shape)
        # trg_indexes=torch.cat((trg_indexes,pred_token),1)
        trg_indexes=torch.cat((trg_indexes,pred_token),0)
        # print(pred_token)
    # trg_tokens = torch.tensor(trg_indexes).flatten()
    trg_tokens = trg_indexes.flatten()
    return trg_tokens

def data_process(raw_text_iter):
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
def detoken(tgt_tokens):
    return " ".join(vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

print(detoken(generate(model, data_process(["he stood by the water "]))))


he stood by the water and <unk> . the <unk> of the <unk> <unk> <unk>


In [None]:
# @title gen
def gen(model, test_loader):
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            # print(data.shape) # int [35, 128]
            output = model(data)
            # print(output.shape) # float [35, 128, 28782]
            # print(targets.shape) # int [4480]
            return output
            # return output, data, targets
# output, data, targets = gen(model, test_loader)
# print(test_loader.detoken(output.argmax(-1).T[0]))
# print(test_loader.detoken(data.argmax(-1)))
# print(test_loader.detoken(targets))
# with torch.no_grad():
#     print(detoken(model(data_process(["he stood by the water "])).argmax(-1).T[1]))
#     print(detoken(model(data_process(["he stood by the water "])).argmax(-1)[0]))
#     # print(model(data_process(["he stood by the water "])).shape) # [5, 5, 28782]


# scheduler.step()
# print(optimizer.param_groups[0]['lr'])

In [None]:
# @title wwwwwwwwwwwww
import time
# 1e-4, 1e-3, 1e-2, 1e-1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1) # 5. , 0.001
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)

for epoch in range(20):
    start = time.time()
    # train_loss = strain(model, train_loader, optimizer, loss_fn, scheduler)
    train_loss = strain(model, train_loader, optimizer, loss_fn)
    # train_loss = train(model, train_loader, optimizer, loss_fn)
    val_loss = test(model, val_loader, loss_fn)
    elapsed = time.time() - start
    print(f'{epoch+1:3d} train loss: {train_loss:5.2f}, valid loss: {val_loss:5.2f}, time: {elapsed:5.2f}s')
    output=gen(model, test_loader)
    print(detoken(output.argmax(-1).T[0]))
    print(detoken(generate(model, data_process(["he stood by the water "]))))
    # scheduler.step()
    # print(optimizer.param_groups[0]['lr'])

# lr1e-1 20epoch train loss: 3.92, valid loss: 98.34, time: 41.68s
# 40 train loss:  1.56, valid loss: 41.93, time: 41.72s
# 60 train loss:  0.94, valid loss: 28.62, time: 41.77s

# 1e-1 20 train loss:  5.75, valid loss: 196.86, time: 42.24s



  1 train loss:  7.43, valid loss: 233.15, time: 42.64s
= <unk> <unk> = <unk> <unk> <unk> <unk> <unk> <unk> the <unk> <unk> , <unk> the the <unk> <unk> , <unk> the <unk> the <unk> <unk> , <unk> , the of the <unk> the ,
he stood by the water of the <unk>
  2 train loss:  6.77, valid loss: 223.85, time: 42.19s
= = <unk> = = <unk> <unk> <unk> , , and , <unk> , <unk> the is a <unk> , <unk> of of the <unk> , , <unk> , the of the <unk> a ,
he stood by the water of the <unk>
  3 train loss:  6.58, valid loss: 219.36, time: 42.32s
= = <unk> = = <unk> a <unk> , , and , <unk> , <unk> the is been <unk> of <unk> of of the <unk> , , <unk> , the of the <unk> a of
he stood by the water of the <unk>
  4 train loss:  6.47, valid loss: 216.07, time: 42.17s
= = = = = = a <unk> , , and , <unk> , <unk> the is been <unk> of <unk> of of the <unk> , of <unk> of the of the <unk> also by
he stood by the water of the <unk>
  5 train loss:  6.38, valid loss: 213.46, time: 42.32s
= = = = = = a <unk> <unk> , and , 

In [None]:
# @title save
from google.colab import drive
drive.mount('/content/gdrive')
PATH="/content/gdrive/MyDrive/torch_save/" # for saving to google drive
name='transformer_tut_1e1.pth'
# PATH="/content/" # for saving on colab only
# name='model.pth'

# torch.save(model.state_dict(), PATH+name)

# model.load_state_dict(torch.load(PATH+name))
model.load_state_dict(torch.load(PATH+name, map_location=device))


In [None]:
# print(scheduler.lr)
# optimizer.param_groups[0]['lr']=2.
# print(len(train_loader))
for x in output.argmax(-1).T:
    print(test_loader.detoken(x))


In [None]:
# print(test_data.shape)
# for x in test_data:
#     print(detoken(x))
# print(test_data[:30])
# # target_seq = detoken(test_data[:30])
# # print(target_seq)
# with torch.no_grad():
#     data = test_data[:30].unsqueeze(0).to(device)
#     output = model(data)
#     output_flat = output.view(-1, ntokens)
#     print(output_flat)
# #     out = detoken(output_flat)
# # print(out)

eval_data = test_loader
with torch.no_grad():
    # for i in range(0, eval_data.size(0) - 1, bptt):
    for data, targets in eval_data:
        data, targets = data.to(device), targets.to(device)
        # data, targets = get_batch(eval_data, i)
        # print(data.shape) # [35, 10]
        # print(targets.shape) # [350]
        output = model(data)
        # print(output.shape, output) # [35, 10, 28782]
        # output_flat = output.view(-1, ntokens)
        # print(output_flat.shape) # float some >1 [350, 28782]
        break

# test_iter
# print(type(output[0]))
# print(output[0].argmax(-1))
# print(data.shape)
# print(targets.shape)
print(test_loader.detoken(data.T[0]))
print(test_loader.detoken(targets.reshape(data.shape).T[0]))
print(test_loader.detoken(output.argmax(-1).T[0]))
