In [None]:
%%capture
!pip install transformers
!pip install einops
# !pip install vncorenlp
# !pip3 install fairseq

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
%cd '/content/drive/MyDrive'
!ls

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import RobertaTokenizerFast
import os
import torch
from torch.utils.data.dataset import Dataset
from transformers.tokenization_utils import PreTrainedTokenizer
from filelock import FileLock
from transformers.utils import logging
from typing import Dict, List, Optional
import pickle
import random
import time
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import numpy as np
from einops import rearrange
import math

In [None]:
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12):
        batch_size, head, length, d_tensor = k.size()

        score = torch.einsum("bhid,bhjd->bhij",q,k)
        score = score/math.sqrt(d_tensor)

        if mask is not None:
            score = score.masked_fill(mask == 0, -e)

        score = self.softmax(score)

        v = score @ v

        return v, score

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model*n_head)
        self.w_k = nn.Linear(d_model, d_model*n_head)
        self.w_v = nn.Linear(d_model, d_model*n_head)
        self.w_concat = nn.Linear(d_model*n_head, d_model)

    def forward(self, x, mask=None):
        q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), (q, k, v))

        out, attention = self.attention(q, k, v, mask=mask)

        # 4. concat and pass to linear layer
        # out = self.concat(out)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.w_concat(out)

        return out

In [None]:
class SelfAttentionLstm(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers,n_head):
        super(SelfAttentionLstm, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.multi_attention = MultiHeadAttention(d_model=input_size,n_head=4)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

    def forward(self, x, mask=None):
        x = self.multi_attention(x)
         
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to("cuda")
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to('cuda')

        out, _ = self.lstm(x, (h0, c0))  #(batch_size, seq_length, hidden_size)

        # seq_length = out.shape[1]
        #get embedding of last token represent whole context sentence
        out = out[: ,-1, : ]
        return out

In [None]:
train_path = 'Dataset/train_dataset_24_03.txt'
test_path = 'Dataset/valid_dataset_24_03.txt'

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained("Custom Loss/Tokenizer_26_03", max_len=512)
tokenizer.add_tokens('\n')
vocab_size= tokenizer.vocab_size
vocab_size = vocab_size + 1

In [None]:
def add_padding(list_token: list, block_size:int):
    tmp_list = [0]* block_size
    tmp_list[0:len(list_token)] = list_token
    tmp_list[len(list_token):block_size] = [1]*(block_size-len(list_token))
    return tmp_list

In [None]:
logger = logging.get_logger(__name__)
class CusTextDataset(Dataset):
    """
    This will be superseded by a framework-agnostic approach
    soon.
    """

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        file_path: str,
        block_size: int,
        overwrite_cache=False,
        cache_dir: Optional[str] = None,
    ):
        assert os.path.isfile(file_path), f"Input file path {file_path} not found"

        # num_special_tokens_to_add Returns the number of added tokens when encoding a sequence with special tokens

        directory, filename = os.path.split(file_path)
        cached_features_file = os.path.join(
            cache_dir if cache_dir is not None else directory,
            "cached_lm_{}_{}_{}".format(
                tokenizer.__class__.__name__,
                str(block_size),
                filename,
            ),
        )

        # Make sure only the first process in distributed training processes the dataset,
        # and the others will use the cache.
        lock_path = cached_features_file + ".lock"
        with FileLock(lock_path):

            if os.path.exists(cached_features_file) and not overwrite_cache:
                start = time.time()
                with open(cached_features_file, "rb") as handle:
                    self.examples = pickle.load(handle)
                logger.info(
                    f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
                )

            else:
                logger.info(f"Creating features from dataset file at {directory}")

                self.examples = []
                with open(file_path, encoding="utf-8") as f:
                    total_poem = f.read()
                    
                split_total_poem = total_poem.split("\n\n")
                canto_poem = [split_total_poem[x:x+4] for x in range(0, len(split_total_poem), 4)]
                canto_poem = ["\n\n".join(i) for i in canto_poem]

                canto_token = [tokenizer.encode(i) for i in canto_poem]
                canto_token = [i for i in canto_token if len(i) >= 129 and len(i) <= 140]

                for i in canto_token:
                  self.examples.append(add_padding(i,block_size=block_size ))


                start = time.time()
                with open(cached_features_file, "wb") as handle:
                    pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
                logger.info(
                    "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
                )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return torch.tensor(self.examples[i], dtype=torch.long)

