In [26]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [27]:
from transformers import BertTokenizerFast
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from datasets import Dataset as HFDataset
import torch
import torch.nn.functional as F
import math
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandbpass")

In [28]:
import wandb
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [29]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
sentence = "By the way, we ball"
tokens = tokenizer.tokenize(sentence)
print(tokens)

['by', 'the', 'way', ',', 'we', 'ball']


In [30]:
mnli_dataset = load_dataset("multi_nli")

In [31]:
print(mnli_dataset['train'][0])

{'promptID': 31193, 'pairID': '31193n', 'premise': 'Conceptually cream skimming has two basic dimensions - product and geography.', 'premise_binary_parse': '( ( Conceptually ( cream skimming ) ) ( ( has ( ( ( two ( basic dimensions ) ) - ) ( ( product and ) geography ) ) ) . ) )', 'premise_parse': '(ROOT (S (NP (JJ Conceptually) (NN cream) (NN skimming)) (VP (VBZ has) (NP (NP (CD two) (JJ basic) (NNS dimensions)) (: -) (NP (NN product) (CC and) (NN geography)))) (. .)))', 'hypothesis': 'Product and geography are what make cream skimming work. ', 'hypothesis_binary_parse': '( ( ( Product and ) geography ) ( ( are ( what ( make ( cream ( skimming work ) ) ) ) ) . ) )', 'hypothesis_parse': '(ROOT (S (NP (NN Product) (CC and) (NN geography)) (VP (VBP are) (SBAR (WHNP (WP what)) (S (VP (VBP make) (NP (NP (NN cream)) (VP (VBG skimming) (NP (NN work)))))))) (. .)))', 'genre': 'government', 'label': 1}


In [32]:
class MNLIDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.data)
    def  __getitem__(self, idx):
        example = self.data[idx]
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = example['label']
        encoded_pair = self.tokenizer.encode_plus(premise, hypothesis, max_length=self.max_length, padding='max_length', truncation=True,return_tensors='pt')
        input_ids = encoded_pair['input_ids'].squeeze(0)
        attention_mask = encoded_pair['attention_mask'].squeeze(0)
        # Keeping this commented out for now, maybe not very essential for encoder only models? Investigate further...
        # token_type_ids = encoded_pair.get('token_type_ids', torch.zeros_like(input_ids))
        return {'input_ids': input_ids,'attention_mask': attention_mask,# 'token_type_ids': token_type_ids,
'labels': torch.tensor(label)}


In [33]:
train_data = mnli_dataset["train"]
max_seq_length = 128
train_dataset = MNLIDataset(train_data, tokenizer, max_seq_length)
print(f"Size of training dataset: {len(train_dataset)}")

Size of training dataset: 392702


In [34]:
sample = train_dataset[0]
print(sample['input_ids'].shape)
print(sample['attention_mask'].shape)
print(sample['labels'])

torch.Size([128])
torch.Size([128])
tensor(1)


In [35]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_rate):
        super().__init__()
        self.num_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        assert self.head_dim * self.num_heads == hidden_size

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout_rate)
        self.output = nn.Linear(hidden_size, hidden_size)
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len_q, seq_len_k, seq_len_v = query.size(1), key.size(1), value.size(1)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        query = query.view(batch_size, seq_len_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = key.view(batch_size, seq_len_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = value.view(batch_size, seq_len_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        attention_scores = torch.matmul(query, key.transpose(-2,-1))
        attention_scores = attention_scores/(self.head_dim**0.5)
        mask = mask.unsqueeze(1).unsqueeze(2)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask==0, float('-inf'))
        attention_weights = F.softmax(attention_scores,dim=-1)
        scaled_attention = torch.matmul(attention_weights, value)
        scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len_q, self.num_heads * self.head_dim)
        output = self.output(scaled_attention)
        return output

In [36]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout_rate):
        super().__init__()
        self.dense1 = nn.Linear(hidden_size, intermediate_size)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = F.relu(self.dense1(x))
        x = self.dropout(x)
        x = self.dense2(x)
        return x

In [37]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_rate):
        super().__init__()
        self.self_attention = MultiHeadSelfAttention(hidden_size, num_attention_heads, dropout_rate)
        self.feed_forward = FeedForwardNetwork(hidden_size, intermediate_size, dropout_rate)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask):
        attention_output = self.self_attention(x,x, x, mask)
        normed_output1 = self.norm1(attention_output + x)
        ff_output = self.feed_forward(normed_output1)
        final_output = self.norm2(ff_output + normed_output1)
        return final_output

In [38]:
class FactorizedEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, type_vocab_size=2):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.token_type_embeddings = nn.Embedding(type_vocab_size, embedding_dim)
        self.projection = nn.Linear(embedding_dim, hidden_size)

    def forward(self, input_ids, token_type_ids=None):
        factor_embeds = self.word_embeddings(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        token_type_embeds = self.token_type_embeddings(token_type_ids)
        # Sum the embeddings
        embeds = factor_embeds + token_type_embeds        
        project_embeds = self.projection(embeds)
        return project_embeds

In [39]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, max_seq_length, dropout_rate):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)

        position = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2) * -(math.log(10000.0) / hidden_size))
        pe = torch.zeros(max_seq_length, 1, hidden_size)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe) # Store as a buffer (not a learnable parameter)

    def forward(self, x):
        seq_length = x.size(1)
        pe = self.pe[:seq_length].squeeze(1)
        x = x + pe.unsqueeze(0)
        x = self.dropout(x)
        return x

In [40]:
class AtomBERT(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, num_attention_heads, intermediate_size, num_classes, max_seq_length, dropout_rate, type_vocab_size=2):
        super().__init__()
        self.embedding = FactorizedEmbedding(vocab_size, embedding_dim, hidden_size, type_vocab_size)
        self.positional_encoding = PositionalEncoding(hidden_size, max_seq_length, dropout_rate)
        self.encoder_layer = TransformerEncoderLayer(hidden_size, num_attention_heads, intermediate_size, dropout_rate)
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout_rate)
        #self.classifier = nn.Linear(hidden_size, num_classes)
        self.hidden_size = hidden_size

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        embeddings = self.embedding(input_ids, token_type_ids)
        embeddings = self.positional_encoding(embeddings)

        # Share the encoder layer across all layers
        encoder_output = embeddings
        for _ in range(self.num_layers):
            encoder_output = self.encoder_layer(encoder_output, attention_mask)

        pooled_output = encoder_output[:, 0, :] # Shape: (batch_size, hidden_size)
        pooled_output = self.dropout(pooled_output)
        #logits = self.classifier(pooled_output)
        return encoder_output, pooled_output

In [41]:
class AtomBERTForPretraining(nn.Module): # Renamed for clarity
    def __init__(self, config):
        super().__init__()
        self.config = config # Store the config object
        self.bert = AtomBERT(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            num_attention_heads=config.num_attention_heads,
            intermediate_size=config.intermediate_size,
            num_classes=2, # For SOP
            max_seq_length=config.max_seq_length,
            dropout_rate=config.dropout_rate,
            type_vocab_size=config.type_vocab_size if hasattr(config, 'type_vocab_size') else 2
        )
        self.mlm_head = nn.Sequential(
        nn.Linear(config.hidden_size, config.hidden_size),
        nn.LayerNorm(config.hidden_size),
        nn.Linear(config.hidden_size, config.vocab_size)
        )
        self.sop_head = nn.Linear(config.hidden_size, 2)



    def forward(self, input_ids, attention_mask, masked_labels=None, sentence_order_labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask, token_type_ids) # Shape: (batch_size, seq_len, hidden_size)
        encoder_output = outputs[0]
        pooled_output = outputs[1]
        # MLM Prediction
        prediction_logits_mlm = self.mlm_head(encoder_output) # Shape: (batch_size, seq_len, vocab_size)

        # SOP Prediction (using the pooled output - we might need to adjust this)
        #pooled_output = outputs[:, 0, :] # Taking the [CLS] token representation
        prediction_logits_sop = self.sop_head(pooled_output) # Shape: (batch_size, 2)

        total_loss = None
        if masked_labels is not None and sentence_order_labels is not None:
            loss_fct_mlm = nn.CrossEntropyLoss()
            masked_loss = loss_fct_mlm(prediction_logits_mlm.view(-1, self.config.vocab_size), masked_labels.view(-1))

            loss_fct_sop = nn.CrossEntropyLoss()
            sop_loss = loss_fct_sop(prediction_logits_sop.view(-1, 2), sentence_order_labels.view(-1))

            total_loss = masked_loss + sop_loss # You might want to weigh these differently

        elif masked_labels is not None:
            loss_fct_mlm = nn.CrossEntropyLoss()
            total_loss = loss_fct_mlm(prediction_logits_mlm.view(-1, self.config.vocab_size), masked_labels.view(-1))

        elif sentence_order_labels is not None:
            loss_fct_sop = nn.CrossEntropyLoss()
            total_loss = loss_fct_sop(prediction_logits_sop.view(-1, 2), sentence_order_labels.view(-1))

        return total_loss, prediction_logits_mlm, prediction_logits_sop

