<a href="https://colab.research.google.com/github/lazarmiovcic/nanoBertic/blob/master/nanoBerti%C4%87.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

# Check tokenizer vocab file
print(os.path.exists('/content/drive/Othercomputers/My_Laptop/all_text_batches'))

# List training batch files (shows first few)
train_dir = '/content/drive/Othercomputers/My_Laptop/all_text_batches/train'
if os.path.exists(train_dir):
    print(os.listdir(train_dir)[:5]) # Print first 5 files
else:
    print(f"Directory not found: {train_dir}")

True
['train_batch_1229.txt', 'train_batch_1230.txt', 'train_batch_1231.txt', 'train_batch_1234.txt', 'train_batch_1242.txt']


In [3]:
import numpy as np
import random
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

import math
from torch import nn
from torch.optim import Adam
import tqdm

In [4]:
n_embd = 64
max_len = 128
batch_size = 64
n_layers = 2
n_heads = 2
dropout = 0.1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class BertDataset_v1(Dataset):
    def __init__(self, path_to_data_dir, tokenizer, seq_len,):
        self.paths = [str(x) for x in Path(path_to_data_dir).glob('**/*.txt')] # Use path_to_data_dir
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.pad_token_id = self.tokenizer.vocab['[PAD]'] # Store PAD ID for convenience

        self.total_samples = 0
        self.file_line_counts = [] # Stores (file_path, num_lines_in_file)
        self.cumulative_line_offsets = [0] # Stores cumulative sums of lines for quick lookup

        print("\n[Dataset Init] Counting lines across all batch files...")
        for p in self.paths:
            with open(p, 'r', encoding='utf-8') as f:
                num_lines_in_file = sum(1 for _ in f) # Efficiently count lines
            self.file_line_counts.append((p, num_lines_in_file))
            self.total_samples += num_lines_in_file
            self.cumulative_line_offsets.append(self.total_samples) # Add cumulative sum

        if self.total_samples == 0:
            raise ValueError(f"No lines found in any files in {path_to_data_dir}. Check your data.")
        print(f"[Dataset Init] Found {self.total_samples} total samples across {len(self.paths)} files.")


    def __len__(self):
        return self.total_samples

    def __getitem__(self, index):
        # Handle index out of bounds (shouldn't happen if __len__ is correct, but good for robustness)
        if not (0 <= index < self.total_samples):
            raise IndexError(f"Index {index} is out of bounds for dataset of size {self.total_samples}")

        # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
        t1_str, t2_str, is_next_label = self.get_sent(index)

        # Tokenize and get input IDs for t1 and t2 (after stripping)
        # Note: self.tokenizer('string')['input_ids'][1:-1]
        # This will remove [CLS] and [SEP] added by the tokenizer for single sentences.
        t1_token_ids = self.tokenizer(t1_str)['input_ids'][1:-1]
        t2_token_ids = self.tokenizer(t2_str)['input_ids'][1:-1]

        # Step 2: replace random words in sentence with mask / random words
        t1_random, t1_label = self.random_word(t1_token_ids)  # Pass token_ids, not string
        t2_random, t2_label = self.random_word(t2_token_ids)  # Pass token_ids, not string

        # Step 3: Adding CLS and SEP tokens to the start and end of sentences
        # Adding PAD token for labels
        cls_id = self.tokenizer.vocab['[CLS]']
        sep_id = self.tokenizer.vocab['[SEP]']
        pad_id = self.tokenizer.vocab['[PAD]']  # Use stored pad_id

        t1_final = [cls_id] + t1_random + [sep_id]
        t2_final = t2_random + [sep_id]  # T2 does not start with CLS in BERT

        # Labels for MLM must correspond to the final input sequence.
        # Pad t1_label and t2_label with PAD_ID for tokens that are not masked.
        t1_label_final = [pad_id] + t1_label + [pad_id]
        t2_label_final = t2_label + [pad_id]

        # Step 4: combine sentence 1 and 2 as one input
        # adding PAD tokens to make the sentence same length as seq_len
        # The segment label for t1 is 0, for t2 is 1.
        segment_label = [1 for _ in range(len(t1_final))] + [2 for _ in range(len(t2_final))]

        bert_input = (t1_final + t2_final)[:self.seq_len]
        bert_label = (t1_label_final + t2_label_final)[:self.seq_len]
        segment_label = segment_label[:self.seq_len]  # Ensure segment_label is also truncated

        # Calculate padding needed
        padding_length = self.seq_len - len(bert_input)
        if padding_length < 0:  # Should not happen with [:self.seq_len] but as a safeguard
            padding_length = 0

        padding_list = [pad_id for _ in range(padding_length)]

        bert_input.extend(padding_list)
        bert_label.extend(padding_list)
        segment_label.extend(padding_list)

        output = {"input_ids": bert_input,
                  "input_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}

    def random_word(self, token_ids):  # Now expects list of token_ids
        output = []
        output_label = []
        for tok_id in token_ids:
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15
                if prob < 0.8:  # 80% chance change token to mask token
                    output.append(self.tokenizer.vocab['[MASK]'])
                elif prob < 0.9:  # 10% chance change token to random token
                    output.append(random.randrange(len(self.tokenizer.vocab)))
                else:  # 10% chance change token to current token
                    output.append(tok_id)
                output_label.append(tok_id)  # Original token ID is the label for these 15%
            else:  # 85% chance: token left unchanged
                output.append(tok_id)
                output_label.append(self.pad_token_id)  # Label is PAD for unmasked tokens (don't predict)
        return output, output_label

    def get_file_lines(self, path):
        with open(path, 'r', encoding="utf-8") as f:
            return f.readlines()

    def get_sent(self, index):
        # Find which file and line number within that file the index corresponds to
        # Uses cumulative_line_offsets to efficiently find the file
        file_idx = 0
        for i in range(len(self.cumulative_line_offsets) - 1):
            if self.cumulative_line_offsets[i] <= index < self.cumulative_line_offsets[i+1]:
                file_idx = i
                break
        line_idx_in_file = index - self.cumulative_line_offsets[file_idx]

        lines = self.get_file_lines(self.paths[file_idx])
        # Get t1
        t1_str = lines[line_idx_in_file]

        prob = random.random()
        if prob > 0.5: # Positive Pair
            # Try to get the next line from the same file
            if line_idx_in_file < self.file_line_counts[file_idx][1] - 1: # [1] is num_lines_in_file
                t2_str = lines[line_idx_in_file+1]
                return t1_str, t2_str, 1
            else: # Last line of current file, try to get from next file
                if file_idx < len(self.paths) - 1: # Check if there's a next file
                    with open(self.paths[file_idx+1], 'r', encoding='utf-8') as next_file:
                        t2_str = next_file.readline()
                    return t1_str, t2_str, 1
                else: # Last line of the entire corpus
                    random_line_from_file = lines[random.randrange(len(lines))]
                    return t1_str, random_line_from_file, 0 # Now it's a negative pair

        else: # Negative Pair
            random_line_from_file = lines[random.randrange(len(lines))]
            return t1_str, random_line_from_file, 0

In [11]:
class PositionEmbedding(nn.Module):

    def __init__(self, d_model, max_len=128):
        super().__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.requires_grad = False

        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** (i / d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((i + 1) / d_model)))

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self):
        return self.pe


