### LoRA
![LoRA](./pics/LoRA.png)

### · coding

In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class linearLoRALayer(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        merge = False,
        rank = 8,
        loar_alpha = 16,
        dropout = 0.1,
    ):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.merge = merge
        self.rank = rank

        # linear weight shape : (in_features, out_features)
        self.linear = nn.Linear(in_features, out_features)
        # input x shape : (batch_size, seq_len, in_features)
        # 计算过程是: x @ weight.T
        # 所以weight shape:(out_features, in_features)

        if rank > 0:
            self.lora_a = nn.Parameter(
                torch.zeros(out_features, rank)
            )
            # 高斯分布，使用凯明初始化
            nn.init.kaiming_uniform_(self.lora_a, a = 0.01)

            self.lora_b = nn.Parameter(
                torch.zeros(rank, in_features)
            )
            self.scale = loar_alpha / rank

            self.linear.weight.requires_grad = False
            self.linear.bias.requires_grad = False

        self.dropout = nn.Dropout(
            dropout
        ) if dropout > 0 else nn.Identity()

        # merge推理, 会把 lora_a 和 lora_b 两个小矩阵的参数直接放到 linear.weight 中
        if merge:
            self.merge_weight()


    def merge_weight(self, ):
        if self.merge and self.rank > 0:
            self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)

    
    def unmerge_weight(self, ):
        if self.merge and self.rank > 0:
            self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)


    
    def forward(self, x):
        # x shape :(batch_size, seq_len, in_features)
        # lora_a 是out_features * rank

        if self.rank > 0 and not self.merge:
            output = self.linear(x) + self.scale * (x @ (self.lora_a @ self.lora_b).T)
        elif self.rank > 0 and self.merge:
            output = self.linear(x)
        else:
            output = self.linear(x)

        return self.dropout(output)
        
# Test
batch_size = 32
seq_len = 128
in_features = 768
out_features = 512
rank = 8
lora_alpha = 16
dropout = 0.1

x = torch.randn(batch_size, seq_len, in_features)

# no merge
lora_layer = linearLoRALayer(
    in_features = in_features,
    out_features = out_features,
    merge = False,
    rank = rank,
    loar_alpha = lora_alpha,
    dropout = dropout,
)

# Forward pass
output = lora_layer(x)
print(f"Output shape(no merage): {output.shape}")

# merge mode
lora_merge_mode = linearLoRALayer(
    in_features = in_features,
    out_features = out_features,
    merge = True,
    rank = rank,
    loar_alpha = lora_alpha,
    dropout = dropout,
)

# Forward pass with merged weights
output_merged = lora_merge_mode(x)
print(f"Output shape(merge): {output_merged.shape}")

# Test weight merging/unmerging
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)

print("Max difference after merge/unmerge cycle:", 
            torch.max(torch.abs(output - output_after_unmerge)).item()
)

Output shape(no merage): torch.Size([32, 128, 512])
Output shape(merge): torch.Size([32, 128, 512])
Max difference after merge/unmerge cycle: 3.1814584732055664
