In [1]:
%load_ext autoreload

In [2]:
import torch
from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora, get_lora_state_dict
_ = torch.set_grad_enabled(False)

In [3]:
# a simple model
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=5, out_features=7),
    torch.nn.Linear(in_features=7, out_features=3),
)

x = torch.randn(1, 5)
y = model(x)
print(y)
Y0 = y

tensor([[-0.0028, -0.3906, -0.6502]])


In [4]:
# add lora to the model
# becase B is initialized to 0, the output is the same as before
add_lora(model)
y = model(x)
assert torch.allclose(y, Y0)

torch.Size([7, 5])


  return X + (torch.matmul(self.lora_B, self.dropout_fn(self.lora_A)).T).view(X.shape) * self.scaling


RuntimeError: shape '[7, 5]' is invalid for input of size 16

In [None]:
# to make the output different, we need to initialize B to something non-zero
model.apply(apply_to_lora(lambda x: torch.nn.init.ones_(x.lora_B)))
y = model(x)
print(y)
assert not torch.allclose(y, Y0)
Y1 = y

In [None]:
# now let's try to disable lora, the output is the same as before lora is added
disable_lora(model)
y = model(x)
assert torch.allclose(y, Y0)

In [None]:
# enable lora again
enable_lora(model)
y = model(x)
assert torch.allclose(y, Y1)

In [None]:
# let's save the state dict for later use
state_dict_to_save = get_lora_state_dict(model)
state_dict_to_save.keys()

In [None]:
# you can remove lora from the model
remove_lora(model)

In [None]:
# lets try to load the lora back
# first we need to add lora to the model
add_lora(model)
# then we can load the lora parameters
# strict=False is needed because we are loading a subset of the parameters
_ = model.load_state_dict(state_dict_to_save, strict=False) 
y = model(x)
assert torch.allclose(y, Y1)

In [None]:
# we can merge it to make it a normal linear layer, so there is no overhead for inference
merge_lora(model)
y = model(x)
assert torch.allclose(y, Y1)

In [None]:
# model now has no lora parameters
model

## Training a model

In [None]:
model = torch.nn.Linear(in_features=5, out_features=3)
# Step 1: Add LoRA to the model
add_lora(model)

# Step 2: Collect the parameters, pass them to the optimizer

parameters = [
    {"params": list(get_lora_params(model))},
]
optimizer = torch.optim.AdamW(parameters, lr=1e-3)

# Step 3: Train the model
# ...
# simulate training, update the LoRA parameters
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_A)))
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_B)))

# Step 4: export the LoRA parameters
state_dict = model.state_dict()
lora_state_dict = {k: v for k, v in state_dict.items() if name_is_lora(k)}

## Loading and Inferencing with LoRA

In [None]:
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters
_ = model.load_state_dict(lora_state_dict, strict=False)

# Step 3: Merge the LoRA parameters into the model
merge_lora(model)

## Inferencing with multiple LoRA models

In [None]:
# to avoid re-adding lora to the model when rerun the cell, remove lora first 
remove_lora(model)
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters

# fake 3 sets of LoRA parameters
lora_state_dict_0 = lora_state_dict
lora_state_dict_1 = {k: torch.ones_like(v) for k, v in lora_state_dict.items()}
lora_state_dict_2 = {k: torch.zeros_like(v) for k, v in lora_state_dict.items()}
lora_state_dicts = [lora_state_dict_0, lora_state_dict_1, lora_state_dict_2]

load_multiple_lora(model, lora_state_dicts)

# Step 3: Select which LoRA to use at inference time
Y0 = select_lora(model, 0)(x)
Y1 = select_lora(model, 1)(x)
Y2 = select_lora(model, 2)(x)

In [None]:
Y0, Y1, Y2

In [None]:
remove_lora(model)
init_state_dict = model.state_dict()
# verify that it's the same as if we load the lora parameters one by one
for state_dict in lora_state_dicts:
    remove_lora(model)
    _ = model.load_state_dict(init_state_dict, strict=False)
    add_lora(model)
    _ = model.load_state_dict(state_dict, strict=False)
    merge_lora(model)
    y = model(x)
    print(y)

In [5]:
x = torch.tensor([1, 2, 3])

In [8]:
torch.nn.functional.pad(x, 4)

TypeError: pad(): argument 'pad' (position 2) must be tuple of ints, not int