In [None]:
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm
import re
import json
import matplotlib.pyplot as plt
import json
import math

import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

# Load Dataset

In [None]:
dataset = load_dataset("text", data_files={"train": ["./wiki.train.tokens"], "test": "./wiki.test.tokens"})


In [None]:
test_dataset = dataset["test"]
train_dataset = dataset["train"]

print(test_dataset.shape)
print(train_dataset.shape)

train_dataset[0:100]

# Clean Dataset

In [None]:
def clean_data(data):
  cleaned_texts = []
  for text in data['text']:
    x = text.lower()
    # remove urls
    x = re.sub(r'https?://\S+|www\.\S+', ' ', x)
    # fix formattting : "hello , world" -> "hello, world"
    x = re.sub(r'\s([.,!?";:])', r'\1', x)
    # Keep basic punctuation but remove weird symbols
    x = re.sub(r'[^a-zA-Z0-9.,!? \n]', '', x)
    # collapse multiple spaces/lines into one
    x = re.sub(r'\s+', ' ', x).strip()
    cleaned_texts.append(x)

  cleaned = []
  max_len = 128
  sentence = ""
  for text in cleaned_texts:
    if len(sentence.strip())<70:
      sentence += " "+text
    else:
      cleaned.append(sentence.strip())
      sentence = text
  # cleaned.filter(lambda x: len(x)>0)
  return {"text":cleaned}

In [None]:
len(test_dataset),len(train_dataset)

In [None]:
test_dataset = test_dataset.map(clean_data,batched=True)
train_dataset = train_dataset.map(clean_data,batched=True)
len(test_dataset),len(train_dataset)


In [None]:
train_dataset = train_dataset.filter(lambda x: len(x['text'])>0)
test_dataset = test_dataset.filter(lambda x : len(x['text'])>0)

len(test_dataset),len(train_dataset)

Filter:   0%|          | 0/11618 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2471 [00:00<?, ? examples/s]

(1, 1)

In [None]:
test_dataset = test_dataset.map(clean_data,batched=True)
train_dataset = train_dataset.map(clean_data,batched=True)
len(test_dataset),len(train_dataset)


Map:   0%|          | 0/2891 [00:00<?, ? examples/s]

Map:   0%|          | 0/13695 [00:00<?, ? examples/s]

(1, 1)

In [None]:
train_dataset.to_csv("cleaned_train.txt",columns=["text"],index=False,header=False)
test_dataset.to_csv('cleaned_test.txt',columns=["text"],index=False,header=False)

In [None]:
len(train_dataset['train']),len(test_dataset['train'])

(11618, 2471)

In [None]:
train_dataset = load_dataset("text", data_files="cleaned_train.txt")
test_dataset = load_dataset("text",data_files="cleaned_test.txt")


In [None]:
len(train_dataset['train']),len(test_dataset['train'])

(13695, 2891)

# BERT INPUT PREPARATION - NSP & MLM USING CUSTOM TOKENIZER


In [None]:
PAD, CLS, SEP, MASK = 0, 1, 2, 3
SPECIAL_TOKENS = ["[PAD]", "[CLS]", "[SEP]","[EOS]" ,"[MASK]"]

def build_vocab(sentences):
    words = set([word for sent in sentences for word in sent.lower().split()])
    vocab = {word:i+len(SPECIAL_TOKENS)for i,word in enumerate(sorted(words))}
    for i, token in enumerate(SPECIAL_TOKENS):
      vocab[token] = i
    return vocab

vocab = build_vocab(train_dataset['train']['text'])
len(vocab)

67275

