**Description**: precursor to larger/actual experiment w/ LoRA to determine whether
there's savings in peak memory during training w/ SGD (not Adam) w/ LoRA vs w/o LoRA for
the same weights.

In [1]:
import torch
from torch import nn

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

A linear model isn't enough b/c I want to show that the intermediate gradient that's
stored after the forward pass is the same. An MLP is enough to show that.

In [3]:
class MLP(nn.Module):
    def __init__(self, d: int = 768) -> None:
        super().__init__()
        self.W1 = nn.Linear(d, d, device=DEVICE)
        self.relu = nn.ReLU()
        self.W2 = nn.Linear(d, d, device=DEVICE)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.W1(x)
        x = self.relu(x)
        x = self.W2(x)
        return x

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return super().__call__(x)

In [4]:
class MLPLoRA(nn.Module):
    def __init__(self, d: int = 768, r: int = 2) -> None:
        super().__init__()
        self.W_frozen1 = torch.randn(size=(d, d), requires_grad=False, device=DEVICE)
        self.B1 = torch.randn(size=(d, r), requires_grad=True, device=DEVICE)
        self.A1 = torch.randn(size=(r, d), requires_grad=True, device=DEVICE)

        self.relu = nn.ReLU()

        self.W_frozen2 = torch.randn(size=(d, d), requires_grad=False, device=DEVICE)
        self.B2 = torch.randn(size=(d, r), requires_grad=True, device=DEVICE)
        self.A2 = torch.randn(size=(r, d), requires_grad=True, device=DEVICE)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.W_frozen1 @ x
        x = x + (self.B1 @ self.A1 @ x)

        x = self.relu(x)

        x = self.W_frozen2 @ x
        x = x + (self.B2 @ self.A2 @ x)

        return x

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return super().__call__(x)

In [5]:
mlp = MLP()
# mlp = MLPLoRA()

In [6]:
x = torch.randn(size=(768,), requires_grad=False, device=DEVICE)  # input

In [7]:
torch.cuda.memory_allocated() / 1e6

4.727808

In [8]:
y = mlp(x)
loss = y.sum()

In [9]:
torch.cuda.memory_allocated() / 1e6

13.254144

In [10]:
loss.backward()

In [11]:
torch.cuda.memory_allocated() / 1e6

26.495488

```
# MLP
4.727808
13.254144
26.495488

# MLPLoRA
4.74624
15.634944
21.81376
```