In [1]:
import torch
from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora, get_lora_state_dict
_ = torch.set_grad_enabled(False)

[2024-04-13 22:35:37,592] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [7]:
'''Timing trials; monarch (ABx) vs (ABI)x '''
import time
import torch
from src.models.layers.monarch_linear import MonarchLinear

n_runs = 10_000
in_features=1000
out_features=500
device="cuda"
monarch = MonarchLinear(nblocks=4, adapt=False, in_features=in_features, out_features=out_features,
              bias=False, device=device, dtype=None).to(device)
x = torch.rand((2*in_features,in_features),device=device)

# start = time.time()
# for _ in range(n_runs):
#     monarch.forward_matmul(x)
# print(time.time()-start)

# start = time.time()
# for _ in range(n_runs):
#     monarch.forward_matmul(torch.eye(in_features,device=device)) @ x
# print(time.time()-start)

In [6]:
x = torch.rand((2*in_features,in_features),device=device)
monarch.forward_matmul(x).shape

torch.Size([2000, 1000])

In [2]:
# a simple model
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=100, out_features=200),
    torch.nn.Linear(in_features=200, out_features=40),
)

x = torch.randn(1, 100)
y = model(x)
print(y)
Y0 = y

tensor([[-0.0967,  0.0831, -0.0068, -0.4254, -0.1586,  0.4504, -0.1301,  0.0330,
         -0.0103,  0.1027,  0.4141, -0.2607, -0.3417,  0.1599, -0.0211, -0.0440,
         -0.1365, -0.2341,  0.1662,  0.5770, -0.1103,  0.2123, -0.4326,  0.0049,
         -0.4706,  0.1912,  0.2688,  0.0759,  1.2353,  0.0486,  0.5336, -0.3987,
         -0.0479, -0.7208,  0.2263,  0.4558,  0.0864, -0.2018,  0.2305,  0.5532]])


In [3]:
# add lora to the model
# becase B is initialized to 0, the output is the same as before
add_lora(model)
y = model(x)
assert torch.allclose(y, Y0)

for name, param in model.state_dict().items():
    print(name, name_is_lora(name))

0.bias False
0.parametrizations.weight.original False
0.parametrizations.weight.0.monarch.bias False
0.parametrizations.weight.0.monarch.blkdiag1 True
0.parametrizations.weight.0.monarch.blkdiag2 True
1.bias False
1.parametrizations.weight.original False
1.parametrizations.weight.0.monarch.bias False
1.parametrizations.weight.0.monarch.blkdiag1 True
1.parametrizations.weight.0.monarch.blkdiag2 True


In [4]:
# to make the output different, we need to initialize B to something non-zero
# model.apply(apply_to_lora(lambda x: torch.nn.init.ones_(x.lora_B)))
model.apply(apply_to_lora(lambda x: torch.nn.init.ones_(x.monarch.blkdiag2)))
y = model(x)
print(y)
assert not torch.allclose(y, Y0)
Y1 = y

tensor([[-0.1344, -0.0259,  0.0673, -0.3784, -0.1344,  0.5476, -0.0929,  0.0768,
         -0.0510,  0.0317,  0.3515, -0.2797, -0.4501,  0.1968, -0.0089, -0.1401,
         -0.2008, -0.3028,  0.0180,  0.5884, -0.1402,  0.1190, -0.4799, -0.0457,
         -0.4249,  0.0741,  0.2738,  0.1504,  1.1864, -0.0498,  0.6305, -0.3022,
          0.0044, -0.6683,  0.1302,  0.4625,  0.1276, -0.1924,  0.2804,  0.6926]])


In [5]:
# now let's try to disable lora, the output is the same as before lora is added
disable_lora(model)
y = model(x)
assert torch.allclose(y, Y0)

In [6]:
# enable lora again
enable_lora(model)
y = model(x)
assert torch.allclose(y, Y1)

In [7]:
# let's save the state dict for later use
state_dict_to_save = get_lora_state_dict(model)
state_dict_to_save.keys()

dict_keys(['0.parametrizations.weight.0.monarch.blkdiag1', '0.parametrizations.weight.0.monarch.blkdiag2', '1.parametrizations.weight.0.monarch.blkdiag1', '1.parametrizations.weight.0.monarch.blkdiag2'])

In [8]:
name_is_lora

<function minlora.utils.name_is_lora(name)>

In [9]:
# you can remove lora from the model
remove_lora(model)

In [10]:
# lets try to load the lora back
# first we need to add lora to the model
add_lora(model)
# then we can load the lora parameters
# strict=False is needed because we are loading a subset of the parameters
_ = model.load_state_dict(state_dict_to_save, strict=False) 
y = model(x)
assert torch.allclose(y, Y1)

In [11]:
# we can merge it to make it a normal linear layer, so there is no overhead for inference
merge_lora(model)
y = model(x)
assert torch.allclose(y, Y1)

In [12]:
# model now has no lora parameters
model

Sequential(
  (0): Linear(in_features=100, out_features=200, bias=True)
  (1): Linear(in_features=200, out_features=40, bias=True)
)

## Training a model

In [14]:
model = torch.nn.Linear(in_features=5, out_features=3)
# Step 1: Add LoRA to the model
add_lora(model)

# Step 2: Collect the parameters, pass them to the optimizer

parameters = [
    {"params": list(get_lora_params(model))},
]
optimizer = torch.optim.AdamW(parameters, lr=1e-3)

# Step 3: Train the model
# ...
# simulate training, update the LoRA parameters
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.monarch.blkdiag1)))
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.monarch.blkdiag2)))

# Step 4: export the LoRA parameters
state_dict = model.state_dict()
lora_state_dict = {k: v for k, v in state_dict.items() if name_is_lora(k)}

## Loading and Inferencing with LoRA

In [15]:
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters
_ = model.load_state_dict(lora_state_dict, strict=False)

# Step 3: Merge the LoRA parameters into the model
merge_lora(model)

## Inferencing with multiple LoRA models

In [16]:
# to avoid re-adding lora to the model when rerun the cell, remove lora first 
remove_lora(model)
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters

# fake 3 sets of LoRA parameters
lora_state_dict_0 = lora_state_dict
lora_state_dict_1 = {k: torch.ones_like(v) for k, v in lora_state_dict.items()}
lora_state_dict_2 = {k: torch.zeros_like(v) for k, v in lora_state_dict.items()}
lora_state_dicts = [lora_state_dict_0, lora_state_dict_1, lora_state_dict_2]

load_multiple_lora(model, lora_state_dicts)

# Step 3: Select which LoRA to use at inference time
Y0 = select_lora(model, 0)(x)
Y1 = select_lora(model, 1)(x)
Y2 = select_lora(model, 2)(x)

AttributeError: 'LoRAParametrization' object has no attribute 'lora_A'

In [None]:
Y0, Y1, Y2

In [None]:
remove_lora(model)
init_state_dict = model.state_dict()
# verify that it's the same as if we load the lora parameters one by one
for state_dict in lora_state_dicts:
    remove_lora(model)
    _ = model.load_state_dict(init_state_dict, strict=False)
    add_lora(model)
    _ = model.load_state_dict(state_dict, strict=False)
    merge_lora(model)
    y = model(x)
    print(y)