In [2]:
from transformers import BertTokenizerFast
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from datasets import Dataset as HFDataset
import torch
import torch.nn.functional as F
import math
import torch.nn as nn
import torch.optim as optim

In [3]:
!pip install datasets



In [4]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
sentence = "By the way, we ball"
tokens = tokenizer.tokenize(sentence)
print(tokens)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

['by', 'the', 'way', ',', 'we', 'ball']


In [5]:
mnli_dataset = load_dataset("multi_nli")

README.md:   0%|          | 0.00/8.89k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

(…)alidation_matched-00000-of-00001.parquet:   0%|          | 0.00/4.94M [00:00<?, ?B/s]

(…)dation_mismatched-00000-of-00001.parquet:   0%|          | 0.00/5.10M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

In [6]:
print(mnli_dataset['train'][0])

{'promptID': 31193, 'pairID': '31193n', 'premise': 'Conceptually cream skimming has two basic dimensions - product and geography.', 'premise_binary_parse': '( ( Conceptually ( cream skimming ) ) ( ( has ( ( ( two ( basic dimensions ) ) - ) ( ( product and ) geography ) ) ) . ) )', 'premise_parse': '(ROOT (S (NP (JJ Conceptually) (NN cream) (NN skimming)) (VP (VBZ has) (NP (NP (CD two) (JJ basic) (NNS dimensions)) (: -) (NP (NN product) (CC and) (NN geography)))) (. .)))', 'hypothesis': 'Product and geography are what make cream skimming work. ', 'hypothesis_binary_parse': '( ( ( Product and ) geography ) ( ( are ( what ( make ( cream ( skimming work ) ) ) ) ) . ) )', 'hypothesis_parse': '(ROOT (S (NP (NN Product) (CC and) (NN geography)) (VP (VBP are) (SBAR (WHNP (WP what)) (S (VP (VBP make) (NP (NP (NN cream)) (VP (VBG skimming) (NP (NN work)))))))) (. .)))', 'genre': 'government', 'label': 1}


In [7]:
class MNLIDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.data)
    def  __getitem__(self, idx):
        example = self.data[idx]
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = example['label']
        encoded_pair = self.tokenizer.encode_plus(premise, hypothesis, max_length=self.max_length, padding='max_length', truncation=True,return_tensors='pt')
        input_ids = encoded_pair['input_ids'].squeeze(0)
        attention_mask = encoded_pair['attention_mask'].squeeze(0)
        # Keeping this commented out for now, maybe not very essential for encoder only models? Investigate further...
        # token_type_ids = encoded_pair.get('token_type_ids', torch.zeros_like(input_ids))
        return {'input_ids': input_ids,'attention_mask': attention_mask,# 'token_type_ids': token_type_ids,
'labels': torch.tensor(label)}


In [8]:
train_data = mnli_dataset["train"]
max_seq_length = 128
train_dataset = MNLIDataset(train_data, tokenizer, max_seq_length)
print(f"Size of training dataset: {len(train_dataset)}")

Size of training dataset: 392702


In [9]:
sample = train_dataset[0]
print(sample['input_ids'].shape)
print(sample['attention_mask'].shape)
print(sample['labels'])

torch.Size([128])
torch.Size([128])
tensor(1)


In [10]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_rate):
        super().__init__()
        self.num_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        assert self.head_dim * self.num_heads == hidden_size

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout_rate)
        self.output = nn.Linear(hidden_size, hidden_size)
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len_q, seq_len_k, seq_len_v = query.size(1), key.size(1), value.size(1)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        query = query.view(batch_size, seq_len_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = key.view(batch_size, seq_len_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = value.view(batch_size, seq_len_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        attention_scores = torch.matmul(query, key.transpose(-2,-1))
        attention_scores = attention_scores/(self.head_dim**0.5)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask==0, float('-inf'))
        attention_weights = F.softmax(attention_scores,dim=-1)
        scaled_attention = torch.matmul(attention_weights, value)
        scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len_q, self.num_heads * self.head_dim)
        output = self.output(scaled_attention)
        return output





In [11]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout_rate):
        super().__init__()
        self.dense1 = nn.Linear(hidden_size, intermediate_size)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = F.relu(self.dense1(x))
        x = self.dropout(x)
        x = self.dense2(x)
        return x

In [12]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_rate):
        super().__init__()
        self.self_attention = MultiHeadSelfAttention(hidden_size, num_attention_heads, dropout_rate)
        self.feed_forward = FeedForwardNetwork(hidden_size, intermediate_size, dropout_rate)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask):
        attention_output = self.self_attention(x,x, x, mask)
        normed_output1 = self.norm1(attention_output + x)
        ff_output = self.feed_forward(normed_output1)
        final_output = self.norm2(ff_output + normed_output1)
        return final_output

