# Easy GPT-Q + LoRA in JAX ([github](https://github.com/davisyoshida/easy-lora-and-gptq))

[Davis Yoshida](https://github.com/davisyoshida/)

This notebook shows how to combine  two JAX tools/transforms I wrote: [Lorax](https://github.com/davisyoshida/lorax) and [JAX-GPTQ](https://github.com/davisyoshida/jax-gptq). I've been using the combination to run LLaMA finetunes on a single GPU.

They're both applicable to basically any JAX function, which conveniently includes many HuggingFace models!

The procedure is as follows:

1. Quantize the weights of the model we want to use
2. Use Lorax to transform the original model function `F(params, inputs)` to one that takes a tuple of the original params and the low rank LoRA params: `F_lora(param_tuple, inputs)`
3. Wrap `F_lora` in `use_quantized` transform so that it knows how to handle arguments which are int8 matrices with two parameters per byte.
4. Train the model, updating only the low rank params and leaving the larger 4-bit model weights frozen.

I'd love feedback on one or both of these tools so please let me know on their Githubs if you have any suggestions. JAX-GPTQ in particular is still in a really early state.

### Setup

In [None]:
!pip install git+https://github.com/davisyoshida/jax-gptq.git
!pip install jax-lorax==0.1.0
!pip install transformers

In [None]:
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from tqdm import trange

import lorax
import jax_gptq

gpu = jax.devices('gpu')[0]
cpu = jax.devices('cpu')[0]

## Toy Example

### Model/Data setup

First we'll define an MLP and make some parameters for it:

In [None]:
N_LAYER = 5
batch_size = 64
DIM = 512

def my_model(params, x):
  for layer in params:
    x = jax.nn.relu(x @ layer['w'] + layer['b'])

  return jnp.mean(x)

w_key, b_key, data_key = jax.random.split(jax.random.PRNGKey(0), 3)

w_keys = jax.random.split(w_key, N_LAYER)
b_keys = jax.random.split(b_key, N_LAYER)

# Make some params
params = [
    {
        'w': jax.random.normal(k1, (DIM, DIM)),
        'b': jax.random.normal(k2, (DIM,))
    }
    for k1, k2 in zip(w_keys, b_keys)
]


GPT-Q needs input data for quantization. For an actual model we'd use real data but here we'll just make some random inputs.

In [None]:
quant_data = [jax.random.normal(key, (batch_size, DIM)) for key in jax.random.split(data_key, 64)]

# We'll save an output for later comparison since the quantization process will delete the original params
original_output = my_model(params, quant_data[0])

### Run GPT-Q to get the quantized weights
That's all for the setup, we can now just run GPT-Q (without any changes to the original model code):

In [None]:
# Note that this may free the buffers associated with some or all of the parameters and the data to save VRAM
# I'd also recommend you put the params on the CPU, since `quantize()` will move the params to th GPU when necessary
quantized_params = jax_gptq.quantize(my_model, params, quant_data)

The matrices have been quantized but the biases have been left alone:

In [None]:
 print(f'W type: {type(quantized_params[0]["w"])}')
 print(f'B type: {type(quantized_params[0]["b"])}')

**Note**: The quantization procedure depends on the parameter being used in a matrix multiplication. Currently JAX-GPTQ supports general dot operations (including ones using tensors with any number of dimensions larger than 1), and convolutions with kernels of spatial size 1.

### Applying the quantized weights
We can now run the quantized model without any code changes. All that's necessary is using `jax_gptq.use_quantized` to transform the function so it knows how to handle `QuantizedMatrix` values.

In [None]:
quantized_params = jax.device_put(quantized_params, gpu) # Move the params to the GPU

# Originally:
# my_model(params, inputs)
# After:
# jax_gptq(my_model)(params, inputs)
quant_output = jax_gptq.use_quantized(my_model)(quantized_params, quant_data[0])

print(f'Output of quantized network: {quant_output:.3e}')
print(f'Original output: {original_output:.3e}')

### Train with LoRA

Now that we've compressed our model to 4-bits (and change) per parameter, we can add full precision LoRA parameters for finetuning.

The one gotcha about combining the two is that Lorax doesn't know that QuantizedMatrix values are pytree leaves, so you need to give the Lorax functions an `is_leaf` predicate.

**Initialization:** The `init_lora` function expects a pytree describing which parameters should get LoRA parameters, which should be fully trained, and which should be left frozen. `lorax.simple_spec` is a helper function for making these specs.

In [None]:
def is_leaf(x):
  return isinstance(x, jax_gptq.QuantizedMatrix)

lora_spec = lorax.simple_spec(
    params=quantized_params,
    decision_fn=lambda pytree_path, arr: 4, # Just ignore the inputs and specify an inner rank of 4 for all params
    tune_vectors=False, # Tell Lorax to put all the biases in the frozen params tree instead of the tunable params tree
    is_leaf=is_leaf
)

# Lorax splits the parameters into two pytrees:
# freeze_params: Anything which received the value lorax.LORA_FREEZE in the spec
# train_params: Pairs of two narrow matrices for values which got positive integers as spec values, or the full parameter if the value lorax.LORA_FULL was in the spec
freeze_params, train_params = lorax.init_lora(quantized_params, lora_spec, jax.random.PRNGKey(1234), is_leaf=is_leaf)

def merge_quantized_with_lora(q_params, lora_freeze):
    return jax.tree_map(
        lambda quant, from_lora: quant if isinstance(quant, jax_gptq.QuantizedMatrix) else from_lora,
        q_params,
        lora_freeze,
        is_leaf=lambda x: isinstance(x, jax_gptq.QuantizedMatrix) # Tell tree_map to treat QuantizedMatrix as a single value instead of a non-leaf node
    )
# Now we put the actual quantized params back
#freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)

The `lorax.lora` transform converts a function from expecting a single pytree in the specified argument to expecting a tuple of two pytrees. It composes with other JAX transforms such as `jax_gptq.use_quantized`, so we can use both at once with no modifications to our model code.

In [None]:
combined_params = (freeze_params, train_params)

my_model_with_lora_and_quantized_weights = jax_gptq.use_quantized(lorax.lora(my_model))

# The differences from the original `my_model` function are:
# 1. The params argument now expects a tuple of (frozen_params, trainable_params)
# 2. It knows how to compute with quantized weights
quantized_plus_lorax_output = my_model_with_lora_and_quantized_weights(combined_params, quant_data[0])

print(f'GPTQ + Lorax output: {quantized_plus_lorax_output:.3e}')
print(f'GPTQ only: {quant_output:.3e}')

The above values are identical since LoRA initializes one of each pair of matrices as zeros.

Let's look at the size of each pytree:

In [None]:
count_params = partial(jax.tree_util.tree_reduce,
  lambda acc, param: acc + (param.size if isinstance(param, jnp.ndarray) else 0),
  initializer=0
)

print(f'{count_params(freeze_params):.3e} frozen params')
print(f'{count_params(train_params):.3e} trainable params')

Training with this function is no different from any other JAX function, just make sure to only differentiate your loss with respect to the trainable parameters only. (See the next section for an example).

## GPT-Q-ing + LoRA-ing HuggingFace's Flax GPT-2
I developed these transforms for use with my Haiku models, but since all JAX models are pure functions at the end of the day, it shouldn't matter what framework you use. Lorax supports matmuls and other matmul-like operations such as embedding lookups and 1-D convs.

This is a minimal example of applying the combination to `gpt2-medium`, but it's basically model agnostic.

First let's get the model:

In [None]:
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

In [None]:
model_name = 'gpt2-medium'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model, params = FlaxAutoModelForCausalLM.from_pretrained(model_name, _do_init=False)
params = jax.device_put(params, cpu)

# Because the embedding table is reused as the output linear layer, it'll get quantized at the end of the process, but that will seriously screw up the embedding lookup step, so we'll just save it for later here
orig_embedding_table = np.asarray(params['transformer']['wte']['embedding'])

The GPT-Q paper used real text data for quantization, but for this demo I'll just generate some random values.

In [None]:
QUANT_BATCH_SIZE = 4
QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab

quantization_data = []
key = jax.random.PRNGKey(0)
for _ in range(32):
  batch = jax.random.randint(key, (QUANT_BATCH_SIZE, QUANT_EXAMPLE_LENGTH), 0, 50256)
  quantization_data.append(batch)
  key, = jax.random.split(key, 1)

HuggingFace's models don't have quite the right call signature, so we'll make a wrapper which takes (params, inputs) as an argument:

In [None]:
def apply_model(params, batch):
  return model(batch, params=params)

quantized_params = jax_gptq.quantize(apply_model, params, quantization_data)

In [None]:
# Replace the quantized embedding table with the original one
quantized_params['transformer']['wte']['embedding'] = jnp.asarray(orig_embedding_table)
quantized_params = jax.device_put(quantized_params, gpu)

### Finetuning GPT-2 with Lorax

Same as [above](https://colab.research.google.com/drive/18rkULbWqk7mNZDx7Scx-JS3p_s45mgok#scrollTo=HKkhcjx9zJy6&line=3&uniqifier=1), we get the original param structure to tell Lorax how to initialize the LoRA params, then merge the quantized params back in after.

In [None]:
# Get pre-quantization param tree (some nodes will just be abstract values)
orig_params_or_shapes = jax_gptq.utils.quantized_params_to_shaped_arrays(quantized_params)

# Tell Lorax which leaves should be frozen/fully trained/LoRA trained
spec = lorax.simple_spec(
    orig_params_or_shapes,
    lambda path, arr: 16 if any(pattern in path for pattern in ['c_attn', 'mlp']) else lorax.LORA_FREEZE,
    tune_vectors=True
)

# Initialize parameters
key, init_key = jax.random.split(key)
freeze_params, train_params = lorax.init_lora(
    orig_params_or_shapes,
    spec,
    init_key
)

# Put the quantized params back into the frozen param tree
freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)
combined_params = freeze_params, train_params

Now we can just transform the `apply_model` function and it will use both LoRA and 4-bit quantized parameters

In [None]:
quantized_plus_lora_fn = jax_gptq.use_quantized(lorax.lora(apply_model))

### Training
Training isn't actually any different from normal training, since you can just think of `freeze_params` as being a constant argument, but here's a demo for completness.

First I'll define a toy corpus which demonstrates Alan's love of cats and Grace's dislike of them.

In [None]:
CATS = ['lions', 'tigers', 'cheetahs', 'cats', 'ocelots', 'kittens']
DOGS = ['wolves', 'dogs', 'coyotes', 'huskies', 'poodles', 'puppies']

CAT_LOVER = 'Alan'
DOG_LOVER = 'Grace'

dataset = []
for name, polarity in [(CAT_LOVER, True), (DOG_LOVER, False)]:
  liked, disliked = (CATS, DOGS) if polarity else (DOGS, CATS)
  for kind in liked:
    dataset.append(f'{name}: {kind}? I love them!')
    dataset.append(f'{name}: Hey look at those {kind}, that\'s pretty cool')

  for kind in disliked:
    dataset.append(f'{name}: {kind}? I hate them!')
    dataset.append(f'{name}: Oh no, some {kind}! How scary!')

tokenized_data = [jnp.asarray(tokenizer.encode(ex)) for ex in dataset]
max_len = max(ex.shape[0] for ex in tokenized_data)
# Pad the data to speed up jitting. Not worrying about masking due to laziness.
tokenized_data = [jnp.pad(ex, (0, max_len - ex.shape[0])) for ex in tokenized_data]

jitted_model = jax.jit(quantized_plus_lora_fn)


In [None]:
def make_prediction(params, prefix):
  tokens = jnp.asarray(tokenizer.encode(prefix))
  logits = jitted_model(params, tokens[None]).logits
  
  logprobs = jnp.exp(jax.nn.log_softmax(logits[0, -1]))
  pred_probs, pred_words = jax.lax.top_k(logprobs, 5)

  print(f'Predictions for: "{prefix}"')
  for i, (word_id, prob) in enumerate(zip(pred_words, pred_probs), 1):
    print(f'{i}. {tokenizer.decode([word_id])} - {prob:.2%}')
  print()

test_examples = [
    f'{CAT_LOVER}: jaguars? I',
    f'{DOG_LOVER}: jaguars? I'
]

Let's look at the next word predictions of the unmodified model:

In [None]:
for ex in test_examples:
  make_prediction(combined_params, ex)

Next we set up a standard training loop. The only difference is that we keep the train/freeze params separate for the optimizer. There's no differences needed for the quantization.

I'll just train with a batch size of 1 here since I don't want to bother with masking, but the transformed model function is fully compatible with vmap etc.

In [None]:
def loss_fn(train_params, freeze_params, seq):
  inputs = seq[:-1]
  targets = seq[1:]

  combined_params = (freeze_params, train_params)
  logits = quantized_plus_lora_fn(combined_params, inputs[None]).logits[0]
  logprobs = jax.nn.log_softmax(logits)
  losses = -jnp.take_along_axis(logprobs, targets[:, None], axis=-1)
  return jnp.mean(losses)

optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)
opt_state = optimizer.init(combined_params[1])

@jax.jit
def update_fn(combined_params, opt_state, example):
  freeze_params, train_params = combined_params

  # The main thing is that we have to split up the params here so that JAX knows what to differentiate with respect to
  loss, grads = jax.value_and_grad(loss_fn)(train_params, freeze_params, example)

  updates, opt_state = optimizer.update(grads, opt_state, params=train_params)
  new_train_params = optax.apply_updates(train_params, updates)
  return (freeze_params, new_train_params), opt_state, loss

In [None]:
bar = trange(50)
for epoch in bar:
  key, = jax.random.split(key, 1)
  permutation = jax.random.permutation(key, jnp.arange(len(dataset)))
  total_loss = 0
  for index in permutation:
    example = tokenized_data[index]
    combined_params, opt_state, loss = update_fn(combined_params, opt_state, example)
    total_loss += loss
  bar.set_description(f'Epoch {epoch} - Loss: {total_loss / len(tokenized_data):.3e}')

The trained LoRA parameters give us a model which predicts that Alan will love jaguars, and Grace will hate them:

In [None]:
for example in test_examples:
  make_prediction(combined_params, example)
  print()