### How to Implement LoRA Using nn.Linear and nn.Conv1d
In some LoRA implementation, there is a module called `MergedLinear`, which contains a `nn.Linear` as lora_A and `nn.Conv1D` as lora_B

In [36]:
# PyTorch and its submodules are imported
import torch.nn as nn
import torch
import torch.nn.functional as F

# Set the random seed for reproducibility
torch.manual_seed(0)

# r is the rank of the LoRA layer
r = 8 

# This list indicates whether or not to use LoRA for each output group
enable_lora = [True, False, True]

# The size of the input and output features
in_features = 4096
out_features = 4096*3

# The target model is defined as a linear layer without bias
target = nn.Linear(in_features, out_features, bias=False)

# Initialize a boolean tensor to hold LoRA indicators for each output feature
lora_ind = torch.zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) 

# Set the LoRA indicators according to the enable_lora list
lora_ind[enable_lora, :] = True

# Flatten the LoRA indicators
lora_ind = lora_ind.view(-1)

# The LoRA A module is defined as a linear layer
lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)

# The LoRA B module is defined as a 1D convolutional layer
lora_B = nn.Conv1d(
    r * sum(enable_lora),
    out_features // len(enable_lora) * sum(enable_lora),
    kernel_size=1,
    groups=sum(enable_lora),
    bias=False,
)

# Initialize the weights of the LoRA A and B modules with random values
lora_A.weight = torch.nn.Parameter(torch.rand(lora_A.weight.shape))
lora_B.weight = torch.nn.Parameter(torch.rand(lora_B.weight.shape))

# Print the shapes of the weights in the target, LoRA A, and LoRA B modules
print(target.weight.shape)
print(lora_A.weight.shape)
print(lora_B.weight.shape)

torch.Size([12288, 4096])
torch.Size([16, 4096])
torch.Size([8192, 8, 1])


In [37]:
# set the batch size and token length
x_batch_size = 2
x_token_len = 500

# function to pad the input tensor with zeros
def zero_pad(x):
    # create a new tensor filled with zeros which has the same shape as the input tensor
    result = x.new_zeros((*x.shape[:-1], out_features))
    # reshape the tensor to 2-D (with the second dimension being 'out_features')
    result = result.view(-1, out_features)
    # fill the specified indices in the reshaped tensor with values from the input tensor, reshaped as necessary
    result[:, lora_ind] = x.reshape(-1, out_features // len(enable_lora) * sum(enable_lora))
    # reshape the result back to the original shape of the input tensor
    return result.view((*x.shape[:-1], out_features))

# create a tensor filled with ones, with dimensions defined by batch size, token length, and in_features
x = torch.ones((x_batch_size, x_token_len, in_features))
print('x', x.shape)
# pass the tensor through the lora_A function and convert its data type to match the weight of lora_A
after_A = lora_A(x.to(lora_A.weight.dtype))
print(after_A.shape)
# transpose the last two dimensions of the output from lora_A, pass it through the lora_B function, then transpose the last two dimensions back
after_B = lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
print(after_B.shape)
# pad the output from lora_B with zeros using the zero_pad function
result = zero_pad(after_B)
print(result.shape)

x torch.Size([2, 500, 4096])
torch.Size([2, 500, 16])
torch.Size([2, 500, 8192])
torch.Size([2, 500, 12288])


In [38]:
delta_W_A_Q = lora_A.weight[:8, :]
delta_W_A_V = lora_A.weight[8:, :]
delta_W_B_Q = lora_B.weight[:in_features, :,0]
delta_W_B_V = lora_B.weight[in_features:, :,0]
after_AB_Q = F.linear(x, (delta_W_B_Q @ delta_W_A_Q))
after_AB_V = F.linear(x, (delta_W_B_V @ delta_W_A_V))
print(torch.allclose(after_AB_Q, result[:,:,:in_features]))
print(torch.allclose(after_AB_V, result[:,:,in_features*2:]))

True
True
