- **Author**: Jaehyuk Heo
- **Paper**: LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS [ [link](https://arxiv.org/pdf/2106.09685.pdf) ]

<p>
    <img width='400' src='https://user-images.githubusercontent.com/37654013/168552268-c764cb0c-2684-4082-a633-ba2d8cf340a3.png'>
</p>

**How to install**
---
```bash
pip install loralib
```
---

**Example of LoRA Linear**

---
```python
class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights
        

class Linear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            if self.r > 0:
                self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
            self.merged = False
    
    def eval(self):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            if self.r > 0:
                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
            self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)
```
---

In [47]:
%reload_ext autoreload
%autoreload 2

import loralib as lora
import torch.nn as nn
import torch
import numpy as np
from timm import create_model

In [5]:
class Config:
    lora_r = 8
    lora_alpha = 8
    
args = Config()

# torch.nn.Linear

In [20]:
torch_linear = nn.Linear(16, 32)

print('torch.nn.linear.weight.size(): ',torch_linear.weight.size())

torch.nn.linear.weight.size():  torch.Size([32, 16])


In [43]:
result = nn.functional.linear(torch.randn(1,16), torch_linear.weight)
print('result size: ',result.size())

result size:  torch.Size([1, 32])


## Load a Model

In [56]:
model = create_model(
    'vit_base_patch16_224', 
    num_classes = 100, 
    pretrained  = True
)

print('load a model')

load a model


In [57]:
total_params = np.sum([p.numel() for p in model.parameters()])
trainable_params = 0
for p in model.parameters():
    if p.requires_grad:
        trainable_params += p.numel()

print('# of total parameters: ',total_params)
print('# of trainable parameters: ',trainable_params)

# of total parameters:  85875556
# of trainable parameters:  85875556


# LoRA Linear

In [58]:
lora_linear = lora.Linear(16, 32, r=args.lora_r, lora_alpha=args.lora_alpha)

print('lora_A.size(): ',lora_linear.lora_A.size())
print('lora_B.size(): ',lora_linear.lora_B.size())

print('(lora_A.T @ lora_B.T).size(): ',(lora_linear.lora_A.T @ lora_linear.lora_B.T).size())

result = torch.randn(1,16) @ lora_linear.lora_A.T @ lora_linear.lora_B.T
print('result size: ',result.size())

lora_A.size():  torch.Size([8, 16])
lora_B.size():  torch.Size([32, 8])
(lora_A.T @ lora_B.T).size():  torch.Size([16, 32])
result size:  torch.Size([1, 32])


## Load Model

In [59]:
model_lora = create_model(
    'vit_base_patch16_224', 
    num_classes = 100, 
    apply_lora  = True, 
    lora_r      = 8, 
    lora_alpha  = 8, 
    pretrained  = True
)

lora.mark_only_lora_as_trainable(model_lora)
print('load a model')

load a model


In [60]:
total_params = np.sum([p.numel() for p in model_lora.parameters()])
trainable_params = 0
for p in model_lora.parameters():
    if p.requires_grad:
        trainable_params += p.numel()

print('# of total parameters: ',total_params)
print('# of trainable parameters: ',trainable_params)

# of total parameters:  86317924
# of trainable parameters:  442368