In [None]:
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling, LineByLineWithSOPTextDataset

def load_dataset(train_path, test_path, tokenizer):
    train_dataset = CusTextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=140)
     
    test_dataset = CusTextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=140)   
    
    return train_dataset,test_dataset

train_dataset,test_dataset = load_dataset(train_path,test_path, tokenizer)

In [None]:
len(train_dataset[0])

140

In [None]:
print(train_dataset[2])

tensor([    0,  1536,  5469,   417,  3707,   731,   705, 11982,   657,   546,
         3218,  2175,   508,   992,   955,   469, 11982,  7143,   693,   846,
          749,   705,   914, 11982,  1309,  2109,  7031,   785,   395,  1011,
         1659,   483, 11982, 11982,   829,   584,   609,  2885,   719,   866,
        11982,   354,   834,   504,  1982,  1890,  2518,  2345,   427,   705,
        11982,   638,   618,   502,  3017,   503,   413, 11982,  1209,   392,
         4391,  1352,  3354,   760,  1946,  1735, 11982, 11982,  8024,  3224,
         3460,  1986,  3907,  3541, 11982,  4305,   755,  4156,  2625,   888,
          601,   516,   495, 11982,  2256,  2496,  3247,  1911,   878,  1117,
        11982,  1469,   394,   501,   927,  1768,  2559,   493,   417, 11982,
        11982,  3052,   854,  1578,  2232,   985,   849, 11982,  6127,   437,
          705,  1270,   586,   950,  1097,   392, 11982,  2190,   693,  1019,
         2586,   395,   372, 11982,  3285,  1140,   600,  3695, 

#Initialize Model

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(dataset= test_dataset, batch_size= 16, shuffle=False)

In [None]:
from transformers import Trainer, TrainingArguments, GPT2Config, GPT2LMHeadModel,GPT2Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
configuration = GPT2Config(vocab_size=vocab_size,n_layer = 6)
poem = GPT2LMHeadModel(configuration).to("cuda")

In [None]:
x = torch.randint(0, vocab_size, (32, 140)).cuda()

In [None]:
start = time.time()
outputs = poem.transformer(x)
end = time.time()
print(end-start)
print(outputs[0].shape)

0.12488460540771484
torch.Size([32, 140, 768])


In [None]:
count = 0
for i in train_loader:
  count = count +1 

print(count)

2249


#Train GPT-2


In [None]:
lr_rate = 0.00001
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(poem.parameters(), lr_rate)

In [None]:
def save_checkpoint(state, filename= "Custom Loss/gpt_2_custom_loss_v2.pth.tar"):
    print("Saving checkpoint")
    torch.save(state,filename)

def load_checkpoint(state):
    print("Load checkpoint")
    poem.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])


In [None]:
load_checkpoint(torch.load("Custom Loss/gpt_2_custom_loss_v3.pth.tar"))

Load checkpoint


In [None]:
head_gpt = SelfAttentionLstm(input_size=768,hidden_size=1000, num_layers=2,n_head=4).to('cuda')

In [None]:
def load_checkpoint_lstm(state):
    print("Load checkpoint")
    head_gpt.load_state_dict(state['state_dict'])

In [None]:
load_checkpoint_lstm(torch.load("Custom Loss/self_attetion_lstm.pth.tar"))

Load checkpoint


In [None]:
test_input = train_dataset[16].to(device)
test_input = test_input.unsqueeze(0)
lm_logits = poem(test_input).logits
token = torch.argmax(lm_logits, dim= 2)
poem_output = tokenizer.decode([0] + token[0].tolist(), skip_special_tokens= False)
print(poem_output)

<s>em thôi anh thấy hả ông
mần tui xí hụt phải lồng rách quần
bửa giờ chuẩn bị ông khuân
lên tui dẹo miết làm ngần roài đa

liếc ngang nốt bả cười khì
õng a õng ẹo kéo ghì cái lưng
giả đò làm bộ ngắm hun
cái mình cá lắc run run cặp đùi

còn mà cũng thấy vui vui
vợ hiền vợ ngắm lại thời nhõng nheo
tánh mình lam thích vợ cam
còn vợ mình thích đi chiều mình ôm

còn phanh có chả có ai
ngực căng e ấp bờ vai tròn tròn
còn gì vợ gái có con
chẳng làm quan vẫn làm mòn hết ai</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [None]:
def custom_index(list_token:list):
    list_token = [list_token[i:i+4] for i in range(0,len(list_token),4)]
    for i in range(len(list_token)):
      list_token[i] = [list_token[i][0],list_token[i][3]]
    return list_token

