##### Companion notebook to this blog article Improving LoRA: Implementing Weight-Decomposed Low-Rank Adaptation (DoRA) from Scratch.

In [None]:
# !pip install watermark
# !pip install torchvision

In [1]:
%load_ext watermark

In [24]:
import time
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch
from tqdm import tqdm

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() or torch.backend.mps.is_available() else 'cpu'

In [4]:
device = torch.device(DEVICE)
batch_size = 64

In [5]:
train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 38949210.20it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 62505517.97it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 38439770.98it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 21429166.22it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [12]:
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

for images, labels in train_loader:
    print(images.shape, images.view(batch_size, -1).shape, labels.shape, mlp(images).shape)
    break

torch.Size([64, 1, 28, 28]) torch.Size([64, 784]) torch.Size([64]) torch.Size([64, 12])


In [20]:
# nn.RE

In [47]:
##########################
### MODEL
##########################

# Hyperparameters
random_seed = 123
learning_rate = 0.005
num_epochs = 2

# Architecture
num_features = 784
num_hidden_1 = 128
num_hidden_2 = 256
num_classes = 10

In [48]:
class MultilayerMLP(nn.Module):
    def __init__(self, input_dim, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.input_dim = input_dim

        self.layers = nn.Sequential(
            nn.Linear(input_dim, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )

    def forward(self, images):
        images = images.view(-1, self.input_dim)
        return self.layers(images)

In [49]:
torch.manual_seed(random_seed)


<torch._C.Generator at 0x1066494d0>

In [50]:
model_pretrained = MultilayerMLP(
    input_dim=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2, 
    num_classes=num_classes
)

In [51]:
model_pretrained = model_pretrained.to(device)
optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=learning_rate)

In [41]:
# optimizer_pretrained.zero_grad(set_to_none=True)

In [42]:
# F.cross_entropy??

In [43]:
# nn.LogSoftmax??

In [52]:
def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples = 0, 0
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.to(device)
            targets = targets.to(device)
            predictions = model(features)
            _, predicted_labels = torch.max(predictions, 1)
            num_examples += features.size(0)
            correct_pred += (predicted_labels == targets).sum()
        
        return (correct_pred.float() / num_examples) * 100

In [53]:
def train(model, num_epochs, train_loader, optimizer, device):
    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()

        for batch_idx, (features, targets) in enumerate(train_loader):
            features = features.to(device)
            targets = targets.to(device)
            predictions = model(features)
            loss = F.cross_entropy(predictions, targets)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # logging
            if not batch_idx % 400:
                print(f"Epoch: {epoch:02d}, batch: {batch_idx}, loss: {loss:0.2f}")
            
        with torch.set_grad_enabled(False):
            print(f"epoch: {epoch:02d}, train accuracy: {compute_accuracy(model, train_loader, device)}")
    
    print(f"Total training time: {time.time() - start_time}")

In [54]:
train(model_pretrained, num_epochs, train_loader, optimizer_pretrained, device)

Epoch: 00, batch: 0, loss: 2.30
Epoch: 00, batch: 400, loss: 0.15
Epoch: 00, batch: 800, loss: 0.14
epoch: 00, train accuracy: 95.02666473388672
Epoch: 01, batch: 0, loss: 0.06
Epoch: 01, batch: 400, loss: 0.07
Epoch: 01, batch: 800, loss: 0.07
epoch: 01, train accuracy: 97.49333190917969
Total training time: 18.91064763069153


In [55]:
compute_accuracy(model_pretrained, test_loader, device)

tensor(96.6600, device='mps:0')

### LORA and DORA Layers

In [66]:
l = nn.Linear(1, 2)
l.in_features, l.out_features

(1, 2)

In [69]:
l.weight.shape

torch.Size([2, 1])

In [57]:
# torch.randn??

In [73]:
torch.tensor.norm??

Object `torch.tensor.norm` not found.


In [74]:
l = nn.Linear(5, 10)
l.in_features, l.out_features

(5, 10)

In [76]:
l.weight.shape

torch.Size([10, 5])

In [79]:
l.weight.norm(p=2, dim=0, keepdim=True).shape

torch.Size([1, 5])

In [105]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha
    
    def forward(self, x):
        # x = [B, num_features], A = [num_features, rank], B = [rank, out_dim] -> [B, out_dim]
        # print(f"x shape: {x.shape}, A shape: {self.A.shape}, B shape: {self.B.shape}")
        x = self.alpha * (x @ self.A @ self.B)
        return x


In [106]:
class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
    
    def forward(self, x):
        return self.lora(x) + self.linear(x) 

In [107]:
class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
    
    def forward(self, x):
        # self.linear.weight is of dim [out_features, in_features]
        # lora is of dim [in_features, out_features]
        lora = self.lora.alpha * (self.lora.A @ self.lora.B)
        weights = self.linear.weight + lora.T
        return F.linear(x, weights, self.linear.bias)


In [108]:
class LinearWithDoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        # shape is (1, in_features)
        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True)
        )
    
    def forward(self, x):
        # self.linear.weight is of dim [out_features, in_features]
        # lora is of dim [in_features, out_features]
        # approximating the directional vector
        lora = self.lora.alpha * (self.lora.A @ self.lora.B)
        weights = self.linear.weight + lora.T
        weights_norm = weights.norm(p=2, dim=0, keepdim=True)
        # shape is (out_features, in_features)
        directional_component = weights / weights_norm
        combined_weight = self.m * directional_component
        return F.linear(x, combined_weight, self.linear.bias)

##### Test the lora and dora layers

