# Text Classification with BERT

In [2]:
import torch
from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model
from ml_collections import ConfigDict

from src.data import SequenceClassificationDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Model and Data Configurations
MODEL_NAME = "bert-base-uncased"
DATASET_NAME = ("glue", "sst2")
NUM_WORKERS = 1
BATCH_SIZE = 4
INIT_PATH = "cache/models/lora_bert_cls_init/model_state_0.pt"

# Lora Configuration
config_dict = {
        "r": 3,
        "lora_alpha": 8,
        "target_modules": ["key", "query", "value"],
        "modules_to_save": ["classifier"],
    }


In [4]:
data_module = SequenceClassificationDataModule(
    model_name=MODEL_NAME,
    dataset_name=DATASET_NAME,
    num_workers=NUM_WORKERS,
    batch_size=BATCH_SIZE,
)

data_module.setup()
trainloader = data_module.train_dataloader()
valloader = data_module.val_dataloader()
testloader = data_module.test_dataloader()

Map: 100%|██████████| 1821/1821 [00:00<00:00, 2159.92 examples/s]


In [5]:
data_module.dataset["train"]

Dataset({
    features: ['sentence', 'label', 'idx'],
    num_rows: 67349
})

In [6]:
# PEFT Config
lora_config = LoraConfig(**config_dict)

# Load BERT and PEFT Model
bert = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, cache_dir="cache/models")
lora_bert = get_peft_model(bert, lora_config)

# Save the initial model state
model_state = lora_bert.state_dict()
torch.save(model_state, INIT_PATH)

# Load the initial model state
model_state = torch.load(INIT_PATH)
lora_bert.load_state_dict(model_state)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [11]:
batch = next(iter(trainloader))
output = lora_bert(**batch)
output

SequenceClassifierOutput(loss=tensor(0.8282, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.6967, -0.1019],
        [ 0.9714,  0.0601],
        [ 1.0416, -0.2116],
        [ 1.0746, -0.2415]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [17]:
from torchmetrics import Accuracy

acc_metric = Accuracy(task="multiclass", num_classes=2).to("cuda")
acc_metric(output.logits, batch["labels"])


tensor(0.5000, device='cuda:0')