# Text Classification with BERT

In [22]:
import torch
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model
from ml_collections import ConfigDict

from src.models import PeftModelForSequenceClassification
from src.data import SequenceClassificationDataModule
from src.config import get_config

In [23]:
config = get_config("testing", 1)
model_path = "experiments\lora_bert_sst2_sanity\checkpoints\epoch=0-step=12627.ckpt"

model = PeftModelForSequenceClassification.load_from_checkpoint(model_path)

[32m2024-06-16 09:30:22.492[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36m__init__[0m:[36m22[0m - [34m[1mLoading model bert-base-uncased...[0m
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[32m2024-06-16 09:30:23.583[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36m__init__[0m:[36m34[0m - [34m[1mCreating LoRA model...[0m
[32m2024-06-16 09:30:23.601[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36m__init__[0m:[36m37[0m - [34m[1mSeeding model parameters...[0m
[32m2024-06-16 09:30:23.601[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36mseed_model_params[0m:[36m45[0m - [34m[1mInitial weights already exist. Loading...[0m


In [5]:
# Model and Data Configurations
MODEL_NAME = "bert-base-uncased"
DATASET_NAME = ("glue", "sst2")
NUM_WORKERS = 1
BATCH_SIZE = 4
INIT_PATH = "cache/models/lora_bert_cls_init/model_state_0.pt"

# Lora Configuration
config_dict = {
        "r": 3,
        "lora_alpha": 8,
        "target_modules": ["key", "query", "value"],
        "modules_to_save": ["classifier"],
    }


In [6]:
data_module = SequenceClassificationDataModule(
    model_name=MODEL_NAME,
    dataset_name=DATASET_NAME,
    num_workers=NUM_WORKERS,
    batch_size=BATCH_SIZE,
)

data_module.setup()
trainloader = data_module.train_dataloader()
valloader = data_module.val_dataloader()
testloader = data_module.test_dataloader()

In [27]:
callbacks = [
    pl.callbacks.ModelCheckpoint(
        dirpath=f"experiments/lora_bert_sst2_sanity/checkpoints",
        save_top_k=1,
        monitor="val/acc",
        mode="max",
    ),
    # EarlyStopping(
    #     monitor="val/loss",
    #     patience=config.training.early_stopping_patience,
    #     mode="min",
    # ),
]

# loggers = [
#     CSVLogger(f"experiments/{exp_name}/logs/"),
#     WandbLogger(
#         project="LoRA-Ensembling",
#         name=exp_name,
#         log_model=True,
#         save_dir="experiments",
#     ),
# ]

# Trainer
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu",
    precision="32",
    val_check_interval=3,
    callbacks=callbacks,
    # logger=loggers,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
model.config = config
trainer.fit(model, trainloader, valloader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[32m2024-06-16 09:31:27.963[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36mconfigure_optimizers[0m:[36m108[0m - [34m[1mConfiguring optimizer and lr scheduler...[0m

  | Name       | Type                          | Params
-------------------------------------------------------------
0 | accuracy   | MulticlassAccuracy            | 0     
1 | base_model | BertForSequenceClassification | 109 M 
2 | model      | PeftModel                     | 109 M 
-------------------------------------------------------------
56.8 K    Trainable params
109 M     Non-trainable params
109 M     Total params
438.162   Total estimated model params size (MB)


Epoch 0:   4%|▍         | 672/16837 [27:47<11:08:24,  0.40it/s, v_num=2, train/loss=0.0424, train/acc=1.000, val/loss=0.312, val/acc=0.859]

In [6]:
# PEFT Config
lora_config = LoraConfig(**config_dict)

# Load BERT and PEFT Model
bert = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, cache_dir="cache/models")
lora_bert = get_peft_model(bert, lora_config)

# Save the initial model state
model_state = lora_bert.state_dict()
torch.save(model_state, INIT_PATH)

# Load the initial model state
model_state = torch.load(INIT_PATH)
lora_bert.load_state_dict(model_state)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [11]:
batch = next(iter(trainloader))
output = lora_bert(**batch)
output

SequenceClassifierOutput(loss=tensor(0.8282, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.6967, -0.1019],
        [ 0.9714,  0.0601],
        [ 1.0416, -0.2116],
        [ 1.0746, -0.2415]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [51]:
state_dict = lora_bert.state_dict()

for name in lora_bert.state_dict():
    if "lora" not in name:
        state_dict.pop(name)

In [48]:
torch.save(state_dict, "testing.pt")
torch.save(lora_bert.state_dict(), "testing_full.pt")

In [49]:
lora_bert_3 = get_peft_model(
    AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3, cache_dir="cache/models"),
    lora_config,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
lora_bert_3.load_state_dict(torch.load("testing.pt"), strict=False)

_IncompatibleKeys(missing_keys=['base_model.model.bert.embeddings.word_embeddings.weight', 'base_model.model.bert.embeddings.position_embeddings.weight', 'base_model.model.bert.embeddings.token_type_embeddings.weight', 'base_model.model.bert.embeddings.LayerNorm.weight', 'base_model.model.bert.embeddings.LayerNorm.bias', 'base_model.model.bert.encoder.layer.0.attention.self.query.base_layer.weight', 'base_model.model.bert.encoder.layer.0.attention.self.query.base_layer.bias', 'base_model.model.bert.encoder.layer.0.attention.self.key.base_layer.weight', 'base_model.model.bert.encoder.layer.0.attention.self.key.base_layer.bias', 'base_model.model.bert.encoder.layer.0.attention.self.value.base_layer.weight', 'base_model.model.bert.encoder.layer.0.attention.self.value.base_layer.bias', 'base_model.model.bert.encoder.layer.0.attention.output.dense.weight', 'base_model.model.bert.encoder.layer.0.attention.output.dense.bias', 'base_model.model.bert.encoder.layer.0.attention.output.LayerNorm.w