class BERTEmbeddings(nn.Module):

    def __init__(self, d_model, vocab_size, seq_len, dropout=0.1):
        super().__init__()
        self.token = nn.Embedding(vocab_size, d_model, padding_idx=0)    # (seq_len, d_model)
        self.segment = nn.Embedding(3, d_model, padding_idx=0)           # (seq_len, d_model)
        self.position = PositionEmbedding(d_model, seq_len)              # (seq_len, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, ids, segment_label):
        x = self.token(ids) + self.segment(segment_label) + self.position()
        return self.dropout(x)


class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0

        self.n_heads = n_heads
        self.head_size = d_model // n_heads

        self.key = nn.Linear(d_model, d_model)    # (d_model, n_heads * head_size)
        self.query = nn.Linear(d_model, d_model)  # (d_model, n_heads * head_size)
        self.value = nn.Linear(d_model, d_model)  # (d_model, n_heads * head_size)
        self.proj = nn.Linear(d_model, d_model)   # (d_model, n_heads * head_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # mask.shape = (batch_size, 1, 1, max_len)

        B, T, D = x.shape  # B - batch size, T - max_len, D - n_embd

        k = self.key(x).view(B, -1, self.n_heads, self.head_size).permute(0, 2, 1, 3)    # (B, n_heads, T, head_size)
        q = self.query(x).view(B, -1, self.n_heads, self.head_size).permute(0, 2, 1, 3)  # (B, n_heads, T, head_size)
        v = self.value(x).view(B, -1, self.n_heads, self.head_size).permute(0, 2, 1, 3)  # (B, n_heads, T, head_size)

        # (B, n_heads, T, head_size) @ (B, n_heads, head_size, T) -> (B, n_heads, T, T)
        att_scores = (q @ k.permute(0, 1, 3, 2)) * (1.0 / math.sqrt(self.head_size))

        att_scores = att_scores.masked_fill(mask == 0, float('-inf'))

        att_scores = nn.functional.softmax(att_scores, -1)
        att_scores = self.dropout(att_scores)  # suppose adding dropout layer here is the original way

        # (B, n_heads, T, T) * (B, n_heads, T, head_size) -> (B, n_heads, T, head_size)
        out = att_scores @ v
        out = out.transpose(1, 2).contiguous().view(B, T,
                                                    self.n_heads * self.head_size)  # (B, T, n_heads * head_size)

        return self.proj(out)  # (B, T, n_heads * head_size) - original shape


class FeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.mlp(x)


class Block(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ffwd = FeedForward(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # x: (batch_size, max_len, n_embd)
        # encoder mask: (batch_size, 1, max_len, max_len)
        # result: (batch_size, max_len, d_model)
        x = x + self.dropout(self.sa(self.ln1(x), mask))
        x = x + self.dropout(self.ffwd(self.ln2(x)))
        return self.dropout(x)


class MaskedLanguageModel(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.lin = nn.Linear(d_model, vocab_size)
        self.softmax = nn.LogSoftmax(-1)

    def forward(self, x):
        return self.softmax(self.lin(x))  # (batch_size, seq_len, vocab_size)

class NextSentencePrediction(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.lin = nn.Linear(d_model, 2)
        self.softmax = nn.LogSoftmax(-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        # (B, T, D) -> (B, 2)
        return self.softmax(self.lin(x[:,0]))  # x[:,0] <=> x[:,0, :]


class BERT(nn.Module):
    def __init__(self, n_layers, vocab_size, d_model, n_heads, max_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embed = BERTEmbeddings(d_model, vocab_size, max_len, dropout)
        self.blocks = nn.ModuleList([Block(d_model, n_heads, dropout) for _ in range(n_layers)])
        self.mlm = MaskedLanguageModel(d_model, vocab_size)
        self.nsp = NextSentencePrediction(d_model)

    def forward(self, ids, segment_label):
        # shape of ids is (batch_size, seq_len)
        # attention masking for padded token
        # mask = (ids > 0).unsqueeze(1).repeat(1, ids.size(1), 1).unsqueeze(1)  # (batch_size, seq_len) -> (batch_size, 1, seq_len, seq_len)
        mask = (ids != 0).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)

        x = self.embed(ids, segment_label)
        for block in self.blocks:
            x = block(x, mask)

        # (B,T,vocab_size), (B,2)
        return self.mlm(x), self.nsp(x)


class ScheduledOptim:
    """Wrapper for optimizer with warmup scheduling and checkpoint support."""
    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self._step = 0  # current step count
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        """Increment step, update LR, then step optimizer."""
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return min(
            self._step ** -0.5,
            self.n_warmup_steps ** -1.5 * self._step
        )

    def _update_learning_rate(self):
        """Update LR on each step (warmup + decay)."""
        self._step += 1
        lr = self.init_lr * self._get_lr_scale()
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

    def state_dict(self):
        """Return scheduler state for checkpointing."""
        return {
            '_step': self._step,
            'warmup_steps': self.n_warmup_steps,
            'optimizer': self._optimizer.state_dict()
        }

    def load_state_dict(self, state_dict):
        """Load scheduler state (step count, warmup, and optimizer state)."""
        self._step = state_dict['_step']
        self.n_warmup_steps = state_dict['warmup_steps']
        self._optimizer.load_state_dict(state_dict['optimizer'])


class BertTrainer:
    def __init__(
            self,
            model,
            train_dataloader,
            test_dataloader=None,
            lr=1e-4,
            weight_decay=0.01,
            betas=(0.9, 0.999),
            warmup_steps=10000,
            log_freq=10,
            device='cuda'
    ):
        self.model = model
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.log_freq = log_freq
        self.device = device

        self.model.to(self.device)

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(self.optim, self.model.d_model, warmup_steps)

        self.mlm_criterion = nn.NLLLoss(ignore_index=0)  # ignore_index=0 tells the program to ignore [PAD] tokens
        self.nsp_criterion = nn.NLLLoss()

    def train(self, epoch):
        self.iteration(epoch, self.train_dataloader)

    def test(self, epoch):
        self.iteration(epoch, self.test_dataloader, train=False)

    def iteration(self, epoch, data_loader, train=True):

        # Reset statistics for the current epoch
        total_mlm_loss = 0
        total_nsp_loss = 0
        total_loss = 0

        # For NSP accuracy
        total_nsp_correct = 0
        total_nsp_elements = 0 # Number of samples for NSP prediction

        # For MLM accuracy (more complex, requires ignoring padded tokens and non-masked tokens)
        total_mlm_correct = 0
        total_mlm_elements = 0 # Number of masked tokens

        mode = 'train' if train else 'test'

        # Set model to train/eval mode
        if train:
            self.model.train()
        else:
            self.model.eval() # Use eval mode for test/validation to disable dropout/batchnorm


        # progress bar
        # Add `leave=True` if you want the progress bar to remain on screen after completion
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}",
            leave=True
        )

        for i, batch in data_iter:
            batch = {key: value.to(self.device) for key, value in batch.items()}

            # Set gradients to zero only if training
            if train:
                self.optim_schedule.zero_grad()

            mlm_output, nsp_output = self.model(batch["input_ids"], batch["segment_label"])

            # Calculate MLM Loss
            # mlm_output: (B, T, V), target: (B, T)
            # NLLLoss expects (B, V, T) for input
            mlm_loss = self.mlm_criterion(mlm_output.transpose(1, 2), batch["input_label"])

            # Calculate NSP Loss
            # nsp_output: (B, 2), target: (B,)
            nsp_loss = self.nsp_criterion(nsp_output, batch["is_next"])

            loss = mlm_loss + nsp_loss

            if train:
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # Update total losses
            total_mlm_loss += mlm_loss.item()
            total_nsp_loss += nsp_loss.item()
            total_loss += loss.item()

            # Calculate NSP Accuracy
            predicted_next_sentence_labels = nsp_output.argmax(dim=-1) # Use dim=-1 for argmax
            current_nsp_correct = (predicted_next_sentence_labels == batch["is_next"]).sum().item()
            current_nsp_elements = batch["is_next"].nelement() # Total NSP samples in current batch

            total_nsp_correct += current_nsp_correct
            total_nsp_elements += current_nsp_elements

            # Calculate MLM Accuracy
            predicted_mlm_labels = mlm_output.argmax(dim=-1) # (B, T)
            # Find which labels are actually masked (i.e., not 0 / PAD)
            masked_positions = (batch["input_label"] != 0) # Boolean mask for masked tokens
            current_mlm_correct = (predicted_mlm_labels[masked_positions] == batch["input_label"][masked_positions]).sum().item()
            current_mlm_elements = masked_positions.sum().item() # Count of actual masked tokens

            total_mlm_correct += current_mlm_correct
            total_mlm_elements += current_mlm_elements


            # Calculate metrics for the *current* iteration for `post_fix`
            current_avg_loss = total_loss / (i + 1)
            current_nsp_acc = (total_nsp_correct / total_nsp_elements * 100) if total_nsp_elements > 0 else 0
            current_mlm_acc = (total_mlm_correct / total_mlm_elements * 100) if total_mlm_elements > 0 else 0


            post_fix = {
                "epoch": epoch,
                "iter": i,
                "loss": loss.item(), # Current batch loss
                "avg_loss": current_avg_loss, # Average loss for epoch so far
                "mlm_loss": mlm_loss.item(),
                "nsp_loss": nsp_loss.item(),
                "nsp_acc": current_nsp_acc, # Average NSP accuracy for epoch so far
                "mlm_acc": current_mlm_acc # Average MLM accuracy for epoch so far
            }

            # Update the progress bar's postfix
            data_iter.set_postfix(post_fix)

            # Old way of writing to console, `set_postfix` is more integrated with tqdm
            # if i % self.log_freq == 0:
            #     data_iter.write(str(post_fix))

            # SAVE CHECKPOINTS EVERY 30 STEPS
            if train and (i + 1) % 30 == 0:
                checkpoint = {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optim_schedule._optimizer.state_dict(),
                    'scheduler_state_dict': self.optim_schedule.state_dict(),
                    'epoch': epoch,
                    'step': self.optim_schedule._step
                }
                checkpoint_path = '/content/drive/MyDrive/bert-checkpoint/bert_checkpoint.pt'
                torch.save(checkpoint, checkpoint_path)
                print(f"\nSaved checkpoint at epoch {epoch}, step {i+1}")

        # print information after each epoch
        # Calculate final epoch metrics
        final_avg_loss = total_loss / len(data_iter)
        final_nsp_acc = (total_nsp_correct / total_nsp_elements * 100) if total_nsp_elements > 0 else 0
        final_mlm_acc = (total_mlm_correct / total_mlm_elements * 100) if total_mlm_elements > 0 else 0

        print(
            f"EP{epoch}, {mode}: \
            avg_loss={final_avg_loss:.4f}, \
            nsp_acc={final_nsp_acc:.2f}%, \
            mlm_acc={final_mlm_acc:.2f}%"
        )


In [7]:
vocab_path = '/content/drive/Othercomputers/My_Laptop/bert-it-1/bert-it-vocab.txt'
tokenizer = BertTokenizer.from_pretrained(vocab_path, local_files_only=True)

print("Tokenizer loaded successfully!")
print(f"Vocab size: {len(tokenizer)}")

print(tokenizer("ovo je rečenica"))



Tokenizer loaded successfully!
Vocab size: 30000
{'input_ids': [1, 2054, 1865, 16573, 1008, 1941, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}


In [8]:
train_dataset = BertDataset_v1(train_dir, tokenizer, seq_len=max_len)

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True)

print(next(iter(train_dataloader)))


[Dataset Init] Counting lines across all batch files...
[Dataset Init] Found 33620000 total samples across 3362 files.
{'input_ids': tensor([[    1,  8101,     3,  ...,     0,     0,     0],
        [    1,  2341, 12720,  ...,     0,     0,     0],
        [    1, 21053,  5892,  ...,     0,     0,     0],
        ...,
        [    1,  2054, 28568,  ...,     0,     0,     0],
        [    1,     3, 22566,  ...,     0,     0,     0],
        [    1,  2573,    16,  ...,     0,     0,     0]]), 'input_label': tensor([[    0,     0,  6702,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0, 26442,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0]]), 'segment_label': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
      

In [12]:
bert_model = BERT(
    n_layers,
    len(tokenizer.vocab),
    n_embd,
    n_heads,
    max_len,
    dropout
)

bert_trainer = BertTrainer(bert_model, train_dataloader, device=device)

epochs = 20

for epoch in range(epochs):
    bert_trainer.train(epoch)

EP_train:0:   0%|| 30/525313 [00:43<186:15:12,  1.28s/it, epoch=0, iter=29, loss=11.9, avg_loss=12, mlm_loss=11.1, nsp_loss=0.782, nsp_acc=49.9, mlm_acc=0]

Saved checkpoint at epoch 0, step 30


EP_train:0:   0%|| 60/525313 [01:23<197:27:46,  1.35s/it, epoch=0, iter=59, loss=11.9, avg_loss=11.9, mlm_loss=11.1, nsp_loss=0.769, nsp_acc=50.8, mlm_acc=0.00329]

Saved checkpoint at epoch 0, step 60


EP_train:0:   0%|| 90/525313 [02:03<181:15:47,  1.24s/it, epoch=0, iter=89, loss=11.8, avg_loss=11.9, mlm_loss=11, nsp_loss=0.761, nsp_acc=50.5, mlm_acc=0.00221]

Saved checkpoint at epoch 0, step 90


EP_train:0:   0%|| 120/525313 [02:43<186:48:43,  1.28s/it, epoch=0, iter=119, loss=11.8, avg_loss=11.9, mlm_loss=11, nsp_loss=0.774, nsp_acc=50.6, mlm_acc=0.00166]

Saved checkpoint at epoch 0, step 120


EP_train:0:   0%|| 150/525313 [03:26<192:32:20,  1.32s/it, epoch=0, iter=149, loss=11.7, avg_loss=11.9, mlm_loss=11, nsp_loss=0.71, nsp_acc=50.1, mlm_acc=0.00266]

Saved checkpoint at epoch 0, step 150


EP_train:0:   0%|| 180/525313 [04:06<184:27:01,  1.26s/it, epoch=0, iter=179, loss=11.6, avg_loss=11.8, mlm_loss=10.9, nsp_loss=0.662, nsp_acc=50.1, mlm_acc=0.00332]

Saved checkpoint at epoch 0, step 180


EP_train:0:   0%|| 210/525313 [04:46<185:37:32,  1.27s/it, epoch=0, iter=209, loss=11.6, avg_loss=11.8, mlm_loss=10.8, nsp_loss=0.811, nsp_acc=50, mlm_acc=0.00284]

Saved checkpoint at epoch 0, step 210


EP_train:0:   0%|| 240/525313 [05:27<187:39:43,  1.29s/it, epoch=0, iter=239, loss=11.5, avg_loss=11.8, mlm_loss=10.7, nsp_loss=0.753, nsp_acc=50, mlm_acc=0.00249]

Saved checkpoint at epoch 0, step 240


EP_train:0:   0%|| 270/525313 [06:06<180:00:49,  1.23s/it, epoch=0, iter=269, loss=11.5, avg_loss=11.7, mlm_loss=10.7, nsp_loss=0.817, nsp_acc=49.9, mlm_acc=0.00221]

Saved checkpoint at epoch 0, step 270


EP_train:0:   0%|| 300/525313 [06:46<183:59:16,  1.26s/it, epoch=0, iter=299, loss=11.4, avg_loss=11.7, mlm_loss=10.6, nsp_loss=0.81, nsp_acc=49.8, mlm_acc=0.00199]

Saved checkpoint at epoch 0, step 300


EP_train:0:   0%|| 330/525313 [07:25<204:59:54,  1.41s/it, epoch=0, iter=329, loss=11.2, avg_loss=11.7, mlm_loss=10.5, nsp_loss=0.691, nsp_acc=49.9, mlm_acc=0.00785]

Saved checkpoint at epoch 0, step 330


EP_train:0:   0%|| 360/525313 [08:06<188:53:57,  1.30s/it, epoch=0, iter=359, loss=11.2, avg_loss=11.6, mlm_loss=10.4, nsp_loss=0.78, nsp_acc=49.7, mlm_acc=0.0188]

Saved checkpoint at epoch 0, step 360


EP_train:0:   0%|| 390/525313 [08:46<185:48:21,  1.27s/it, epoch=0, iter=389, loss=11.1, avg_loss=11.6, mlm_loss=10.3, nsp_loss=0.736, nsp_acc=49.6, mlm_acc=0.0509]

Saved checkpoint at epoch 0, step 390


EP_train:0:   0%|| 420/525313 [09:25<184:12:37,  1.26s/it, epoch=0, iter=419, loss=11.1, avg_loss=11.6, mlm_loss=10.4, nsp_loss=0.783, nsp_acc=49.6, mlm_acc=0.114]

Saved checkpoint at epoch 0, step 420


EP_train:0:   0%|| 450/525313 [10:05<183:04:05,  1.26s/it, epoch=0, iter=449, loss=11.1, avg_loss=11.5, mlm_loss=10.3, nsp_loss=0.732, nsp_acc=49.7, mlm_acc=0.222]

Saved checkpoint at epoch 0, step 450


EP_train:0:   0%|| 480/525313 [10:44<186:23:34,  1.28s/it, epoch=0, iter=479, loss=11, avg_loss=11.5, mlm_loss=10.2, nsp_loss=0.766, nsp_acc=49.6, mlm_acc=0.375]

Saved checkpoint at epoch 0, step 480


EP_train:0:   0%|| 509/525313 [11:22<195:30:33,  1.34s/it, epoch=0, iter=508, loss=10.8, avg_loss=11.4, mlm_loss=10.1, nsp_loss=0.722, nsp_acc=49.7, mlm_acc=0.523]


KeyboardInterrupt: 

In [None]:
# checkpoint_path = '/content/drive/MyDrive/bert-checkpoints/bert_checkpoint.pt'

# checkpoint = torch.load(checkpoint_path)

# bert_model = BERT(
#     n_layers,
#     len(tokenizer.vocab),
#     n_embd,
#     n_heads,
#     max_len,
#     dropout
# )
# bert_trainer = BertTrainer(bert_model, train_dataloader, device=device)


# bert_model.load_state_dict(checkpoint['model_state_dict'])
# bert_trainer.optim_schedule.load_state_dict(checkpoint['scheduler_state_dict'])
# epoch = checkpoint['epoch']

# print('Checkpoint loaded successfully!')
# print('Training continued...')

# for ep in range(epoch, epochs):
#     bert_trainer.train(ep)