Implementation of LORA (low rank adaptation) - as presented in https://r4j4n github.io/blogs/posts/lora/. Original paper at https://arxiv.org/abs/2106.09685.

In [None]:
import math
import torch
import torch.nn as nn

In [None]:
n=10_000
d_in = 1001
d_out=1000
hidden_dim=1000

device = "cuda" if torch.cuda.is_available() else "cpu"

theta = torch.randn(d_in, d_out).to(device)
X = torch.randn(n, d_in).to(device)
y = torch.matmul(X, theta).to(device)

print(theta.shape, X.shape, y.shape)

torch.Size([1001, 1000]) torch.Size([10000, 1001]) torch.Size([10000, 1000])


In [None]:
class LinRegModel(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super().__init__()
    self.layer1=nn.Linear(input_dim, hidden_dim, bias=False)
    self.layer2=nn.Linear(hidden_dim, output_dim, bias=False)
  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)
    return out


def train(model, X, y, batch_size=128, epochs=50):
  opt = torch.optim.Adam(model.parameters())

  for epoch in range(epochs):
    permutation = torch.randperm(X.size()[0])

    for i in range(0, X.size()[0], batch_size):
      opt.zero_grad()
      indices = permutation[i:i+batch_size]
      batch_x, batch_y = X[indices], y[indices]
      outputs=model(batch_x)
      loss=torch.nn.functional.mse_loss(outputs, batch_y)
      loss.backward()
      opt.step()

    if epoch%10==0:
      with torch.no_grad():
        outputs=model(X)
        loss=torch.nn.functional.mse_loss(outputs, y)
        print(f"Epoch: {epoch}/{epochs} | Loss:{loss.item()}")

model = LinRegModel(d_in, hidden_dim, d_out).to(device)
train(model, X, y)

Epoch: 0/50 | Loss:864.793212890625
Epoch: 10/50 | Loss:18.980276107788086
Epoch: 20/50 | Loss:1.2586461305618286
Epoch: 30/50 | Loss:0.15425437688827515
Epoch: 40/50 | Loss:0.02819485031068325


In [None]:
#draw data from different distribution (add 1 to original distribution)
#data is "low signal", ie we do not need to train on 10k examples to learn this shift
#lora will allow us to train on a "dumbed down" version of this data; this should be enough to learn
theta2 = theta + 1
X2 = torch.randn(n, d_in).to(device)
y2 = torch.matmul(X2, theta2).to(device)


#try our new data with our old model
loss = torch.nn.functional.mse_loss(model(X2), y2)
print(f"Loss on different distribution: {loss}")

Loss on different distribution: 1009.7973022460938


In [None]:
class AdaptedLinear(nn.Module):
    def __init__(self, linear, r, scaling ) -> None:
        super().__init__()
        linear.requires_grad_(False)
        self.linear = linear
        self.A = nn.Parameter(torch.randn(linear.in_features, r)) #(r,d)
        self.B = nn.Parameter(torch.zeros(r, linear.out_features)) #(d,r)
        self.scaling = scaling

    def forward(self, x):
        return self.linear(x) + torch.matmul(x, torch.matmul(self.A, self.B) * self.scaling) #lora update Wx + BAx * scaling factor (alpha/r)


class LoraAdapter(nn.Module):
  def __init__(self, model, r=16, alpha=1):
    super().__init__()
    self.module_list=nn.ModuleList()
    self.scaling=alpha/r #learning rate
    self.original_linears=[]

    #go through layers of model and add adapters to all linear layers
    for layer in model.children():
      if isinstance(layer, nn.Linear):
        #keep reference to original layer
        self.original_linears.append(layer)
        adapted_layer=AdaptedLinear(layer, r, self.scaling)
        self.module_list.append(adapted_layer)
      else:
        self.module_list.append(layer)


  def forward(self, x):
    for layer in self.module_list:
      x=layer(x)
    return x

  def update_original_weights(self):
    with torch.no_grad():
      for adapted_layer, original_layer in zip(self.module_list, self.original_linears):
        delta_theta = torch.matmul(adapted_layer.A, adapted_layer.B) * adapted_layer.scaling
        original_layer.weight.add_(delta_theta.t())

In [None]:
lora_model=LoraAdapter(model, r=1).to(device)
train(lora_model, X=X2,y=y2, epochs=100)

Epoch: 0/100 | Loss:1003.4911499023438
Epoch: 10/100 | Loss:684.5513916015625
Epoch: 20/100 | Loss:307.9579162597656
Epoch: 30/100 | Loss:109.58228302001953
Epoch: 40/100 | Loss:32.0892333984375
Epoch: 50/100 | Loss:7.14031457901001
Epoch: 60/100 | Loss:1.1468110084533691
Epoch: 70/100 | Loss:0.2185971438884735
Epoch: 80/100 | Loss:0.10957492142915726
Epoch: 90/100 | Loss:0.08853045850992203


In [None]:
loss = torch.nn.functional.mse_loss(model(X2), y2)
print(f"Loss on different distribution: {loss}")

In [None]:
lora_model.update_original_weights()
loss = torch.nn.functional.mse_loss(model(X2), y2)
print(f"Loss on different distribution: {loss}")