In [42]:
class AtomBERTConfig:
    def __init__(
        self,
        vocab_size,
        embedding_dim=128,
        hidden_size=768,
        num_layers=4,
        num_attention_heads=4,
        intermediate_size=1500,
        num_classes=2, # For SOP
        max_seq_length=128,
        dropout_rate=0.1,
        type_vocab_size = 2
    ):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.num_classes = num_classes
        self.max_seq_length = max_seq_length
        self.dropout_rate = dropout_rate
        self.type_vocab_size = type_vocab_size

In [43]:
wiki_text_dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
print(wiki_text_dataset)

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})


In [44]:
vocab_size = tokenizer.vocab_size
learning_rate = 5e-5

# Move the model to the device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Recall or re-initialize your tokenizer if needed
vocab_size = tokenizer.vocab_size

# Create the configuration
config = AtomBERTConfig(vocab_size=vocab_size) # You can customize other parameters here if needed

# Instantiate the AtomBERTForPretraining model
modelpretrain = AtomBERTForPretraining(config)

print(f"Model instantiated with {config}")

Model instantiated with <__main__.AtomBERTConfig object at 0x7c6735acfb20>


In [45]:
def count_parameters(model):
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0

    for name, param in model.named_parameters():
        num_params = param.numel()
        # if requires_grad is True, then it is a trainable parameter
        if param.requires_grad:
            trainable_params += num_params
        else:
            non_trainable_params += num_params
        total_params += num_params

    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Non-trainable Parameters: {non_trainable_params:,}")
print("Basic Model: \n")
count_parameters(modelpretrain.bert)
print("Total Pretraining Model: \n")
count_parameters(modelpretrain)


Basic Model: 

Total Parameters: 8,677,852
Trainable Parameters: 8,677,852
Non-trainable Parameters: 0
Total Pretraining Model: 

Total Parameters: 32,742,936
Trainable Parameters: 32,742,936
Non-trainable Parameters: 0


In [46]:
# Start a new wandb run to track this script.
import wandb
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="hexager-manipal",
    # Set the wandb project where this run will be logged.
    project="MRM-transform",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": 0.02,
        "architecture": "AlBERT-32M",
        "dataset": "WikiText-103",
        "epochs": 4,
    },
)

In [None]:
# Define parameters (ensure max_seq_length and batch_size are defined)
max_seq_length = 128  # Example value, adjust as needed
batch_size = 128      # Example value, adjust as needed
num_epochs = 4
learning_rate = 1e-4
weight_decay = 0.01
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Load dataset
wiki_text_dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding='max_length', max_length=max_seq_length)
tokenized_train_dataset = wiki_text_dataset["train"].map(tokenize_function, batched=True)

# Prepare MLM data
def prepare_mlm_data(examples):
    input_ids = torch.tensor(examples["input_ids"])
    batch_size_local = len(input_ids)
    sequence_length = len(input_ids[0])
    mask_prob = 0.15
    mask_token_id = tokenizer.mask_token_id
    pad_token_id = tokenizer.pad_token_id
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id

    rand = torch.rand(batch_size_local, sequence_length)
    mask = (rand < mask_prob) & (input_ids != cls_token_id) & (input_ids != sep_token_id) & (input_ids != pad_token_id)

    labels = input_ids.clone()
    labels[~mask] = -100
    inputs = input_ids.clone()
    inputs[mask] = mask_token_id

    return {"input_ids": inputs, "labels": labels}
