In [1]:
from functools import partial

import torch
from minlora import (
    LoRAParametrization,
    add_lora,
    apply_to_lora,
    merge_lora,
)
from torch import nn

_ = torch.set_grad_enabled(False)

## Adding LoRA to layers other than nn.Linear

In [2]:
## add_lora supports an optional `lora_config` argument of type Dict[Type[nn.Module], Dict[str, Callable]]
## it specifies how to apply lora to each layer

## Currently, there are support for nn.Embedding, nn.Linear, and nn.Conv2d

lora_config = {
    nn.Embedding: {
        "weight": partial(LoRAParametrization.from_embedding, rank=4),
    },
    nn.Linear: {
        "weight": partial(LoRAParametrization.from_linear, rank=4),
    },
}

model = nn.Sequential(
    nn.Embedding(num_embeddings=3, embedding_dim=2),
    nn.Linear(in_features=2, out_features=3),
)
add_lora(model, lora_config=lora_config)
model

Sequential(
  (0): ParametrizedEmbedding(
    3, 2
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (1): ParametrizedLinear(
    in_features=2, out_features=3, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
)

## Tying weights

In [3]:
# let's see if this works
linear = nn.Linear(in_features=2, out_features=3)
embedding = nn.Embedding(num_embeddings=3, embedding_dim=2)
# tie the weights of the linear layer and the embedding layer
embedding.weight = linear.weight
print(torch.allclose(embedding.weight, linear.weight))
# so far so good

True


In [4]:
# now, add lora to the linear layer
add_lora(linear)
# and update the lora weights to make it non-zero
linear.apply(apply_to_lora(lambda x: nn.init.ones_(x.lora_B)))
# and the weights are no longer the same
print(torch.allclose(embedding.weight, linear.weight))

False


In [5]:
# because adding lora makes the `weight` a computed property that returns a tensor.
# It's not a Parameter anymore
print(type(linear.weight), type(embedding.weight))

<class 'torch.Tensor'> <class 'torch.nn.parameter.Parameter'>


In [6]:
# to tie the weights, we need to add lora to the embedding layer as well
# let's add lora to the embedding layer

add_lora(embedding, lora_config=lora_config)
# tie the lora weights
# because the fan_in and fan_out are opposite to each other, we need to swap the lora weights A and B
# here we assign the linear layer's A to the embedding layer's B, and vice versa
# you can do it the other way around as well, but the initialization will be different
embedding.parametrizations.weight[0].lora_A = linear.parametrizations.weight[0].lora_B
embedding.parametrizations.weight[0].lora_B = linear.parametrizations.weight[0].lora_A
linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))
linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))
assert torch.allclose(linear.weight, embedding.weight)

In [7]:
# although the shape of the weight is the same, the lora parameters have different shapes
print(
    embedding.parametrizations.weight[0].lora_A.shape,
    linear.parametrizations.weight[0].lora_A.shape,
    embedding.parametrizations.weight[0].lora_B.shape,
    linear.parametrizations.weight[0].lora_B.shape,
)

torch.Size([3, 4]) torch.Size([4, 2]) torch.Size([4, 2]) torch.Size([3, 4])


In [8]:
# update to the linear layer will also update the embedding layer
linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))
linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_A)))
print(torch.allclose(linear.weight, embedding.weight))
# vice versa
embedding.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))
embedding.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_A)))
print(torch.allclose(linear.weight, embedding.weight))
# embedding.apply(apply_to_lora(lambda x: print(x.lora_B, x.lora_A)))
# linear.apply(apply_to_lora(lambda x: print(x.lora_B, x.lora_A)))

True
True


In [9]:
# we can put the logic of tying the weights in a function
def tie_weights(linear: nn.Linear, embedding: nn.Embedding):
    """tie the weights of the linear layer and the embedding layer both with the same lora"""
    # this line below is optional if the original is already tied
    embedding.parametrizations.weight.original = linear.parametrizations.weight.original
    embedding.parametrizations.weight[0].lora_A = linear.parametrizations.weight[0].lora_B
    embedding.parametrizations.weight[0].lora_B = linear.parametrizations.weight[0].lora_A
# you can import this function directly:
from minlora import tie_weights, untie_weights

In [10]:
# now back to our first model with lora
tie_weights(model[0], model[1])
# update the lora weights of the linear layer
apply_to_lora(lambda x: nn.init.uniform_(x.lora_B))(model[1])
# and the weights are the still the same
assert torch.allclose(model[0].weight, model[1].weight)
merge_lora(model)
# even after merging lora, the weights are still the same
assert torch.allclose(model[0].weight, model[1].weight)