In [None]:
def prepare_bert_batch_manual(examples, max_len=128):
    texts = examples['text']
    n = len(texts)
    vocab_size = len(vocab)
    inv_vocab = {v:k for k,v in vocab.items()}
    batch_input_ids, batch_segments, batch_masks,batch_mlm, batch_nsp = [],[],[],[],[]
    for i in range(n):
      # NSP
      is_next = 1 if torch.rand(1)>0.5 and i<n-1 else 0
      s1 = texts[i].lower().split()
      s2 = texts[i + 1 if is_next else torch.randint(0,n,(1,))].lower().split()
      while len(s1)+len(s2) > max_len - 3:
        if len(s1)>len(s2):
          s1.pop()
        else:
          s2.pop()
      tokens = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[EOS]"]
      segements = [0]*(len(s1)+2) + [1]*(len(s2)+1)

      # MLM
      input_ids = [vocab.get(t,vocab["[PAD]"]) for t in tokens]
      mlm_labels = [-100] * len(input_ids)
      for idx,token in enumerate(tokens):
        if token in SPECIAL_TOKENS:
          continue
        if torch.rand(1) < 0.15:  # 15%
          mlm_labels[idx] = vocab.get(token,0)
          rand_val = torch.rand(1)
          if rand_val < 0.8:      # 80% of 15%
            input_ids[idx] = MASK
          elif rand_val < 0.9:    # other 10% of 15%
            input_ids[idx] = torch.randint(len(SPECIAL_TOKENS),n,(1,)).item()
          else:
            pass # remain same (remaining 10% of 15%)

      padding_len = max_len - len(input_ids)
      attention_mask = [1]*len(input_ids)+ [0]*padding_len
      input_ids += [0]*padding_len
      segements += [0]*padding_len
      mlm_labels += [-100] * padding_len
      batch_input_ids.append(input_ids)
      batch_masks.append(attention_mask)
      batch_segments.append(segements)
      batch_mlm.append(mlm_labels)
      batch_nsp.append(is_next)

    return {
      "input_ids": torch.tensor(batch_input_ids),
      "token_type_ids": torch.tensor(batch_segments),
      "attention_mask": torch.tensor(batch_masks),
      "labels": torch.tensor(batch_mlm),
      "next_sentence_label": torch.tensor(batch_nsp)
    }

# Data Loader

In [None]:
processed_train_dataset = train_dataset.map(prepare_bert_batch_manual,batched=True,remove_columns=train_dataset['train'].column_names)
processed_test_dataset = test_dataset.map(prepare_bert_batch_manual,batched=True,remove_columns=test_dataset['train'].column_names)
processed_train_dataset = processed_train_dataset['train']
processed_test_dataset = processed_test_dataset['train']

Map:   0%|          | 0/11618 [00:00<?, ? examples/s]

Map:   0%|          | 0/2471 [00:00<?, ? examples/s]

In [None]:
processed_train_dataset.set_format(type='torch')
processed_test_dataset.set_format(type='torch')

In [None]:
train_loader = DataLoader(processed_train_dataset,batch_size=16,shuffle=True)
test_loader = DataLoader(processed_test_dataset,batch_size=16,shuffle=True)

In [None]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7c876935e9c0>

In [None]:
test_loader

# BERT : Architecture

## BERT - Embedding Layer

In [None]:
class PositionalEmbedding(torch.nn.Module):
  def __init__(self, d_model=768, max_len=128):
    super().__init__()
    pe = torch.zeros(max_len, d_model).float()

    for pos_idx in range(max_len):
      for i in range(0, d_model, 2):
        pe[pos_idx, i] = math.sin(pos_idx / (10000 ** ((2 * i) / d_model)))
        if i + 1 < d_model:
          pe[pos_idx, i + 1] = math.cos(pos_idx / (10000 ** ((2 * i) / d_model)))
    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    return self.pe[:, :x.size(1)]

In [None]:
class BERTEmbedding(nn.Module):
  def __init__(self, vocab_size, hidden_size, max_len, n_segments):
    super().__init__()
    self.tok_embed = nn.Embedding(vocab_size,hidden_size, padding_idx=0)
    self.pos_embed = PositionalEmbedding(hidden_size,max_len)
    self.seg_embed = nn.Embedding(n_segments,hidden_size,padding_idx=0)
    self.norm = nn.LayerNorm(hidden_size)
    self.dropout = nn.Dropout(0.1)

  def forward(self, x, seg):
    # self.pos_embed(x) internally uses the buffer 'pe'
    # which PyTorch moved to CUDA when you did model.to(device)
    embed = self.tok_embed(x) + self.pos_embed(x) + self.seg_embed(seg)
    return self.dropout(embed)

## BERT Class