mlm_tokenized_train_dataset = tokenized_train_dataset.map(
    prepare_mlm_data,
    batched=True,
    remove_columns=["text", "token_type_ids"]
)

# Prepare SOP data
import re
from datasets import Dataset as HFDataset

def split_into_sentences_robust(text):
    sentence_pattern = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s+'
    sentences = re.split(sentence_pattern, text)
    return [s.strip() for s in sentences if s.strip()]

sop_texts_limited = []
sop_labels_limited = []
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
max_pairs_per_document = 5

for example in wiki_text_dataset["train"]:
    text = example["text"]
    sentences = split_into_sentences_robust(text)
    pairs_count = 0
    for i in range(len(sentences) - 1):
        if pairs_count >= max_pairs_per_document:
            break
        sentence_a = sentences[i]
        sentence_b = sentences[i + 1]
        sop_texts_limited.append(cls_token + " " + sentence_a + " " + sep_token + " " + sentence_b)
        sop_labels_limited.append(0)
        pairs_count += 1
        if pairs_count >= max_pairs_per_document:
            break
        sop_texts_limited.append(cls_token + " " + sentence_b + " " + sep_token + " " + sentence_a)
        sop_labels_limited.append(1)
        pairs_count += 1

sop_dataset = HFDataset.from_dict({"text": sop_texts_limited, "sentence_order_labels": sop_labels_limited})

def tokenize_sop_function(examples):
    return tokenizer(examples["text"], truncation=True, padding='max_length', max_length=max_seq_length)
tokenized_sop_dataset = sop_dataset.map(tokenize_sop_function, batched=True)

# Set format to PyTorch tensors
mlm_tokenized_train_dataset.set_format("torch")
tokenized_sop_dataset.set_format("torch")

# DataLoader for MLM dataset
mlm_dataloader = DataLoader(
    mlm_tokenized_train_dataset,
    batch_size=batch_size,
    shuffle=True
)

# DataLoader for SOP dataset
sop_dataloader = DataLoader(
    tokenized_sop_dataset,
    batch_size=batch_size,
    shuffle=True
)


Map:   0%|          | 0/1801350 [00:00<?, ? examples/s]

In [None]:
modelpretrain.to(device)
from torch.optim import AdamW
optimizer = AdamW(modelpretrain.parameters(), lr=learning_rate, weight_decay=weight_decay)
wandb.watch(modelpretrain)

