# LoRA Layer

In [1]:
import torch


class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.W_a = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.W_b = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.W_a @ self.W_b)
        return x


class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [2]:
torch.manual_seed(123)

# a simple linear layer with 10 inputs and 1 output
# requires_grad=False makes it non-trainable
linear_layer = torch.nn.Linear(10, 1)
linear_layer.requires_grad_=True
print(linear_layer)

# a simple example input
x = torch.rand((1, 10))
print(x)

# 执行线性层并输出
linear_layer(x)

Linear(in_features=10, out_features=1, bias=True)
tensor([[0.6871, 0.0756, 0.1966, 0.3164, 0.4017, 0.1186, 0.8274, 0.3821, 0.6605,
         0.8536]])


tensor([[-0.5745]], grad_fn=<AddmmBackward0>)

Replace linear layer with LoRA layer:

In [3]:
# 用我们的新LinearWithLoRA层替换其预训练Linear层
# replace its pretrained Linear layers with our new LinearWithLoRA layer
lora_layer = LinearWithLoRA(linear=linear_layer, rank=8, alpha=1)
# 仅替换后 再次执行lora_layer
lora_layer(x)
# 可以看到，结果相比执行线性层linear_layer没有变化，因为模型还没有训练

tensor([[-0.5745]], grad_fn=<AddBackward0>)

Note that the LoRA layer will not change the linear layer output until it's trained because its `W_b` weight matrix is initialized to all zeros. For LoRA to take effect, we have to train the model. Below is a simple toy example.

Let's simulate a simple weight update:

In [4]:
# 模拟训练的过程，更新lora_layer权重
# 后续 我们只需将这一步自动化
lora_layer.lora.W_b = torch.nn.Parameter(lora_layer.lora.W_b + 0.01 * x[0])

We can now see that the output has changed:

In [9]:
print(x)
lora_layer(x)

tensor([[0.6871, 0.0756, 0.1966, 0.3164, 0.4017, 0.1186, 0.8274, 0.3821, 0.6605,
         0.8536]])


tensor([[-0.5863, -0.5758, -0.5779, -0.5800, -0.5814, -0.5766, -0.5887, -0.5811,
         -0.5859, -0.5892]], grad_fn=<AddBackward0>)