In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
import math
from tqdm import tqdm
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from transformers import AutoTokenizer, BertTokenizer
from sklearn.model_selection import train_test_split
import gc
from tokenizers import ByteLevelBPETokenizer, processors

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")

print("Padding token:", tokenizer.pad_token)
print("EOS token:", tokenizer.sep_token)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Padding token: <pad>
EOS token: </s>


In [161]:
data = pd.read_csv("article_summaries.csv")
data.dropna(inplace=True)
print(data.head())


                          Title  \
0              Dicamptus neavei   
1  Molophilus lackschewitzianus   
2             Sisyropa argyrata   
3      Cột Chúa Ba Ngôi ở Praha   
4          Androsace ludlowiana   

                                             Summary  
0  Dicamptus neavei là một loài tò vò trong họ Ic...  
1  Molophilus lackschewitzianus là một loài ruồi ...  
2  Sisyropa argyrata là một loài ruồi trong họ Ta...  
3  Cột Holy Trinity, hay Cột Chúa Ba Ngôi ở Praha...  
4  Androsace ludlowiana là một loài thực vật có h...  


In [129]:
BLOCK_SIZE = 50
VOCAB_SIZE = tokenizer.vocab_size

In [159]:
def get_batch(data, block_size, batch_size):
    x = torch.zeros((batch_size, block_size), dtype=torch.long)
    y = torch.zeros((batch_size, block_size), dtype=torch.long)
    samples = data['Summary'].sample(n=batch_size)
    
    for i, sample in enumerate(samples):
        summary_ids = tokenizer.encode(sample)
        if len(summary_ids) < block_size + 2:
            summary_ids = summary_ids + [tokenizer.pad_token_id] * (block_size + 2 - len(summary_ids))
        random_start = random.randint(0, len(summary_ids) - block_size - 2)
        x[i, :len(summary_ids)] = torch.tensor(summary_ids[random_start:random_start + block_size], dtype=torch.long)
        y[i, :len(summary_ids)] = torch.tensor(summary_ids[random_start + 1:random_start + block_size + 1], dtype=torch.long)

    return x, y
        


a, b = get_batch(data, block_size=BLOCK_SIZE, batch_size=1)
#c, d = get_batch(train_text_ids, block_size=BLOCK_SIZE, batch_size=1)

print(a.shape, b.shape)

print(tokenizer.decode(a[0].tolist()))
print(tokenizer.decode(b[0].tolist()))

torch.Size([1, 50]) torch.Size([1, 50])
<s> Ornitopia là một chi bướm đêm thuộc họ Noctuidae. </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Ornitopia là một chi bướm đêm thuộc họ Noctuidae. </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


In [156]:
N_EMB = 400
N_LAYERS = 6
N_HEADS = 5
DROPOUT = 0.1

print(tokenizer.vocab_size)

def estimate_loss(model, val_data, block_size, batch_size):
    model.eval()
    with torch.no_grad():
        x, y = get_batch(val_data, block_size, batch_size)
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
    model.train()
    return loss.item()

def generate_square_subsequent_mask(sz):
    mask = (torch.tril(torch.ones(sz, sz)) == 1).float()
    mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
    
class LanguageModel(nn.Module):

    def __init__(self, vocab_size, n_emb, block_size, n_layers, n_heads, dropout=0.2):
        super(LanguageModel, self).__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_emb)
        self.position_embedding_table = nn.Embedding(block_size, n_emb)

        encoder_layer = nn.TransformerEncoderLayer(d_model=n_emb, nhead=n_heads, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb),
            nn.ReLU(),
            nn.Linear(4 * n_emb, n_emb)
        )
        
        self.lm_head = nn.Linear(n_emb, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx)
        position_emb = self.position_embedding_table(torch.arange(T, device=device))
        
        x = token_emb + position_emb
        x_transform = x.clone()
        mask = generate_square_subsequent_mask(T).to(device)
        
        x_transform = self.transformer_encoder(x_transform.permute(1, 0, 2), mask=mask)
        x_transform = x_transform.permute(1, 0, 2)
        x = x + x_transform
        
        x = self.feed_forward(x)
        logits = self.lm_head(x)

        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
            return logits, loss
        else:
            return logits, None


    def generate(self, idx, max_new_tokens, block_size, temperature=1.0, stop_token=False):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self.forward(idx_cond)
            
            # Scale logits by the temperature
            logits = logits[:, -1, :] / temperature
            
            probs = F.softmax(logits, dim=-1)
            idx_new = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_new], dim=-1)
            if stop_token and idx_new.item() == tokenizer.sep_token_id:
                break
        return idx

# Create model, optimizer
model = LanguageModel(vocab_size=VOCAB_SIZE, block_size=BLOCK_SIZE, n_emb=N_EMB, n_layers=N_LAYERS, \
    n_heads=N_HEADS, dropout=DROPOUT).to(device)

print(f'Number of parameters {sum(p.numel() for p in model.parameters() if p.requires_grad)}')


64000
Number of parameters 66270288


In [6]:
model = torch.load("good_wiki_transformer_2.pth")

In [162]:
EARLY_STOP = 50
N_EPOCHS = 1000
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

last_val_loss = 1e9
early_stop = EARLY_STOP

for steps in range(N_EPOCHS):
    model.train()
    xb, yb = get_batch(data, block_size=BLOCK_SIZE, batch_size=BATCH_SIZE)
    xb = xb.to(device)
    yb = yb.to(device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if steps % 100 == 0:
        print('Step:', steps, 'Training Loss:', loss.item())
        val_loss = estimate_loss(model, data, block_size=BLOCK_SIZE, batch_size=BATCH_SIZE)
        print('Validation loss:', val_loss)
        if val_loss >= last_val_loss:
            early_stop -= 1
            if early_stop == 0:
                print('Early stop!')
                break
        else:
            early_stop = EARLY_STOP
            last_val_loss = val_loss

Step: 0 Training Loss: 4.087509632110596
Validation loss: 4.27831506729126


In [125]:
torch.save(model, 'viet_wiki_summary.pth')

In [159]:
starting_tokens = ''

encoded_start = tokenizer.encode(starting_tokens)
encoded_start.pop(-1)
len_starting_tokens = len(encoded_start)

idx = torch.tensor(encoded_start).reshape(1, len_starting_tokens).to(device)
model.eval()
N_SAMPLES = 100
for _ in range(N_SAMPLES):
    generation = model.generate(idx, max_new_tokens=500, block_size=BLOCK_SIZE, temperature=1, stop_token=True)[0].tolist()
    print(tokenizer.decode(generation))
    print('------------------')




[CLS] time the plot shows [SEP]
------------------
[CLS] i ’ m kinda sure this nothing....mner like guys? [SEP]
------------------
[CLS] congress hits 10 day in return from wisdom [SEP]
------------------
[CLS] 23k sell on!! [SEP]
------------------
[CLS] keep in 2021 that afternoon your outlook, just missed buying you? [SEP]
------------------
[CLS] if you love you guys, why's my brrroy [SEP]
------------------
[CLS] or smart he?????? [UNK] [SEP]
------------------
[CLS] do not get “?? [SEP]
------------------
[CLS] bb ) is a wsb adventure, growing [SEP]
------------------
[CLS] another german full retardedевич making sense over $ 1000. aapl powell [SEP]
------------------
[CLS] holding gme this kid and scrolling stocks [SEP]
------------------
[CLS] be interesting! [SEP]
------------------
[CLS] people hit losses on sounds [unused772] amc " trading of waiting for my anthem on our needs right of funds to = [SEP]
------------------
[CLS] loadedრ is near the moon [UNK] [SEP]
-----------