In [61]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from functools import partial

from lora import *

### Examples of how stuff works

In [62]:
# a simple model
model = torch.nn.Sequential()
model.add_module('dumb_module', torch.nn.Sequential())
model.dumb_module.add_module('linear', torch.nn.Linear(5, 5))
model.dumb_module.add_module('relu', torch.nn.ReLU())
model.dumb_module.add_module('Q', torch.nn.Linear(5, 5))
model.add_module('linear', torch.nn.Linear(5, 5))
model.add_module('Q', torch.nn.Linear(5, 5))
model.add_module('K', torch.nn.Linear(5, 5))
model.add_module('A', torch.nn.Linear(5, 5))

x = torch.randn(1, 5)
y = model(x)
print(y)
Y0 = y

tensor([[ 0.5813,  0.4768, -0.1592, -0.3883, -0.2509]],
       grad_fn=<AddmmBackward0>)


In [63]:
for name, module in model.named_modules():
    print(name)


dumb_module
dumb_module.linear
dumb_module.relu
dumb_module.Q
linear
Q
K
A


In [66]:
def construct_lorafa_config(model, rank):
    """Each layer has a unique name within module hierarchy, so we can identify 
    them for lora parametrization"""
    config = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any([name.split('.')[-1] == n for n in ['Q', 'K']]):
            config[name] = {
                nn.Linear: {
                    "weight": partial(
                        LoRAFAParametrization.from_linear,
                        rank=rank, 
                        init_method="svd", # set svd as initiazliation method
                        original_weights=module.weight.data,
                    ),
                }
            }
    return config

lorafa_config = construct_lorafa_config(model, rank=2)
list(lorafa_config.keys())

['dumb_module.Q', 'Q', 'K']

In [67]:
add_lora_by_layer_names(model, lorafa_config)
y = model(x)
print(y)

tensor([[ 0.6452,  0.5182, -0.1465, -0.3162, -0.2925]],
       grad_fn=<AddmmBackward0>)


In [69]:
for n, p in model.named_parameters():
    if name_is_lora(n) and n.split('.')[-1] == 'lora_B':
        print(n, p)        

dumb_module.Q.parametrizations.weight.0.lora_B Parameter containing:
tensor([[ 0.7124,  0.3005],
        [ 0.1960, -0.2111],
        [-0.0140,  0.7604],
        [ 0.2219,  0.3728],
        [-0.6361,  0.3848]], requires_grad=True)
Q.parametrizations.weight.0.lora_B Parameter containing:
tensor([[ 0.5842, -0.0712],
        [ 0.3016,  0.8204],
        [-0.0405, -0.4805],
        [ 0.5626, -0.1995],
        [ 0.4996, -0.2264]], requires_grad=True)
K.parametrizations.weight.0.lora_B Parameter containing:
tensor([[-0.2473,  0.6015],
        [ 0.5891,  0.4978],
        [-0.1533,  0.4535],
        [ 0.3463,  0.3257],
        [ 0.6696, -0.2804]], requires_grad=True)


In [42]:
# list of parameters generated by lora
list(get_lora_state_dict(model).keys())

['dumb_module.Q.parametrizations.weight.0.lora_A',
 'dumb_module.Q.parametrizations.weight.0.lora_B',
 'Q.parametrizations.weight.0.lora_A',
 'Q.parametrizations.weight.0.lora_B',
 'K.parametrizations.weight.0.lora_A',
 'K.parametrizations.weight.0.lora_B']

In [43]:
# Note: for some reason get_lora_named_parameters(model) returns all parameters with requires_grad=False
# need to figure out why
# but this works
for params in get_lora_params(model): # check that A are frozen
    print(params.shape, params.requires_grad)

torch.Size([2, 5]) False
torch.Size([5, 2]) True
torch.Size([2, 5]) False
torch.Size([5, 2]) True
torch.Size([2, 5]) False
torch.Size([5, 2]) True


### Better example with complete pipeline

Define some model

In [22]:
model = torch.nn.Sequential()
model.add_module('block1', torch.nn.Sequential())
model.block1.add_module('Q', torch.nn.Linear(1, 10))
model.block1.add_module('relu', torch.nn.ReLU())
model.block1.add_module('K', torch.nn.Linear(10, 10))

model.add_module('block2', torch.nn.Sequential())
model.block2.add_module('Q', torch.nn.Linear(10, 3))
model.block2.add_module('relu', torch.nn.ReLU())
model.block2.add_module('A', torch.nn.Linear(3, 1))

In [23]:
model # let's imagine this is LLM with Q,K within some blocks

Sequential(
  (block1): Sequential(
    (Q): Linear(in_features=1, out_features=10, bias=True)
    (relu): ReLU()
    (K): Linear(in_features=10, out_features=10, bias=True)
  )
  (block2): Sequential(
    (Q): Linear(in_features=10, out_features=3, bias=True)
    (relu): ReLU()
    (A): Linear(in_features=3, out_features=1, bias=True)
  )
)

Define proper config for target layers

In [41]:
def construct_lorafa_config(model, rank):
    """Each layer has a unique name within module hierarchy, so we can identify 
    them for lora parametrization"""
    config = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any([name.split('.')[-1] == n for n in ['Q', 'K']]):
            config[name] = {
                nn.Linear: {
                    "weight": partial(
                        LoRAFAParametrization.from_linear,
                        rank=rank, 
                        init_method="svd", # set svd as initiazliation method
                        original_weights=module.weight # pass weights for svd init
                    ),
                }
            }
    return config

In [42]:
lorafa_config = construct_lorafa_config(model=model, rank=3)
add_lora_by_layer_names(model, lorafa_config)

In [43]:
# check new added lora weights
for name, params in model.named_parameters():
    print(name)

block1.Q.bias
block1.Q.parametrizations.weight.original
block1.Q.parametrizations.weight.0.lora_A
block1.Q.parametrizations.weight.0.lora_B
block1.K.bias
block1.K.parametrizations.weight.original
block1.K.parametrizations.weight.0.lora_A
block1.K.parametrizations.weight.0.lora_B
block2.Q.bias
block2.Q.parametrizations.weight.original
block2.Q.parametrizations.weight.0.lora_A
block2.Q.parametrizations.weight.0.lora_B
block2.A.weight
block2.A.bias


And now training pipeline

In [44]:
def freeze_nonlora(model):
    for name, parameters in model.named_parameters():
        if not name_is_lora(name):
            parameters.requires_grad = False

def get_trainable_lorafa_B_weights(model):
    for n, p in model.named_parameters():
        if name_is_lora(n) and name.split(".")[-1] == "lora_B":
            yield p

# step 1. freeze nonlora layers
freeze_nonlora(model)

# step 2. pass trainable parameters to optimizer 
parameters = [{"params": get_trainable_lorafa_B_weights(model)}]
optimizer = torch.optim.Adam(parameters, lr=1e-3)


# step 3. training
criterion = torch.nn.functional.mse_loss
x = np.linspace(start=0, stop=10, num=300)
y = 3*x + 15 + np.random.normal(0, 1e-1, size=(300))
dataset = torch.from_numpy(np.stack((x, y),axis=1)).to(torch.float32)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

epochs=10
for e in range(epochs):
    l_cum, n = 0, 0
    for batch in loader:
        x, y = batch[:, 0][:, None], batch[:, 1][:, None]
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

In [45]:
# sanity check that A didn't change
for name, parameters in model.named_parameters():
    if name_is_lora(name) and name.split('.')[-1] == 'lora_A':
        assert not parameters.requires_grad