In [1]:
import torch
import torch.nn as nn
import math

In [2]:
# LoRA implementation for the linear layer of the model

class LoRALinear(nn.Module):
  def __init__(self, orig_params, r=0, alpha=None, dropout=0.1, merge=False, init_scale=0.1):
    super().__init__()
    self.orig = orig_params
    self.r = r
    self.merge = merge
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
    d_in = orig_params.in_features
    d_out = orig_params.out_features

    if alpha is None:
      alpha = r
    self.scaling = alpha / r

    # LoRA params: A (r x d_in), B (d_out x r)
    self.A = nn.Parameter(torch.randn(r, d_in) * init_scale)
    self.B = nn.Parameter(torch.zeros(d_out, r))

    # Freeze original weights
    for p in self.orig.Parameter():
      p.requires_grad = False

  def forward(self, x):
    # x: [batch, ..., d_in]
    # orig_out: regular linear
    orig_out = self.orig(x)

    if self.merge:
      # If merged, orig already contains LoRA contribution
      return orig_out

    # LoRA path: compute A x -> (r)
    lora_inter = self.dropout(x) @ self.A.t() # shape: [..., r]
    lora_out = lora_inter @ self.B.t() # shape: [..., d_out]
    lora_out = self.scaling * lora_out

    return orig_out + lora_out

  def merge_weights(self):
    """Fuse LoRA into orig weight: W = W + scaling * (B @ A)"""
    if hasattr(self.orig, 'weight') and not self.merge:
      with torch.no_grad():
        delta = (self.B @ self.A) * self.scaling
        self.orig.weight.data += delta
      self.merge = True
    # del self.A, self.B

