In [10]:
import torch
import torch.nn as nn
import math

In [31]:
class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank, lora_alpha):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.scale = lora_alpha / rank

        self.linear = nn.Linear(in_features, out_features) # Y = X * W^T + b
        # size of X: (batch_size, in_features)
        # size of self.linear.weight (i.e., W): (out_features, in_features)

        # Y = X * (W + A * B)^T + b
        # size of A: (out_features, rank)
        # size of B: (rank, in_features)

        if self.rank > 0:
            # 这里需要nn.Parameter, 这样model.parameters()才会包含这个参数
            self.lora_a = nn.Parameter(torch.randn(out_features, rank))
            nn.init.kaiming_normal_(self.lora_a, a=0.01)
            self.lora_b = nn.Parameter(torch.zeros(rank, in_features))
        else:
            self.lora_a = None
            self.lora_b = None

        # 设置linear的参数为不可训练
        for param in self.linear.parameters():
            param.requires_grad = False

    def forward(self, X):
        if self.rank > 0:
            # Y = X * (W + A * B)^T + b = X * W^T + b + X * (A * B)^T
            return self.linear(X) + self.scale * X @ (self.lora_a @ self.lora_b).T
        else:
            return self.linear(X)

In [40]:
batch_size = 8
in_features = 10
out_features = 20
rank = 2
lora_alpha = 1

X = torch.randn(batch_size, in_features)

lora = LoRA(in_features, out_features, rank, lora_alpha)

Y = lora(X)

print(f"Shape of Y: {Y.shape}")

s = Y.sum()

print(f"s: {s}")

s.backward()

print("grad sum of lora.lora_a: ", lora.lora_a.grad.sum())
print("grad sum of lora.lora_b: ", lora.lora_b.grad.sum())
print("grad of lora.linear.bias: ", lora.linear.bias.grad)
# \frac{\partial s}{\partial A} = B @ X^T
print("grad of lora.linear.weight: ", lora.linear.weight.grad)


