In [1]:
from transformers import AutoTokenizer, AutoConfig, AutoModel, DistilBertTokenizer
from datasets import load_dataset
from datasets import Dataset
import pandas as pd
import re
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers_tutorial.networks.attention_head import MultiHeadAttention, FeedForward
from transformers_tutorial.networks.trainer import Trainer

# Load data and preprocess

In [2]:
df_poem_raw = pd.json_normalize(pd.read_json("../data/verse_202412132333.json").iloc[:,0])

def preprocess(df):
    df_ = df.copy()
    diacritics_pattern = r'[\u064E\u064F\u0650\u0651\u0652\u0640]'
    
    df_['text'] = df_['text'].apply(lambda x: re.sub(diacritics_pattern, '', x))
    df_['verse_index'] = (df_['vorder']-1) // 2

    df_output = (
        df_.sort_values("position", ascending=True)
        .groupby(["poem_id", "verse_index"])["text"]
        .agg(lambda x: " - ".join(x.tolist())
            ).reset_index()
    )
    
    # df_output['text_reverse'] = df_output['text'].apply(lambda x: " ".join(reversed(x.strip().split(" "))))

    return df_output

In [3]:
df_prep = preprocess(df_poem_raw)

In [4]:
df_poem_raw.iloc[2:4]

Unnamed: 0,poem_id,vorder,position,text
2,700000,3,0,همچو شاهین به هوا جلوه کنان می گذرم
3,700000,4,1,تیزرو بالی و تازنده پری داده مرا


In [5]:
df_prep[['text']].iloc[1].values

array(['همچو شاهین به هوا جلوه کنان می گذرم - تیزرو  بالی و تازنده پری داده مرا'],
      dtype=object)

# Load tokenizer

In [6]:
tokenizer = AutoTokenizer.from_pretrained("mitra-mir/BERT-Persian-Poetry")

In [7]:
def encode_inputs(df_):
    return tokenizer(df_['text'].values.tolist(), padding=False)

def decode_tokens(tokens_, skip_special_tokens=False):
    decoded = tokenizer.batch_decode(tokens_, skip_special_tokens=skip_special_tokens)
    return decoded

In [8]:
tokens = encode_inputs(df_prep)

Make sure that tokens orders are correct.

In [9]:
tokens['input_ids'][1][:5], decode_tokens(tokens['input_ids'][1][:20]), decode_tokens(tokens['input_ids'][1:2])

([2, 2164, 1112, 10880, 1923],
 ['[CLS]',
  'همچ',
  '##و',
  'شاهین',
  'به',
  'هوا',
  'جلوه',
  'کنان',
  'می',
  'گذر',
  '##م',
  '-',
  'تیزر',
  '##و',
  'بالی',
  'و',
  'تاز',
  '##نده',
  'پری',
  'داده'],
 ['[CLS] همچو شاهین به هوا جلوه کنان می گذرم - تیزرو بالی و تازنده پری داده مرا [SEP]'])

## Generate training dataset

In [10]:
def generate_sequences(tokens_):
    output = []
    target = []
    for seq in tokens_['input_ids']:
        for ix in range(1, len(seq)):
            output += [torch.tensor(seq[:ix])]
            target.append(seq[ix])

    padded_tensor = pad_sequence(output, batch_first=True, padding_value=0)
    attention_mask_tensor = (padded_tensor != 0).int()
          
    return {"input_ids": padded_tensor, "attention_mask": attention_mask_tensor}, torch.tensor(target)

In [11]:
full_tokens, targets = generate_sequences(tokens)

In [12]:
full_tokens['input_ids'].shape

torch.Size([71097, 39])

Check if target is correct

In [13]:
decode_tokens(full_tokens['input_ids'][:10], skip_special_tokens=True), decode_tokens(targets[:20], skip_special_tokens=True)

(['',
  'خواب',
  'خواب دیدم',
  'خواب دیدم که',
  'خواب دیدم که خدا',
  'خواب دیدم که خدا بال',
  'خواب دیدم که خدا بال و',
  'خواب دیدم که خدا بال و پری',
  'خواب دیدم که خدا بال و پری داده',
  'خواب دیدم که خدا بال و پری داده مرا'],
 ['خواب',
  'دیدم',
  'که',
  'خدا',
  'بال',
  'و',
  'پری',
  'داده',
  'مرا',
  '-',
  'در',
  'هوا',
  'قوت',
  'سیر',
  'و',
  'سفری',
  'داده',
  'مرا',
  '',
  'همچ'])

In [14]:
full_tokens['attention_mask'][:10,:10]

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)

## Train / validation split

In [15]:
N_FULL_DATASET = targets.shape[0]
TRAIN_FRAC = 0.9
TRAIN_SIZE = int(N_FULL_DATASET * TRAIN_FRAC)

SEQ_LEN = full_tokens['input_ids'].shape[1]

In [16]:
train_data = Dataset.from_dict({key: val[:TRAIN_SIZE] for key, val in full_tokens.items()}).add_column("label", targets[:TRAIN_SIZE].numpy())
validation_data = Dataset.from_dict({key: val[TRAIN_SIZE:] for key, val in full_tokens.items()}).add_column("label", targets[TRAIN_SIZE:].numpy()) 

In [17]:
validation_data.set_format("pt"), train_data.set_format("pt")

(None, None)