In [109]:
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1066494d0>

In [110]:
layer = nn.Linear(10, 2)
x = torch.randn((1, 10))
layer(x)

tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)

In [111]:
# should output 0 as expected as matrix B is all 0s
lora_layer = LoRALayer(10, 2, 4, 0.1)
lora_layer(x)

tensor([[0., 0.]], grad_fn=<MulBackward0>)

In [112]:
linear_lora_layer = LinearWithLoRA(layer, 4, 0.1)
linear_lora_layer(x)

tensor([[0.6639, 0.4487]], grad_fn=<AddBackward0>)

In [113]:
linear_lora_merged = LinearWithLoRAMerged(layer, 4, 0.1)
linear_lora_merged(x)


tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)

In [114]:
linear_dora_merged = LinearWithDoRAMerged(layer, 4, 0.1)
linear_dora_merged(x)

tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)

In [115]:
model_pretrained

MultilayerMLP(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [116]:
import copy 

original_list = [[1, 2, 3], [4, 5, 6]]
copy_list = copy.deepcopy(original_list)

In [117]:
copy_list

[[1, 2, 3], [4, 5, 6]]

In [118]:
copy_list[0][0] = 5
copy_list, original_list

([[5, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]])

In [119]:
model_lora = copy.deepcopy(model_pretrained)
model_dora = copy.deepcopy(model_pretrained)

In [120]:
model_lora.layers[0]

Linear(in_features=784, out_features=128, bias=True)

In [121]:
model_lora.layers[0] = LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2] = LinearWithLoRA(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4] = LinearWithLoRA(model_lora.layers[4], rank=4, alpha=8)

model_lora.to(device)
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
model_lora

MultilayerMLP(
  (layers): Sequential(
    (0): LinearWithLoRA(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRA(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRA(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [122]:
model_dora.layers[0] = LinearWithDoRAMerged(model_dora.layers[0], rank=4, alpha=8)
model_dora.layers[2] = LinearWithDoRAMerged(model_dora.layers[2], rank=4, alpha=8)
model_dora.layers[4] = LinearWithDoRAMerged(model_dora.layers[4], rank=4, alpha=8)

model_dora.to(device)
optimizer_dora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
model_dora

MultilayerMLP(
  (layers): Sequential(
    (0): LinearWithDoRAMerged(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithDoRAMerged(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithDoRAMerged(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [123]:
compute_accuracy(model_pretrained, test_loader, device)

tensor(96.6600, device='mps:0')

In [124]:
compute_accuracy(model_lora, test_loader, device)

tensor(96.6600, device='mps:0')

In [126]:
print(compute_accuracy(model_dora, test_loader, device))

tensor(96.6600, device='mps:0')


#### Train models with LoRA and DoRA layers

In [128]:
# freeze pretrained layer weights

In [141]:
list(list(model_lora.children())[0][0].children())

[Linear(in_features=784, out_features=128, bias=True), LoRALayer()]

In [142]:
def freeze_linear_layers(model):
    for module in model.children():
        if isinstance(module, nn.Linear):
            for p in module.parameters():
                p.requires_grad = False
        
        # recurse otherwise
        freeze_linear_layers(module)

In [143]:
freeze_linear_layers(model_lora)

In [145]:
for name, param in model_lora.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [None]:
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)

In [147]:
train??

[0;31mSignature:[0m [0mtrain[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mnum_epochs[0m[0;34m,[0m [0mtrain_loader[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0mtrain[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mnum_epochs[0m[0;34m,[0m [0mtrain_loader[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0mdevice[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mstart_time[0m [0;34m=[0m [0mtime[0m[0;34m.[0m[0mtime[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;32mfor[0m [0mepoch[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mnum_epochs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mmodel[0m[0;34m.[0m[0mtrain[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m        [0;32mfor[0m [0mbatch_idx[0m[0;34m,[0m [0;34m([0m[0mfeatures[0m[0;34m,[0m [0mtargets

In [148]:
train(model_lora, num_epochs, train_loader, optimizer_lora, device)

Epoch: 00, batch: 0, loss: 0.17
Epoch: 00, batch: 400, loss: 0.07
Epoch: 00, batch: 800, loss: 0.02
epoch: 00, train accuracy: 97.50166320800781
Epoch: 01, batch: 0, loss: 0.12
Epoch: 01, batch: 400, loss: 0.14
Epoch: 01, batch: 800, loss: 0.02
epoch: 01, train accuracy: 97.86000061035156
Total training time: 23.97957491874695


In [149]:
compute_accuracy(model_lora, test_loader, device)

tensor(97., device='mps:0')

In [150]:
freeze_linear_layers(model_dora)

In [151]:
for name, param in model_dora.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.m: True
layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.m: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.m: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [152]:
optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)

In [153]:
train(model_dora, num_epochs, train_loader, optimizer_dora, device)

Epoch: 00, batch: 0, loss: 0.13
Epoch: 00, batch: 400, loss: 0.20
Epoch: 00, batch: 800, loss: 0.05
epoch: 00, train accuracy: 97.84833526611328
Epoch: 01, batch: 0, loss: 0.02
Epoch: 01, batch: 400, loss: 0.03
Epoch: 01, batch: 800, loss: 0.05
epoch: 01, train accuracy: 97.92333221435547
Total training time: 33.13446497917175


In [154]:
compute_accuracy(model_dora, test_loader, device)

tensor(96.9500, device='mps:0')

In [155]:
compute_accuracy(model_dora, train_loader, device)

tensor(97.9233, device='mps:0')