In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [8]:
class LinearLoRA(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, merge=True, dropout=None):
        super(LinearLoRA, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.rank = rank
        self.alpha = alpha
        self.merge = merge

        assert self.rank>0, "The rank is not positive"
        self.linear = nn.Linear(in_dim, out_dim)
        self.linear.weight.requires_grad = False
        # keep note that the order of dimension should be reversed since F.linear asks the output_dim to be the first
        self.lora_a = nn.Parameter(torch.empty(rank, in_dim))
        self.lora_b = nn.Parameter(torch.empty(out_dim, rank))
        self.scale = self.alpha/self.rank
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = nn.Identity()

        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.lora_a, mean=0.0, std=1.0)
        nn.init.zeros_(self.lora_b)
        
    def forward(self, x):
        # linear model y = xA^T + b. A should be out_dim x in_dim
        if self.merge:
            x = F.linear(x, self.linear.weight + (self.lora_b @ self.lora_a * self.scale), self.linear.bias)
        else:
            x = self.linear(x)
        return x

In [9]:
net = LinearLoRA(5, 10, 2, 0.01)

In [11]:
x = torch.rand(20, 5)

In [12]:
out = net(x)

In [13]:
out.shape

torch.Size([20, 10])