In [13]:
class FactorizedEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.projection = nn.Linear(embedding_dim, hidden_size)

    def forward(self, input_ids):
        factor_embeds = self.word_embeddings(input_ids)
        project_embeds = self.projection(factor_embeds)
        return project_embeds

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, max_seq_length, dropout_rate):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)

        position = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2) * -(math.log(10000.0) / hidden_size))
        pe = torch.zeros(max_seq_length, 1, hidden_size)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe) # Store as a buffer (not a learnable parameter)

    def forward(self, x):
        seq_length = x.size(1)
        pe = self.pe[:seq_length].squeeze(1)
        x = x + pe.unsqueeze(0)
        x = self.dropout(x)
        return x

In [15]:
class AtomBERT(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, num_attention_heads, intermediate_size, num_classes, max_seq_length, dropout_rate):
        super().__init__()
        self.embedding = FactorizedEmbedding(vocab_size, embedding_dim, hidden_size)
        self.positional_encoding = PositionalEncoding(hidden_size, max_seq_length, dropout_rate)
        self.encoder_layer = TransformerEncoderLayer(hidden_size, num_attention_heads, intermediate_size, dropout_rate)
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        embeddings = self.embedding(input_ids)
        embeddings = self.positional_encoding(embeddings)

        # Share the encoder layer across all layers
        encoder_output = embeddings
        for _ in range(self.num_layers):
            encoder_output = self.encoder_layer(encoder_output, attention_mask)

        pooled_output = encoder_output[:, 0, :] # Shape: (batch_size, hidden_size)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [16]:
class AtomBERTConfig:
    def __init__(
        self,
        vocab_size,
        embedding_dim=128,
        hidden_size=768,
        num_layers=4,
        num_attention_heads=4,
        intermediate_size=1500,
        num_classes=2, # For SOP
        max_seq_length=128,
        dropout_rate=0.1,
    ):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.num_classes = num_classes
        self.max_seq_length = max_seq_length
        self.dropout_rate = dropout_rate

In [17]:
class AtomBERTForPretraining(nn.Module): # Renamed for clarity
    def __init__(self, config):
        super().__init__()
        self.bert = AtomBERT(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            num_attention_heads=config.num_attention_heads,
            intermediate_size=config.intermediate_size,
            num_classes=2, # For SOP
            max_seq_length=config.max_seq_length,
            dropout_rate=config.dropout_rate
        )
        self.mlm_head = nn.Sequential(
        nn.Linear(config.hidden_size, config.hidden_size),
        nn.LayerNorm(config.hidden_size),
        nn.Linear(config.hidden_size, config.vocab_size)
        )
        self.sop_head = nn.Linear(config.hidden_size, 2)



    def forward(self, input_ids, attention_mask, masked_labels=None, sentence_order_labels=None):
        outputs = self.bert(input_ids, attention_mask) # Shape: (batch_size, seq_len, hidden_size)

        # MLM Prediction
        prediction_logits_mlm = self.mlm_head(outputs) # Shape: (batch_size, seq_len, vocab_size)

        # SOP Prediction (using the pooled output - we might need to adjust this)
        pooled_output = outputs[:, 0, :] # Taking the [CLS] token representation
        prediction_logits_sop = self.sop_head(pooled_output) # Shape: (batch_size, 2)

        total_loss = None
        if masked_labels is not None and sentence_order_labels is not None:
            loss_fct_mlm = nn.CrossEntropyLoss()
            masked_loss = loss_fct_mlm(prediction_logits_mlm.view(-1, self.config.vocab_size), masked_labels.view(-1))

            loss_fct_sop = nn.CrossEntropyLoss()
            sop_loss = loss_fct_sop(prediction_logits_sop.view(-1, 2), sentence_order_labels.view(-1))

            total_loss = masked_loss + sop_loss # You might want to weigh these differently

        elif masked_labels is not None:
            loss_fct_mlm = nn.CrossEntropyLoss()
            total_loss = loss_fct_mlm(prediction_logits_mlm.view(-1, self.config.vocab_size), masked_labels.view(-1))

        elif sentence_order_labels is not None:
            loss_fct_sop = nn.CrossEntropyLoss()
            total_loss = loss_fct_sop(prediction_logits_sop.view(-1, 2), sentence_order_labels.view(-1))

        return total_loss, prediction_logits_mlm, prediction_logits_sop

In [22]:
wiki_text_dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
print(wiki_text_dataset)

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/157M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/157M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})


