In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import pandas as pd
import re
from transformers import BertTokenizer
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel
import torch.optim as optim
import numpy as np
from tqdm import tqdm

In [None]:
text1 = pd.read_parquet('train_data1.parquet')
text2 = pd.read_parquet('train_data2.parquet')
text = pd.concat([text1, text2], axis=0)

In [None]:
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9.,!?\'\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [None]:
text['text'] = text['text'].apply(clean_text)
cleaned_text = text.loc[text['text'] != ""]
text = cleaned_text

In [None]:
batch_size = 32  
block_size = 128  
max_iters = 10000  
eval_interval = 1000  
learning_rate = 2e-5  
device = 'cuda' if torch.cuda.is_available() else 'cpu'  
eval_iters = 500  

vocab_size = 30522  

n_embd = 768  

n_head = 12  
n_layer = 2  
dropout = 0.3  


In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
text_list = text['text'].tolist()
input_ids_inputs, input_ids_targets = [], []

In [None]:
for i in tqdm(range(0, len(text_list), 1000), desc="Tokenizing", unit="batch"):
    batch_texts = text_list[i:i + batch_size]
    tokens = tokenizer.batch_encode_plus(
        batch_texts,
        truncation=True,
        padding="max_length",
        max_length=50,
        return_tensors="np",
        add_special_tokens=False
    )
    input_ids_inputs.extend(tokens["input_ids"][:-1])
    input_ids_targets.extend(tokens["input_ids"][1:])

In [None]:
input_ids_inputs_np = np.array(input_ids_inputs)
input_ids_targets_np = np.array(input_ids_targets)

In [None]:
input_ids_inputs_tensor = torch.from_numpy(input_ids_inputs_np).long()
input_ids_targets_tensor = torch.from_numpy(input_ids_targets_np).long()

In [None]:
dataset = TensorDataset(input_ids_inputs_tensor, input_ids_targets_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.proj(torch.cat([h(x) for h in self.heads], dim=-1)))

class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        return x + self.ffwd(self.ln2(x))


In [None]:
class BERT_LSTM_GRU(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = BertModel.from_pretrained("bert-base-uncased")
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.lstm = nn.LSTM(n_embd, n_embd, batch_first=True)
        self.gru = nn.GRU(n_embd, n_embd, batch_first=True)
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        embedding = self.embedding(idx).last_hidden_state
        x = self.blocks(embedding)
        x, _ = self.lstm(x)
        x, _ = self.gru(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [None]:
model = BERT_LSTM_GRU().to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
for iter in range(max_iters):
    total_loss = 0.0
    num_batches = 0

    batch_loss = 0.0
    batch_count = 0

    for i, batch in enumerate(dataloader):
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        logits, loss = model(inputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        batch_loss += loss.item()
        batch_count += 1

        if i > 0 and i % 50 == 0:
            batch_mean_loss = batch_loss / batch_count
            print(f"Epoch: {iter + 1} Batch: {i} Loss: {batch_mean_loss:.4f}")
            batch_loss = 0.0
            batch_count = 0

    mean_loss = total_loss / num_batches
    print(f"Epoch: {iter + 1}, Mean Loss: {mean_loss:.4f}")
    break



In [None]:
def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        array = idx.tolist()
        print(tokenizer.decode(array))