# Using PEFT with custom models

`peft` allows us to fine-tune models efficiently with LoRA. In this notebook we
train a simple multilayer perceptron (MLP) using `peft`.

In [2]:
import copy
import os

os.environ['BITSANDBYTES_NOWELCOME'] = '1'

In [3]:
import peft
import torch
from torch import nn
import torch.nn.functional as F

In [4]:
torch.manual_seed(0)

<torch._C.Generator at 0x10c1a3890>

In [6]:
X = torch.rand((1000, 20))
y = (X.sum(1) > 10).long()

In [7]:
n_train = 800
batch_size = 64

In [8]:
train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X[:n_train], y[:n_train]),
    batch_size=batch_size,
    shuffle=True
)
eval_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X[n_train:], y[n_train:]),
    batch_size=batch_size,
)

In [None]:
class MLP(nn.Module):
    def __init__(self, num_units_hidden=2000):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(20, num_units_hidden),
            nn.ReLU(),
            nn.Linear(num_units_hidden, num_units_hidden),
            nn.ReLU(),
            nn.Linear(num_units_hidden, 2),
            nn.LogSoftmax(dim=-1),
        )

    def forward(self, X):
        return self.seq(X)

In [10]:
lr = 0.002
batch_size = 64
max_epochs = 30
device = 'mps'