In [23]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding='max_length', max_length= max_seq_length)
tokenized_train_dataset = wiki_text_dataset["train"].map(tokenize_function, batched=True)

print(tokenized_train_dataset[2])


Map:   0%|          | 0/1801350 [00:00<?, ? examples/s]

{'text': '', 'input_ids': [101, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [33]:
print(tokenized_train_dataset[1])

{'text': ' = Valkyria Chronicles III = \n', 'input_ids': [101, 1027, 11748, 4801, 4360, 11906, 3523, 1027, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [34]:
print(tokenized_sop_dataset[1])

{'text': '[CLS] Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . [SEP] Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit .', 'sentence_order_labels': 1, 'input_ids': [101, 101, 11748, 4801, 4360, 1997, 1996, 11686, 1017, 1007, 1010, 4141, 3615, 2000, 2004, 11748, 4801, 4360, 11906, 3523, 2648, 2900, 1010, 2003, 1037, 8608, 2535, 1030, 1011, 1030, 2652, 2678, 2208, 2764, 2011, 16562, 1998, 2865, 1012, 4432, 2005, 1996, 9160, 12109, 1012, 102, 12411, 5558, 2053, 11748, 4801, 4360, 1017, 1024, 4895, 2890, 27108, 5732, 11906, 1006, 2887, 1024, 1856, 1806, 1671, 30222, 30218, 30259, 30227, 30255, 30258, 30219, 2509, 1010, 5507, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 

In [24]:
def prepare_mlm_data(examples):
    input_ids = torch.tensor(examples["input_ids"])
    batch_size = len(input_ids)
    sequence_length = len(input_ids[0])
    mask_prob = 0.15
    mask_token_id = tokenizer.mask_token_id
    pad_token_id = tokenizer.pad_token_id
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id

    # Create a random mask of the same shape as input_ids
    rand = torch.rand(batch_size, sequence_length)
    # Mask tokens with probability mask_prob
    mask = (rand < mask_prob) & (input_ids != cls_token_id) & (input_ids != sep_token_id) & (input_ids != pad_token_id)

    # Duplicate the input_ids to create labels, and set labels for non-masked tokens to -100
    labels = input_ids.clone()
    labels[~mask] = -100

    # Replace masked input_ids with the mask token
    inputs = input_ids.clone()
    inputs[mask] = mask_token_id

    return {"input_ids": inputs, "labels": labels}

In [25]:
mlm_tokenized_train_dataset = tokenized_train_dataset.map(
    prepare_mlm_data,
    batched=True,
    remove_columns=["text", "token_type_ids"] # We can remove these columns as we won't need them further
)

print(mlm_tokenized_train_dataset[0])

Map:   0%|          | 0/1801350 [00:00<?, ? examples/s]

{'input_ids': [101, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -

In [26]:
import re

def split_into_sentences_robust(text):
    # Pattern to split by '.', '?', or '!' followed by whitespace or end of string
    sentence_pattern = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s+'
    sentences = re.split(sentence_pattern, text)
    return [s.strip() for s in sentences if s.strip()]

In [27]:
first_text = wiki_text_dataset["train"][3]["text"]
sentences = split_into_sentences_robust(first_text)
print(sentences)

['Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit .', 'Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable .', 'Released in January 2011 in Japan , it is the third game in the Valkyria series .', 'Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " .']


In [28]:
sop_texts_limited = []
sop_labels_limited = []
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
max_pairs_per_document = 5 # You can adjust this number

for example in wiki_text_dataset["train"]:
    text = example["text"]
    sentences = split_into_sentences_robust(text)
    pairs_count = 0 # Initialize a counter for each document
    for i in range(len(sentences) - 1):
        if pairs_count >= max_pairs_per_document:
            break # Stop if we've reached the limit for this document

        sentence_a = sentences[i]
        sentence_b = sentences[i + 1]

        # Original order
        sop_texts_limited.append(cls_token + " " + sentence_a + " " + sep_token + " " + sentence_b)
        sop_labels_limited.append(0)
        pairs_count += 1

        if pairs_count >= max_pairs_per_document:
            break # Stop after adding the original pair if we've reached the limit

        # Swapped order
        sop_texts_limited.append(cls_token + " " + sentence_b + " " + sep_token + " " + sentence_a)
        sop_labels_limited.append(1)
        pairs_count += 1

print(f"Number of limited SOP examples generated: {len(sop_texts_limited)}")
print(f"Number of limited SOP labels generated: {len(sop_labels_limited)}")

Number of limited SOP examples generated: 3189245
Number of limited SOP labels generated: 3189245


In [29]:
sop_dataset = HFDataset.from_dict({"text": sop_texts_limited, "sentence_order_labels": sop_labels_limited})

# 2. Define the tokenization function
def tokenize_sop_function(examples):
    return tokenizer(examples["text"], truncation=True, padding='max_length', max_length= max_seq_length)

# 3. Apply the tokenization function using map
tokenized_sop_dataset = sop_dataset.map(tokenize_sop_function, batched=True)

print(tokenized_sop_dataset[0])

Map:   0%|          | 0/3189245 [00:00<?, ? examples/s]

{'text': '[CLS] Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . [SEP] Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable .', 'sentence_order_labels': 0, 'input_ids': [101, 101, 12411, 5558, 2053, 11748, 4801, 4360, 1017, 1024, 4895, 2890, 27108, 5732, 11906, 1006, 2887, 1024, 1856, 1806, 1671, 30222, 30218, 30259, 30227, 30255, 30258, 30219, 2509, 1010, 5507, 1012, 102, 11748, 4801, 4360, 1997, 1996, 11686, 1017, 1007, 1010, 4141, 3615, 2000, 2004, 11748, 4801, 4360, 11906, 3523, 2648, 2900, 1010, 2003, 1037, 8608, 2535, 1030, 1011, 1030, 2652, 2678, 2208, 2764, 2011, 16562, 1998, 2865, 1012, 4432, 2005, 1996, 9160, 12109, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 

In [44]:
from torch.utils.data import Dataset

class HFDatasetWrapper(Dataset):
    def __init__(self, hf_dataset):
        self.hf_dataset = hf_dataset

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        return self.hf_dataset[idx]

# Wrap the Hugging Face datasets
wrapped_mlm_dataset = HFDatasetWrapper(mlm_tokenized_train_dataset)
wrapped_sop_dataset = HFDatasetWrapper(tokenized_sop_dataset)

batch_size = 16 # Ensure this is defined

# DataLoader for MLM dataset
mlm_dataloader = DataLoader(
    wrapped_mlm_dataset,
    batch_size=batch_size,
    shuffle=True
)

# DataLoader for SOP dataset
sop_dataloader = DataLoader(
    wrapped_sop_dataset,
    batch_size=batch_size,
    shuffle=True
)

In [31]:
vocab_size = tokenizer.vocab_size
learning_rate = 5e-5

# Move the model to the device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Recall or re-initialize your tokenizer if needed
vocab_size = tokenizer.vocab_size

# Create the configuration
config = AtomBERTConfig(vocab_size=vocab_size) # You can customize other parameters here if needed

# Instantiate the AtomBERTForPretraining model
modelpretrain = AtomBERTForPretraining(config)

print(f"Model instantiated with {config}")

Model instantiated with <__main__.AtomBERTConfig object at 0x7ae19346ba90>


In [32]:
def count_parameters(model):
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0

    for name, param in model.named_parameters():
        num_params = param.numel()
        # if requires_grad is True, then it is a trainable parameter
        if param.requires_grad:
            trainable_params += num_params
        else:
            non_trainable_params += num_params
        total_params += num_params

    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Non-trainable Parameters: {non_trainable_params:,}")
print("Basic Model: \n")
count_parameters(modelpretrain.bert)
print("Total Pretraining Model: \n")
count_parameters(modelpretrain)


Basic Model: 

Total Parameters: 8,679,134
Trainable Parameters: 8,679,134
Non-trainable Parameters: 0
Total Pretraining Model: 

Total Parameters: 32,744,218
Trainable Parameters: 32,744,218
Non-trainable Parameters: 0


In [19]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [21]:
%cd drive/MyDrive/Colab Notebooks/

/content/drive/MyDrive/Colab Notebooks


In [35]:
!pip install wandb



In [36]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhexager[0m ([33mhexager-manipal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [38]:
# Start a new wandb run to track this script.
import wandb
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="hexager-manipal",
    # Set the wandb project where this run will be logged.
    project="my-awesome-project",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": 0.02,
        "architecture": "AlBERT-32M",
        "dataset": "WikiText-103",
        "epochs": 4,
    },
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhexager[0m ([33mhexager-manipal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [45]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader  # Make sure this is imported
from tqdm import tqdm  # For a nice progress bar

# Define training parameters (you might want to adjust these)
num_epochs = 4
learning_rate = 1e-4
weight_decay = 0.01
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
modelpretrain.to(device)
optimizer = AdamW(modelpretrain.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Assuming mlm_dataloader and sop_dataloader are already defined

for epoch in range(num_epochs):
    modelpretrain.train()
    total_mlm_loss = 0
    total_sop_loss = 0
    total_steps = min(len(mlm_dataloader), len(sop_dataloader))  # Iterate for the length of the smaller dataloader

    progress_bar = tqdm(range(total_steps), desc=f"Epoch {epoch+1}/{num_epochs}")

    for step in progress_bar:
        try:
            batch_mlm = next(iter(mlm_dataloader))
        except StopIteration:
            print("MLM dataloader finished early.")
            break
        try:
            batch_sop = next(iter(sop_dataloader))
        except StopIteration:
            print("SOP dataloader finished early.")
            break
        # Move MLM batch to device
        mlm_inputs = {k: batch_mlm[k].to(device) for k in batch_mlm}

        # Move SOP batch to device
        sop_inputs = {k: batch_sop[k].to(device) for k in batch_sop}


        # Zero gradients
        optimizer.zero_grad()

        # MLM forward pass
        outputs_mlm = modelpretrain(**{
            'input_ids': mlm_inputs['input_ids'],
            'attention_mask': mlm_inputs['attention_mask'],
            'labels': mlm_inputs['labels']
        })
        mlm_loss = outputs_mlm.loss

        # SOP forward pass
        outputs_sop = modelpretrain(**{
            'input_ids': sop_inputs['input_ids'],
            'attention_mask': sop_inputs['attention_mask'],
            'token_type_ids': sop_inputs.get('token_type_ids'), # Token type IDs might not always be present
            'next_sentence_label': sop_inputs['sentence_order_labels'] # Assuming your model uses this name
        })
        sop_loss = outputs_sop.loss

        # Combine losses (you might want to weigh them differently)
        total_loss = mlm_loss + sop_loss

        # Backward pass
        total_loss.backward()

        # Update parameters
        optimizer.step()

        total_mlm_loss += mlm_loss.item()
        total_sop_loss += sop_loss.item()

        progress_bar.set_postfix(mlm_loss=f"{mlm_loss.item():.4f}", sop_loss=f"{sop_loss.item():.4f}")
    avg_mlm_loss = total_mlm_loss / total_steps if total_steps > 0 else 0
    avg_sop_loss = total_sop_loss / total_steps if total_steps > 0 else 0
    run.log({"Total Unsupervised Loss": total_loss})

    print(f"Epoch {epoch+1}/{num_epochs} finished, Average MLM Loss: {avg_mlm_loss:.4f}, Average SOP Loss: {avg_sop_loss:.4f}")

print("Pretraining finished!")
run.finish()

Epoch 1/4:   0%|          | 0/112585 [00:00<?, ?it/s]


AttributeError: 'list' object has no attribute 'to'

In [43]:
# Inspect the first MLM batch
first_mlm_batch = next(iter(mlm_dataloader))
print("First MLM batch keys:", first_mlm_batch.keys())
for key, value in first_mlm_batch.items():
    print(f"Key: {key}, Type: {type(value)}, Length: {len(value) if isinstance(value, list) or isinstance(value, torch.Tensor) else value}")
    if isinstance(value, list) and len(value) > 0:
        print(f"First element of value: {value[0]}, Type of first element: {type(value[0])}")
    elif isinstance(value, torch.Tensor) and value.numel() > 0:
        print(f"First element of value: {value[0]}, Data type: {value.dtype}")

# Inspect the first SOP batch
first_sop_batch = next(iter(sop_dataloader))
print("\nFirst SOP batch keys:", first_sop_batch.keys())
for key, value in first_sop_batch.items():
    print(f"Key: {key}, Type: {type(value)}, Length: {len(value) if isinstance(value, list) or isinstance(value, torch.Tensor) else value}")
    if isinstance(value, list) and len(value) > 0:
        print(f"First element of value: {value[0]}, Type of first element: {type(value[0])}")
    elif isinstance(value, torch.Tensor) and value.numel() > 0:
        print(f"First element of value: {value[0]}, Data type: {value.dtype}")

First MLM batch keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
Key: input_ids, Type: <class 'list'>, Length: 128
First element of value: tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101]), Type of first element: <class 'torch.Tensor'>
Key: attention_mask, Type: <class 'list'>, Length: 128
First element of value: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), Type of first element: <class 'torch.Tensor'>
Key: labels, Type: <class 'list'>, Length: 128
First element of value: tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]), Type of first element: <class 'torch.Tensor'>

First SOP batch keys: dict_keys(['text', 'sentence_order_labels', 'input_ids', 'token_type_ids', 'attention_mask'])
Key: text, Type: <class 'list'>, Length: 16
First element of value: [CLS] They also re @-@ issued the album in October 1997 , with " High " included and the " Stereo World " b 