In [1]:
import lightning as L

  from .autonotebook import tqdm as notebook_tqdm


In [8]:

from torch import optim
from transformers import AutoModelForSequenceClassification

class LitLLM(L.LightningModule):
    def __init__(self, llm):
        super().__init__()
        self.llm = llm
    
    def training_step(self, batch, batch_idx):
        outputs = self.llm(**batch)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=5e-5)
    
    def validation_step(self, batch, batch_idx):
        outputs = self.llm(**batch)
        self.log("val_loss", outputs.loss)
    
model = LitLLM(
    AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
)
        

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Remove the text column because the model does not accept raw text as an input
# Rename the label column to labels because the model expects the argument to be named labels
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

# Set the format of the dataset to return PyTorch tensors instead of lists:
tokenized_datasets.set_format("torch")

In [10]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [11]:
# Create a DataLoader for your training and test datasets so you can iterate over batches of data:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

In [12]:
trainer = L.Trainer(min_epochs=3, max_epochs=3)
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=eval_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name | Type                          | Params
-------------------------------------------------------
0 | llm  | BertForSequenceClassification | 108 M 
-------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.256   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/echapman/projects/llm-finetuning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/echapman/projects/llm-finetuning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 125/125 [02:54<00:00,  0.72it/s, v_num=1]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 125/125 [02:56<00:00,  0.71it/s, v_num=1]


In [13]:
trainer.validate(model=model, dataloaders=eval_dataloader)

/Users/echapman/projects/llm-finetuning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Validation DataLoader 0: 100%|██████████| 125/125 [00:40<00:00,  3.05it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            1.1710704565048218
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 1.1710704565048218}]