In [14]:
def train(model, optimizer, criterion, train_dataloader, eval_dataloader, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for xb, yb in train_dataloader:
            xb = xb.to(device)
            yb = yb.to(device)
            outputs = model(xb)
            loss = criterion(outputs, yb)
            train_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        eval_loss = 0
        for xb, yb in eval_dataloader:
            xb = xb.to(device)
            yb = yb.to(device)
            with torch.no_grad():
                outputs = model(xb)
            loss = criterion(outputs, yb)
            eval_loss += loss.detach().float()

        eval_loss_total = (eval_loss / len(eval_dataloader)).item()
        train_loss_total = (train_loss / len(train_dataloader)).item()
        print(f'{epoch=:<2} {train_loss_total=:.4f} {eval_loss_total=:.4f}')

In [15]:
# Training without PEFT
module = MLP().to(device)
optimizer = torch.optim.Adam(module.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [None]:
%time train(module, optimizer, criterion, train_dataloader, eval_dataloader, epochs=max_epochs)

epoch=0  train_loss_total=1.1008 eval_loss_total=0.6648
epoch=1  train_loss_total=0.6641 eval_loss_total=0.6243
epoch=2  train_loss_total=0.5475 eval_loss_total=0.5355
epoch=3  train_loss_total=0.3885 eval_loss_total=0.3534
epoch=4  train_loss_total=0.2576 eval_loss_total=0.4362
epoch=5  train_loss_total=0.2201 eval_loss_total=0.2771
epoch=6  train_loss_total=0.1514 eval_loss_total=0.2032
epoch=7  train_loss_total=0.1173 eval_loss_total=0.2428
epoch=8  train_loss_total=0.1057 eval_loss_total=0.2615
epoch=9  train_loss_total=0.1137 eval_loss_total=0.2047
epoch=10 train_loss_total=0.1334 eval_loss_total=0.3663
epoch=11 train_loss_total=0.0925 eval_loss_total=0.3073
epoch=12 train_loss_total=0.0603 eval_loss_total=0.2559
epoch=13 train_loss_total=0.0520 eval_loss_total=0.1886
epoch=14 train_loss_total=0.0511 eval_loss_total=0.2894
epoch=15 train_loss_total=0.0324 eval_loss_total=0.2197
epoch=16 train_loss_total=0.0177 eval_loss_total=0.2164
epoch=17 train_loss_total=0.0156 eval_loss_total

In [17]:
# Now training with PEFT
[(n, type(m)) for n, m in MLP().named_modules()]

[('', __main__.MLP),
 ('seq', torch.nn.modules.container.Sequential),
 ('seq.0', torch.nn.modules.linear.Linear),
 ('seq.1', torch.nn.modules.activation.ReLU),
 ('seq.2', torch.nn.modules.linear.Linear),
 ('seq.3', torch.nn.modules.activation.ReLU),
 ('seq.4', torch.nn.modules.linear.Linear),
 ('seq.5', torch.nn.modules.activation.LogSoftmax)]

In [18]:
config = peft.LoraConfig(
    r=8,
    target_modules=['seq.0', 'seq.2'],
    modules_to_save=['seq.4']
)

In [None]:
module = MLP().to(device)
module_copy = copy.deepcopy(module)
peft_model = peft.get_peft_model(module, config)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
peft_model.print_trainable_parameters()

trainable params: 52,162 || all params: 4,100,164 || trainable%: 1.2722


In [21]:
%time train(peft_model, optimizer, criterion, train_dataloader, eval_dataloader, epochs=max_epochs)

epoch=0  train_loss_total=0.6777 eval_loss_total=0.6316
epoch=1  train_loss_total=0.5815 eval_loss_total=0.5184
epoch=2  train_loss_total=0.4233 eval_loss_total=0.3412
epoch=3  train_loss_total=0.2959 eval_loss_total=0.3356
epoch=4  train_loss_total=0.2465 eval_loss_total=0.3502
epoch=5  train_loss_total=0.1848 eval_loss_total=0.2093
epoch=6  train_loss_total=0.1341 eval_loss_total=0.2015
epoch=7  train_loss_total=0.0906 eval_loss_total=0.2163
epoch=8  train_loss_total=0.0855 eval_loss_total=0.2263
epoch=9  train_loss_total=0.0631 eval_loss_total=0.2437
epoch=10 train_loss_total=0.0471 eval_loss_total=0.2159
epoch=11 train_loss_total=0.0299 eval_loss_total=0.2250
epoch=12 train_loss_total=0.0273 eval_loss_total=0.3061
epoch=13 train_loss_total=0.0669 eval_loss_total=0.7617
epoch=14 train_loss_total=0.0843 eval_loss_total=0.5253
epoch=15 train_loss_total=0.0643 eval_loss_total=0.2451
epoch=16 train_loss_total=0.0136 eval_loss_total=0.2509
epoch=17 train_loss_total=0.0061 eval_loss_total

In [22]:
for name, param in peft_model.base_model.named_parameters():
    if 'lora' not in name:
        continue
    print(f'new params: {name:<13} | {param.numel():>5} parameters | updated')

new params: model.seq.0.lora_A.default.weight |   160 parameters | updated
new params: model.seq.0.lora_B.default.weight | 16000 parameters | updated
new params: model.seq.2.lora_A.default.weight | 16000 parameters | updated
new params: model.seq.2.lora_B.default.weight | 16000 parameters | updated


In [None]:
params_before = dict(module_copy.named_parameters())
for name, param in peft_model.base_model.named_parameters():
    if 'lora' in name:
        continue
    name_before = (
        name.partition(".")[-1].replace("base_layer.", "").replace("original_",
                                                                   "").replace("module.", "").replace("modules_to_save.default.", "")
    )
    param_before = params_before[name_before]
    if torch.allclose(param, param_before):
        print(
            f"Parameter {name_before:<13} | {param.numel():>7} parameters | not updated")
    else:
        print(
            f"Parameter {name_before:<13} | {param.numel():>7} parameters | updated")

Parameter seq.0.weight  |   40000 parameters | not updated
Parameter seq.0.bias    |    2000 parameters | not updated
Parameter seq.2.weight  | 4000000 parameters | not updated
Parameter seq.2.bias    |    2000 parameters | not updated
Parameter seq.4.weight  |    4000 parameters | not updated
Parameter seq.4.bias    |       2 parameters | not updated
Parameter seq.4.weight  |    4000 parameters | updated
Parameter seq.4.bias    |       2 parameters | updated


In [25]:
user = 'rodmosc'
model_name = 'peft-lora-with-custom-model'
model_id = f'{user}/{model_name}'
peft_model.push_to_hub(model_id)

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

CommitInfo(commit_url='https://huggingface.co/rodmosc/peft-lora-with-custom-model/commit/a5a87a86b6235e8eb5a4edf90f71d2134a9b0a23', commit_message='Upload model', commit_description='', oid='a5a87a86b6235e8eb5a4edf90f71d2134a9b0a23', pr_url=None, repo_url=RepoUrl('https://huggingface.co/rodmosc/peft-lora-with-custom-model', endpoint='https://huggingface.co', repo_type='model', repo_id='rodmosc/peft-lora-with-custom-model'), pr_revision=None, pr_num=None)

In [26]:
model = peft.PeftModel.from_pretrained(module_copy, model_id)
type(model)

adapter_config.json:   0%|          | 0.00/903 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/209k [00:00<?, ?B/s]

peft.peft_model.PeftModel

In [27]:
y_peft = peft_model(X.to(device))
y_model = model(X.to(device))
torch.allclose(y_peft, y_model)

True