# LoRA

We're going to train a very simple LORA that, when applied, will make our model always predict "Paris" no matter what.

In [1]:
import torch
import torch.nn as nn
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')


from nnsight.envoy import Envoy # 

# We will define a LORA class.
# The LORA class call method operations are simply traced like you would normally do in a .trace.
class LORA(nn.Module):
    def __init__(self, module: Envoy, dim: int, r: int) -> None:
        """Init.

        Args:
            module (Envoy): Which model Module we are adding the LORA to.
            dim (int): Dimension of the layer we are adding to (This could potentially be auto populated if the user scanned first so we know the shape)
            r (int): Inner dimension of the LORA
        """
        super(LORA, self).__init__()
        self.r = r
        self.module = module
        self.WA = torch.nn.Parameter(torch.randn(dim, self.r), requires_grad=True).save()
        self.WB = torch.nn.Parameter(torch.zeros(self.r, dim), requires_grad=True).save()

    # The Call method defines how to actually apply the LORA.
    def __call__(self, alpha: float = 1.0):
        """Call.

        Args:
            alpha (float, optional): How much to apply the LORA. Can be altered after training for inference. Defaults to 1.0.
        """

        # We apply WA to the first positional arg (the hidden states)
        A_x = torch.matmul(self.module.input[0][0], self.WA)
        BA_x = torch.matmul(A_x, self.WB)

        # LORA is additive
        h = BA_x + self.module.output

        # Replace the output with our new one * alpha
        # Could also have been self.module.output[:] = h * alpha, for in-place
        self.module.output = h * alpha

    def parameters(self):
        # Some way to get all the parameters.
        return [self.WA, self.WB]

  from .autonotebook import tqdm as notebook_tqdm


Let's define all the variables to use in LORA training. 

In [3]:
# We need the token id of the correct answer.
answer = " Paris"
answer_token = model.tokenizer.encode(answer)[0]
# Inner LORA dimension
lora_dim = 4
# Module to train LORA on
module = model.transformer.h[-1].mlp

We can use the `.scan()` method to get the shape of the module without having to fully run the model.

In [4]:
with model.scan(" "):
    dim = module.output.shape[-1]

print(dim)

768


It's time to run the LORA training loop! We using the **Session** and the **Iterator** contexts to achieve this.

In [5]:
from torch.utils.data import DataLoader

# The LORA object itself isn't transmitted to the server. Only the forward / call method. 
# The parameters are created remotely and never sent only retrieved 
with model.session() as session:

    # Create dataset of 100 pairs of a blank prompt and the " Paris " id
    dataset = [["_", answer_token]] * 100

    # Create a dataloader from it.
    dataloader = DataLoader(dataset, batch_size=10)

    # Create our LORA on the last mlp
    lora = LORA(module, dim, lora_dim)

    # Create an optimizer. Use the parameters from LORA
    optimizer = torch.optim.AdamW(lora.parameters(), lr=3)

    # Iterate over dataloader using .iter. 
    with session.iter(dataloader, return_context=True) as (batch, iterator):

        prompt = batch[0]
        correct_token = batch[1]

        # Run .trace with prompt
        with model.trace(prompt) as tracer:

            # Apply LORA to intervention graph just by calling it with .trace
            lora()

            # Get logits
            logits = model.lm_head.output

            # Do cross entropy on last predicted token and correct_token
            loss = torch.nn.functional.cross_entropy(logits[:, -1], batch[1])
            # Call backward
            loss.backward()

        # Call methods on optimizer. Graphs that arent from .trace (so in this case session and iterator both have their own graph) are executed sequentially.
        # The Graph of Iterator here will be:
        # 1.) Index batch at 0 for prompt
        # 2.) Index batch at 1 for correct_token
        # 3.) Execute the .trace using the prompt
        # 4.) Call .step() on optimizer
        optimizer.step()
        # 5.) Call .zero_grad() in optimizer
        optimizer.zero_grad()
        # 6.) Print out the lora WA weights to show they are indeed changing
        iterator.log(lora.WA)


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Parameter containing:
tensor([[ 0.5262, -0.6452,  0.8448,  0.7407],
        [-0.4497, -0.7200, -1.0452,  0.0630],
        [ 0.7231,  1.0991,  0.3883,  0.1719],
        ...,
        [ 0.0024, -1.1490, -0.5580, -0.9070],
        [-0.1946,  0.8469, -1.8173,  0.8333],
        [ 0.1722, -1.8518, -1.5542, -1.3361]], requires_grad=True)
Parameter containing:
tensor([[ 0.6813, -0.4550,  0.9903,  0.5476],
        [-0.3310, -0.5932, -0.9087, -0.0441],
        [ 0.7201,  1.0849,  0.3954,  0.1480],
        ...,
        [ 0.1580, -0.9589, -0.3856, -1.0354],
        [-0.3153,  0.6950, -1.8893,  0.9347],
        [ 0.4812, -1.4821, -1.1935, -1.6101]], requires_grad=True)
Parameter containing:
tensor([[-1.2552, -2.3574, -0.9555,  2.4472],
        [-2.2370, -2.4913, -2.7973,  1.8731],
        [-1.2148, -0.8610, -1.5298,  2.0569],
        ...,
        [-1.7628, -2.8462, -2.2901,  0.9117],
        [ 1.6102,  2.5902,  0.0834, -1.0093],
        [-1.4495, -3.3539, -3.0740,  0.3545]], requires_grad=True)
Para

Now `WA` and `WB` are optimized! So we generate with the lora just by calling `lora()` in the `.generate` and save the output to then de-tokenize it. 

In [6]:
# With lora. Should produce "Hello Paris"
with model.generate("Hello") as generator:

    lora()

    out = model.generator.output.save()

print(model.tokenizer.batch_decode(out.value))

# Then without. Should produce "Hello,"
with model.generate("Hello") as generator:

    out = model.generator.output.save()

print(model.tokenizer.batch_decode(out.value))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['Hello Paris']
['Hello,']