In [None]:
for epoch in range(num_epochs):
    modelpretrain.train()
    total_mlm_loss = 0
    total_sop_loss = 0
    total_steps = min(len(mlm_dataloader), len(sop_dataloader))

    progress_bar = tqdm(range(total_steps), desc=f"Epoch {epoch+1}/{num_epochs}")

    for step in progress_bar:
        try:
            batch_mlm = next(iter(mlm_dataloader))
        except StopIteration:
            print("MLM dataloader finished early.")
            break
        try:
            batch_sop = next(iter(sop_dataloader))
        except StopIteration:
            print("SOP dataloader finished early.")
            break

        # Move MLM batch to device
        mlm_inputs = {k: batch_mlm[k].to(device) for k in batch_mlm}

        # Move SOP batch to device
        sop_inputs = {
                    k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch_sop.items()
                    }

        # Zero gradients
        optimizer.zero_grad()

        # MLM forward pass
        returned_loss, prediction_logits_mlm, prediction_logits_sop = modelpretrain(
            input_ids=mlm_inputs['input_ids'].to(device),
            attention_mask=mlm_inputs['attention_mask'].to(device),
            masked_labels=mlm_inputs['labels'].to(device)
        )

        mlm_loss = returned_loss # Now you can use 'returned_loss' which holds the total loss

        # SOP forward pass
        outputs_sop = modelpretrain(**{
            'input_ids': sop_inputs['input_ids'].to(device),
            'attention_mask': sop_inputs['attention_mask'].to(device),
            'token_type_ids' : sop_inputs.get('token_type_ids', torch.tensor([], device=device)).to(device),
            'sentence_order_labels': sop_inputs['sentence_order_labels'].to(device)
        })
        sop_loss = outputs_sop[0]

        # Combine losses
        total_loss = mlm_loss + sop_loss

        # Backward pass
        total_loss.backward()

        # Update parameters
        optimizer.step()

        total_mlm_loss += mlm_loss.item()
        total_sop_loss += sop_loss.item()
        wandb.log({"mlm_loss": mlm_loss.item(), "sop_loss": sop_loss.item(), "total_loss": total_loss.item()})
        progress_bar.set_postfix(mlm_loss=f"{mlm_loss.item():.4f}", sop_loss=f"{sop_loss.item():.4f}")
        
    avg_mlm_loss = total_mlm_loss / total_steps if total_steps > 0 else 0
    avg_sop_loss = total_sop_loss / total_steps if total_steps > 0 else 0
    run.log({"Total Unsupervised Loss": total_loss})
    save_path = '/kaggle/working/pretrained_model.pth'  # Replace with your desired path
    wandb.log({"avg_mlm_loss": avg_mlm_loss, "avg_sop_loss": avg_sop_loss, "epoch": epoch + 1})
    # Save the state_dict of the model
    torch.save(modelpretrain.state_dict(), save_path)
    wandb.save(save_path)
    print(f"Pretrained model weights saved to: {save_path}")
    print(f"Epoch {epoch+1}/{num_epochs} finished, Average MLM Loss: {avg_mlm_loss:.4f}, Average SOP Loss: {avg_sop_loss:.4f}")

print("Pretraining finished!")
run.finish()

In [None]:
class AtomBERTForSequenceClassification(nn.Module):
    def __init__(self, pretrained_config, num_classes):
        super().__init__()
        self.bert = AtomBERT(
            vocab_size=pretrained_config.vocab_size,
            embedding_dim=pretrained_config.embedding_dim,
            hidden_size=pretrained_config.hidden_size,
            num_layers=pretrained_config.num_layers,
            num_attention_heads=pretrained_config.num_attention_heads,
            intermediate_size=pretrained_config.intermediate_size,
            num_classes=num_classes, # For MNLI this will be 3
            max_seq_length=pretrained_config.max_seq_length,
            dropout_rate=pretrained_config.dropout_rate,
            type_vocab_size=pretrained_config.type_vocab_size if hasattr(pretrained_config, 'type_vocab_size') else 2
        )
        self.dropout = nn.Dropout(pretrained_config.dropout_rate)
        self.classifier = nn.Linear(pretrained_config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled_output = outputs[1] # We typically use the pooled output for classification
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [None]:


# Define the path to your saved pretrained weights file
pretrained_weights_path = '/kaggle/working/pretrained_model.pth' # Replace with the actual path to your saved file

# Load the configuration of your pretrained model
# Make sure these values match what you used for pretraining
pretrained_config = AtomBERTConfig(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=128,
    hidden_size=768,
    num_layers=4,
    num_attention_heads=4,
    intermediate_size=1500,
    max_seq_length=128,
    dropout_rate=0.1,
    type_vocab_size=2
)

# Instantiate the model for sequence classification (for MNLI with 3 classes)
num_classes_mnli = 3
model_for_mnli = AtomBERTForSequenceClassification(pretrained_config, num_classes_mnli)

# Load the state dictionary of the pretrained model
pretrained_state_dict = torch.load(pretrained_weights_path, map_location=torch.device(device)) # Ensure device consistency

# Create a new dictionary to store only the 'bert' weights
pretrained_bert_state_dict = {}
for name, param in pretrained_state_dict.items():
    if name.startswith('bert.'):
        pretrained_bert_state_dict[name[len('bert.'):]] = param

# Load the pretrained weights into the 'bert' part of the sequence classification model
model_for_mnli.bert.load_state_dict(pretrained_bert_state_dict)

print("Pretrained weights loaded successfully into the sequence classification model.")

# Move the fine-tuning model to the device
model_for_mnli.to(device)