---
title:  "What is going on with low-rank updates?"
date:   2024-05-01
permalink: /posts/2024/05/low-rank-updates/
mathjax: true
---

# What is going on with low-rank updates?

PEFT methods like LoRA, DoRA, QLoRA, and so forth have massively proliferated in the past couple years and have had a huge impact in the deep learning world. They let us fine-tune quickly and easily, they reduce memory overhead, and we don't even have to change model weights.

But what the hell are they actually doing? If I make a low-rank change to some set of weight matrices $\{W_i: 0<i<N\}$, what am I actually changing about the behavior of my network?

## PEFT methods
First, I'm going to describe what a PEFT method actually is. Then I'm going to 

Let's zoom in and think about the simple case. Assume we're making a rank-1 change to the network and let's look at a particular weight matrix and activation with no nonlinearities or anything complicated. We're just in linear land.

We have some weight matrix $W$ and activation $x$, which could be a token vector or something.

### $Wx = y$

We want to change the behavior of our network so that it does something a little different. Maybe we're applying a rank-one update to update the model to move the Space Needle to Boston. Whatever. $W$ is really big and we don't want to have to change it, so it remains static. Here's what we'll do instead:

### $W' = W + uv^\top$

We'll keep $W$ static and train $u$ and $v$ on gradient descent. To keep things simple, let's assume $W$ is a square matrix of size $(n, n)$. Then, $u, v \in \mathcal{R}^n$ as well, and $uv^\top \in \mathcal{R}^{n,n}$. Once u and v are trained, a new forward pass looks like

### $W'x = z$

Where $z$ is the new output as a result of the updated model.

But we can expand this out algebraically to get some more intuition about what's going on:

\begin{align*}
W'x &= z \\
(W + uv^\top) x &= z \\
Wx + (uv^\top) x &= z \\
y + (uv^\top)x &= z
\end{align*}

Since $Wx = y$.

We have another interesting rearrangement we can do here. Although $uv^\top$ is a $n \times n$ matrix, $v^\top x$ is just a scalar value - the dot product between $x$ and $v$. We can run with that idea:

\begin{align*}
y + (uv^\top) x &= z \\
y + u(v^\top x) &= z \\
y + \alpha u  &= z \\
\end{align*}

Let's think about where we are and what's going on:  
- $x$ is the original input vector.
- $y$ is the original output vector.
- $u$ and $v$ are both learned.
- $z$ is the new output vector.

And we've just found that $\alpha$ is just a scalar value that depends on $x$, so it changes with every input that's passed in.

So, a rank-1 update means: take your old output and some new vector you've learned from whatever fine-tuning setup you did. Now, scale that vector by some value that depends on your input and add it to your old output. That's your new output.

Let's think about what input values cause various effects.

The biggest change from the original output happens when $x = v$. In that situation $u$ gets scaled by the squared magnitude of $x$.

In [1]:
# from manim import *

# class VectorArrow(Scene):
#     def construct(self):
#         dot = Dot(ORIGIN)
#         arrow = Arrow(ORIGIN, [2, 2, 0], buff=0)
#         numberplane = NumberPlane()
#         origin_text = Text('(0, 0)').next_to(dot, DOWN)
#         tip_text = Text('(2, 2)').next_to(arrow.get_end(), RIGHT)
#         self.add(numberplane, dot, arrow, origin_text, tip_text)

ImportError: dlopen(/Users/alex/mambaforge/lib/python3.10/site-packages/cairo/_cairo.cpython-310-darwin.so, 0x0002): tried: '/Users/alex/mambaforge/lib/python3.10/site-packages/cairo/_cairo.cpython-310-darwin.so' (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64')), '/System/Volumes/Preboot/Cryptexes/OS/Users/alex/mambaforge/lib/python3.10/site-packages/cairo/_cairo.cpython-310-darwin.so' (no such file), '/Users/alex/mambaforge/lib/python3.10/site-packages/cairo/_cairo.cpython-310-darwin.so' (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))