In [None]:
def get_idx_two_line(lm_logits):
    token = torch.argmax(lm_logits, dim= 2)
    token = token[0].tolist()
    index_token = [0]
    for i in range(len(token)):
        if token[i:i+2] == [11982,11982]:
          index_token.append(i)
          index_token.append(i+2)
    index_token.append(len(token))

    # Lấy index đầu và cuối của 1 khổ
    index_khotho = [index_token[i:i+2] for i in range(0,len(index_token),2)]
    index_khotho = [i for i in index_khotho if len(i) == 2]

    a = index_khotho
    
    #Lấy index của token đầu và token cuối của 2 câu trong 1 khổ
    token_final = []
    for idx_khotho in index_khotho:
        tmp = token[idx_khotho[0]:idx_khotho[1]]
        token_tmp = [idx_khotho[0]]
        for i in range(len(tmp)):
          if tmp[i] == 11982:
            token_tmp.append(i + idx_khotho[0])
            token_tmp.append(i+1 +idx_khotho[0])
        token_tmp.append(idx_khotho[1])
        if len(token_tmp) != 8:
          continue 
        else :
          token_final.append(custom_index(token_tmp))

    # a = [a[i:i+2] for i in range(len(a))]
    # a = [i for i in a if len(i) == 2]
    
    return token_final

In [None]:
a = get_idx_two_line(lm_logits) 
a

[[[0, 15], [16, 31]],
 [[33, 48], [49, 64]],
 [[66, 81], [82, 97]],
 [[99, 114], [115, 140]]]

In [None]:
def loss_kho_tho(lm_logits,embedding):
    lm_logits = torch.unsqueeze(lm_logits,0)
    pair_list = get_idx_two_line(lm_logits)
    embedding = torch.unsqueeze(embedding,0)
    
    total_lost = 0
    loss = nn.MSELoss().to(device)
    for pair in pair_list:
        one = pair[0]
        two = pair[1]

        if one == None or two == None:
          continue

        embedd_one = head_gpt(embedding[:,one[0]: one[1], :])
        embedd_two = head_gpt(embedding[:,two[0]: two[1], :])

        total_lost += loss(embedd_one,embedd_two)

    return total_lost     

In [None]:
for i, batch in enumerate(train_loader):
  if i == 1:
    embedding = poem.transformer(batch.to(device))[0]
    lm_logits = poem(batch.to(device)).logits 
    break

start = time.time()
print(sum([loss_kho_tho(lm_logits[i],embedding[i]) for i in range(lm_logits.shape[0])]))
end = time.time()
print(end-start)

tensor(0.0036, device='cuda:0', grad_fn=<AddBackward0>)
0.40593647956848145


In [None]:
for param in head_gpt.parameters():
    param.require_grad = True

for param in poem.parameters():
    param.require_grad = True

In [None]:
checkpoint = {'state_dict': head_gpt.state_dict()}
save_checkpoint(checkpoint, filename= "Custom Loss/self_attetion_lstm.pth.tar")

Saving checkpoint


In [None]:
class TextGenerator():

    def __init__(self, max_tokens, start_tokens, maxlen, model, tokenizer,device, topk):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.maxlen = maxlen
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.k = topk 

    def sample_from(self, logits):
        logits, indices = torch.topk(logits, k=self.k, sorted=True)
        return np.random.choice(indices.cpu().numpy())


    def gen_poem(self):
        start_tokens = [_ for _ in self.start_tokens]
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = self.maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:self.maxlen]
                sample_index = self.maxlen - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = torch.tensor([x], device= self.device)
            y = self.model(x).logits
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
            # print(sample_token)
        output_token = [_ for _ in self.start_tokens + tokens_generated]
        poem = self.tokenizer.decode(output_token)
        print(f"generated text:\n{poem}\n")

In [None]:
num_token_generated = 30
hint = 'mùa thu'
start_tokens = tokenizer.encode(hint)[:-1]
generator = TextGenerator(max_tokens= num_token_generated, start_tokens= start_tokens, maxlen= 300, model= poem, tokenizer= tokenizer,device= device, topk= 1)
generator.gen_poem()

generated text:
<s>mùa thu gợi lắm những gì
còn đang hò hẹn những khi sững sờ
mùa thu vắng tiếng trẻ thơ
đôi bờ môi khép sững sờ mắt nhau