In [18]:
validation_data.shape, train_data.shape

((7110, 3), (63987, 3))

# Decoder Transformers

In [131]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, vocab_size, n_heads, intermediate_dim, device, p_dropout=0.2, seq_len=None):
        super().__init__()

        self.device = device

        config = AutoConfig.from_pretrained("bert-base-uncased")
        config.vocab_size = vocab_size
        config.hidden_dropout_prob = p_dropout

        self.seq_len = seq_len
        
        if seq_len:
            config.max_position_embeddings = seq_len
            
        self.embeddings = AutoModel.from_config(config).embeddings

        hidden_dim = config.hidden_size
        
        self.multi_head_attention = MultiHeadAttention(
            emb_dim=config.hidden_size, hidden_dim=hidden_dim, n_heads=n_heads, is_decoder=True,
        )
        self.ff = FeedForward(
            hidden_dim=hidden_dim,
            intermediate_dim=intermediate_dim,
            p_dropout=p_dropout,
        )
        self.layer_norm_1 = nn.LayerNorm(hidden_dim)
        self.layer_norm_2 = nn.LayerNorm(hidden_dim)

        self.linear = nn.Linear(hidden_dim, vocab_size)


    def forward(self, input_):
        data = {
            k: input_[k].to(self.device)
            for k in input_.keys()
            if k in ["attention_mask", "input_ids"]
        }

        x = self.embeddings(data['input_ids'])
        
        residual = x
        
        x = residual + self.multi_head_attention(x, data["attention_mask"])
        x = self.layer_norm_1(x)

        residual = x
        x = residual + self.ff(x)
        x = self.layer_norm_2(x)       
        
        logits = self.linear(x)
        
        # Use hidden layer corresponding to last non [PAD] token.
        
        last_non_padded = data['attention_mask'].sum(dim=1) - 1 # To get index
        batch_size = last_non_padded.shape[0]
        return logits[torch.arange(batch_size), last_non_padded]

In [132]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [133]:
VOCAB_SIZE = tokenizer.vocab_size
INTERMEDIATE_DIM = 512 * 4
N_HEADS = 12

trasnformer_decoder = TransformerDecoderLayer(
    vocab_size=VOCAB_SIZE, 
    # hidden_dim=HIDDEN_DIM, 
    n_heads=N_HEADS, 
    intermediate_dim=INTERMEDIATE_DIM,
    seq_len=SEQ_LEN,
    device=device,
    p_dropout=0.1,
).to(device)

In [134]:
trasnformer_decoder

TransformerDecoderLayer(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(42000, 768, padding_idx=0)
    (position_embeddings): Embedding(39, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (multi_head_attention): MultiHeadAttention(
    (heads): ModuleList(
      (0-11): 12 x AttentionHead(
        (q): Linear(in_features=768, out_features=64, bias=True)
        (k): Linear(in_features=768, out_features=64, bias=True)
        (v): Linear(in_features=768, out_features=64, bias=True)
      )
    )
    (dense): Linear(in_features=768, out_features=768, bias=True)
  )
  (ff): FeedForward(
    (layers): Sequential(
      (layer_1): Linear(in_features=768, out_features=2048, bias=True)
      (gelu): GELU(approximate='none')
      (layer_2): Linear(in_features=2048, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
 

# Training

In [135]:
# for i in trasnformer_decoder.named_parameters():
#     if i[1].requires_grad:
#         print(i[0])

In [136]:
with torch.no_grad():
    _ = trasnformer_decoder(validation_data[:2])

decode_tokens(torch.softmax(_, dim=-1).argmax(-1))

['دینامو', '##قهر']

In [137]:
optimizer = torch.optim.AdamW(
    params={p for p in trasnformer_decoder.parameters() if p.requires_grad}, 
    lr=1e-5, weight_decay=0.01
)

loss_fn = nn.CrossEntropyLoss()


trainer = Trainer(optimizer=optimizer, loss=loss_fn, model=trasnformer_decoder)

In [None]:
BATCH_SIZE = 16
N_EPOCHS = 1

torch.cuda.empty_cache()

_ = trainer.train(train_data=train_data.select(range(0, 2048)), 
                  validation_data=validation_data.select(range(0, 1024)), 
                  n_epochs=N_EPOCHS,
                  batch_size=BATCH_SIZE)

Epoch 0:   5%|█▍                             | 6/128 [00:04<01:26,  1.41batch/s]

In [125]:
def generate_text(model_, tokenizer_, initial_text, max_length = 100):
    output = initial_text.split(" ")
    
    for _ in range(0, min(model_.seq_len, max_length)):
        current_text = " ".join(output)
        print(initial_text, end="")
        
        with torch.no_grad():
            tokens_ = tokenizer_(current_text, padding='max_length', truncation=True, max_length=model_.seq_len, return_tensors="pt")

            chosen_token = torch.softmax(model_(tokens_), dim=-1).argmax()
            next_word = tokenizer_.decode(chosen_token, skip_special_tokens=True)
            print(next_word, end="")
            output += next_word
            
    return " ".join(output)

In [129]:
_ = generate_text(trasnformer_decoder, tokenizer, "ابر")

ابروابروابروابروابروابر-ابر-ابر-ابر-ابروابروابروابروابروابروابرابرابروابرابرابرابرابرابرابرابرابرابروابرابرابرابرابرابرابرابرابرابرابر