# Text Classification with BERT

In [19]:
import torch
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model
from ml_collections import ConfigDict


from src.models import PeftModelForSequenceClassification
from src.data import SequenceClassificationDataModule
from src.config import get_config

In [None]:
config = get_config("testing", 1)
model_path = "experiments/lora-bert-sst2_kqv/checkpoints/epoch=0-step=8416.ckpt"

model = PeftModelForSequenceClassification.save_best_model_state_dict(trainer, config)

In [None]:
# Model and Data Configurations
MODEL_NAME = "bert-base-uncased"
DATASET_NAME = ("glue", "sst2")
NUM_WORKERS = 1
BATCH_SIZE = 4
INIT_PATH = "cache/models/lora_bert_cls_init/model_state_0.pt"

# Lora Configuration
config_dict = {
        "r": 3,
        "lora_alpha": 8,
        "target_modules": ["key", "query", "value"],
        "modules_to_save": ["classifier"],
    }


In [None]:
data_module = SequenceClassificationDataModule(
    model_name=MODEL_NAME,
    dataset_name=DATASET_NAME,
    num_workers=NUM_WORKERS,
    batch_size=BATCH_SIZE,
)

data_module.setup()
trainloader = data_module.train_dataloader()
valloader = data_module.val_dataloader()
testloader = data_module.test_dataloader()

In [4]:
callbacks = [
    pl.callbacks.ModelCheckpoint(
        dirpath=f"experiments/lora_bert_sst2_sanity/checkpoints",
        save_top_k=1,
        monitor="val/acc",
        mode="max",
    ),
    # EarlyStopping(
    #     monitor="val/loss",
    #     patience=config.training.early_stopping_patience,
    #     mode="min",
    # ),
]

# loggers = [
#     CSVLogger(f"experiments/{exp_name}/logs/"),
#     WandbLogger(
#         project="LoRA-Ensembling",
#         name=exp_name,
#         log_model=True,
#         save_dir="experiments",
#     ),
# ]

# Trainer
trainer = pl.Trainer(
    max_epochs=1,
    # accelerator="gpu",
    precision="32",
    val_check_interval=3,
    callbacks=callbacks,
    # logger=loggers,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
model.config = config
trainer.fit(model, trainloader, valloader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[32m2024-06-16 09:31:27.963[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36mconfigure_optimizers[0m:[36m108[0m - [34m[1mConfiguring optimizer and lr scheduler...[0m

  | Name       | Type                          | Params
-------------------------------------------------------------
0 | accuracy   | MulticlassAccuracy            | 0     
1 | base_model | BertForSequenceClassification | 109 M 
2 | model      | PeftModel                     | 109 M 
-------------------------------------------------------------
56.8 K    Trainable params
109 M     Non-trainable params
109 M     Total params
438.162   Total estimated model params size (MB)


Epoch 0:   4%|▍         | 672/16837 [27:47<11:08:24,  0.40it/s, v_num=2, train/loss=0.0424, train/acc=1.000, val/loss=0.312, val/acc=0.859]

In [None]:
# PEFT Config
lora_config = LoraConfig(**config_dict)

# Load BERT and PEFT Model
bert = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, cache_dir="cache/models")
lora_bert = get_peft_model(bert, lora_config)

# Save the initial model state
model_state = lora_bert.state_dict()
torch.save(model_state, INIT_PATH)

# Load the initial model state
model_state = torch.load(INIT_PATH)
lora_bert.load_state_dict(model_state)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [1]:
from datasets import load_dataset

mnli = load_dataset("glue", "mnli", cache_dir="cache/data")
mnli

Downloading data: 100%|██████████| 52.2M/52.2M [00:29<00:00, 1.79MB/s]
Downloading data: 100%|██████████| 1.21M/1.21M [00:02<00:00, 480kB/s]
Downloading data: 100%|██████████| 1.25M/1.25M [00:02<00:00, 491kB/s]
Downloading data: 100%|██████████| 1.22M/1.22M [00:02<00:00, 501kB/s]
Downloading data: 100%|██████████| 1.26M/1.26M [00:02<00:00, 514kB/s]
Generating train split: 100%|██████████| 392702/392702 [00:00<00:00, 1829589.32 examples/s]
Generating validation_matched split: 100%|██████████| 9815/9815 [00:00<00:00, 1088673.34 examples/s]
Generating validation_mismatched split: 100%|██████████| 9832/9832 [00:00<00:00, 1638524.99 examples/s]
Generating test_matched split: 100%|██████████| 9796/9796 [00:00<00:00, 1348233.04 examples/s]
Generating test_mismatched split: 100%|██████████| 9847/9847 [00:00<00:00, 1545302.93 examples/s]


DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})

In [5]:
from src.mnli import MNLIDataModule

mnli = MNLIDataModule("bert-base-uncased", 4, 1)

In [6]:
mnli.prepare_data()
mnli.setup()

Map: 100%|██████████| 392702/392702 [00:56<00:00, 6892.57 examples/s]
Map: 100%|██████████| 9815/9815 [00:01<00:00, 6694.98 examples/s] 
Map: 100%|██████████| 9832/9832 [00:01<00:00, 7535.50 examples/s]
Map: 100%|██████████| 9796/9796 [00:01<00:00, 7229.60 examples/s]
Map: 100%|██████████| 9847/9847 [00:02<00:00, 4010.88 examples/s]


In [8]:
mnli.dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9847
    })
})

In [15]:
trainloader = mnli.train_dataloader()
trainloader

<torch.utils.data.dataloader.DataLoader at 0x22ced3dfe10>

In [20]:

config = get_config("testing", 2)
model = PeftModelForSequenceClassification(config)
model

[32m2024-06-17 11:25:51.174[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36m__init__[0m:[36m22[0m - [34m[1mLoading model bert-base-uncased...[0m
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[32m2024-06-17 11:25:53.669[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36m__init__[0m:[36m34[0m - [34m[1mCreating LoRA model...[0m
[32m2024-06-17 11:25:53.707[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36m__init__[0m:[36m37[0m - [34m[1mSeeding model parameters...[0m
[32m2024-06-17 11:25:53.709[0m | [34m[1mDEBUG   [0m | [36msrc.models[0m:[36mseed_model_params[0m:[36m45[0m - [34m[1mInitial weights already exist. Loading...[0m


PeftModelForSequenceClassification(
  (accuracy): MulticlassAccuracy()
  (base_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): lora.Linear(
                  (base_layer): Linear(in_features=768, out_features=768, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Identity()
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=768, out_features=1, bias=Fals

In [26]:
for batch in mnli.val_dataloader()[0]:
    print(batch.keys())
    print(batch["labels"])
    break

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
tensor([1, 2, 0, 2])


In [37]:
load_dataset("glue", "mnli_matched", cache_dir="cache/data")

Downloading data: 100%|██████████| 1.21M/1.21M [00:03<00:00, 367kB/s]
Downloading data: 100%|██████████| 1.22M/1.22M [00:02<00:00, 436kB/s]
Generating validation split: 100%|██████████| 9815/9815 [00:00<00:00, 124445.95 examples/s]
Generating test split: 100%|██████████| 9796/9796 [00:00<00:00, 652191.34 examples/s]


DatasetDict({
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
})