In [None]:
class BERT(nn.Module):
  def __init__(self,vocab_size, hidden_size=768, n_layers=12, n_heads=12, max_len=128):
    super().__init__()
    self.embedding = BERTEmbedding(vocab_size,hidden_size,max_len,n_segments=2)

    encoder_layer = nn.TransformerEncoderLayer(
        d_model = hidden_size,
        nhead = n_heads,
        batch_first = True,
        activation = 'gelu',
        norm_first = True
    )

    self.encoder = nn.TransformerEncoder(
        encoder_layer,
        num_layers=n_layers

    )

    self.fc = nn.Linear(hidden_size,hidden_size)
    self.activ = nn.Tanh()
    self.classifier = nn.Linear(hidden_size,2)
    self.mlm_head = nn.Linear(hidden_size, vocab_size)

  def forward(self, x, seg, attention_mask):
    x = self.embedding(x,seg)
    # src_key_padding_mask = (attention_mask==0)
    src_key_padding_mask = (attention_mask == 0).to(torch.bool).to(device)
    # mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)รท
    encoded = self.encoder(x, src_key_padding_mask = src_key_padding_mask)
    # MLM
    mlm_output = self.mlm_head(encoded)
    # NSP
    cls_token = encoded[:,0]
    cls_output = self.activ(self.fc(cls_token))
    nsp_logits = self.classifier(cls_output)
    return mlm_output, nsp_logits
  def _init_weights(self,module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=0.02)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

# Training BERT

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(vocab)
model = BERT(vocab_size)
model.to(device) # Move the model to the specified device
model.apply(model._init_weights)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-5,weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=-100)



In [None]:
def train_and_save(model,train_loader,optimizer,device,epochs=5):
  history = {
      "total_loss":[],
      "mlm_loss":[],
      "nsp_loss":[]
  }
  model.train()
  print("Starting training")

  for epoch in range(epochs):
    epoch_total,epoch_mlm,epoch_nsp = 0,0,0
    loop = tqdm(train_loader,leave=True)
    for batch in loop:
      optimizer.zero_grad()
      input_ids = batch['input_ids'].to(device)
      token_type_ids = batch['token_type_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      mlm_labels = batch['labels'].to(device)
      nsp_labels = batch['next_sentence_label'].to(device)
      mlm_logits,nsp_logits = model(input_ids,token_type_ids,attention_mask)
      loss_mlm = criterion(mlm_logits.view(-1,vocab_size),mlm_labels.view(-1))
      loss_nsp = criterion(nsp_logits, nsp_labels)
      loss = loss_mlm + loss_nsp

      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
      optimizer.step()

      epoch_total += loss.item();
      epoch_mlm += loss_mlm.item()
      epoch_nsp += loss_nsp.item()
      loop.set_description(f"Epoch {epoch+1}/{epochs}")
      loop.set_postfix(mlm=loss_mlm.item(), nsp=loss_nsp.item(),loss=loss.item())
    history['total_loss'].append(epoch_total/len(train_loader))
    history['mlm_loss'].append(epoch_mlm/len(train_loader))
    history['nsp_loss'].append(epoch_nsp/len(train_loader))
  torch.save(model.state_dict(),"bert_pretrained2.pth")
  with open("loss_history2.json", "w") as f:
    json.dump(history, f)
  print("Model saved")
  return history

In [None]:
history = train_and_save(model,train_loader,optimizer,device)

In [None]:
def plot_losses(history):
  epochs = range(1, len(history["total_loss"]) + 1)

  plt.figure(figsize=(10, 6))
  plt.plot(epochs, history["mlm_loss"], 'b-o', label='MLM Loss')
  plt.plot(epochs, history["nsp_loss"], 'r-o', label='NSP Loss')
  plt.plot(epochs, history["total_loss"], 'g--', label='Total Loss')

  plt.title('BERT Pre-training Loss')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.legend()
  plt.grid(True)
  plt.show()

plot_losses(history)

In [None]:
model = BERT(vocab_size=len(tokenizer), hidden_size=768, n_layers=12)
model.load_state_dict(torch.load("bert_model.pth", map_location=device))
model.to(device)
model.eval()

with open("loss_history.json", "r") as f:
    history = json.load(f)

In [None]:
history = train_and_save(model,train_loader,optimizer,device)

Starting training


  0%|          | 0/856 [00:00<?, ?it/s]

  0%|          | 0/856 [00:00<?, ?it/s]

  0%|          | 0/856 [00:00<?, ?it/s]

  0%|          | 0/856 [00:00<?, ?it/s]

  0%|          | 0/856 [00:00<?, ?it/s]

Model saved
