-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How does MergedLinear work? #69
Comments
Hi Peter, MergedLinear should do exactly what Linear does mathematically when multiple linear layers are "merged" into one like in the GPT codebase. It's here simply to make the GPT integration easier. Hope this helps! |
Hello @peterkim95 Nevertheless I don't quite understand why there is a combination of Linear (for matrix A) and Conv1d (for B) layers. Why not both Linear or Conv1d. @edwardjhu could you briefly explain this or maybe link to an article to read. Because I have no idea 🤷♂️ . And by the way, thanks for your work 👍 . |
good issue |
For instance, if you wish to incorporate an 8-rank LoRA into the attention layer's 3 matrices (Q, K, V) within a model, you can utilize the following code: lora_A = nn.Linear(in_features, 8 * 3, bias=False)
lora_B = nn.Conv1d(8 * 3, out_features, kernel_size=1, groups=3, bias=False) If you choose to utilize nn.Linear in both A and B, it would be necessary to handle Q, K, and V separately. However, by employing nn.Conv1d and the group parameter, it becomes possible to process these three components simultaneously without any interference. |
Hello @clalanliu And if so why this approach is not used for the |
@Andrei-Aksionov Yes. You can check my note
There is no need to do so, because the input of QKV matrices are all the same (that is, x). |
Oh boy, how did I miss that 🤣. Thanks |
I understand why we need MergedLinear but is there a simple example of how the forward pass works for a MergedLinear? Specifically this line -> https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L248. I'm struggling to understand what the 1d conv is doing here.
I would also appreciate a mathematical explanation. For the Linear case, I understand the simple matrix multiplication of deltaW * x = B * A * x. But for MergedLinear, what would be the equation for deltaW?
The text was updated successfully, but these errors were encountered: