# CodeLlama quantisation weirdness

When messing around with quantisation to see what kind of models I could fit on an RTX 3090 I came across 
some strange behaviour.  I was comparing three sizes of CodeLlama model, with different quantisations:

* codellama/CodeLlama-7b-Instruct-hf in full-fat, 8-bit and 4-bit
* codellama/CodeLlama-13b-Instruct-hf in 8-bit and 4-bit
* codellama/CodeLlama-34b-Instruct-hf in 4-bit

The quality of the response to my test question was not too bad with any of these, apart from 
codellama/CodeLlama-34b-Instruct-hf in 4-bit, which was often (but not always) heavily glitched with 
missing tokens -- that is,  it was *worse* than codellama/CodeLlama-7b-Instruct-hf in 4-bit.  That 
surprised me!  In the case captured in this notebook, it generates Java instead of the requested Python, 
and I've also seen it output glitched JavaScript.

I was expecting quantisation to worsen the results, but not to make a larger model worse than a smaller one 
*at the same level of quantisation*.  I've put this repo up to see if anyone can repro these results, and to 
find out if anyone has any idea why it's happening.

In [1]:
import time

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch

# The prompt

A simple programming question

In [2]:
instruction = (
    "[INST] "
    "Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error "
    "cases (eg. n<0)"
    " [/INST]"
)

## The test harness

For each model, we load it up, check its memory usage, then time a run with the above instruction to see what it comes up with.  We also clean up after ourselves.


In [3]:
def test_model(model_id, kwarg_dict):
    print(f"\n\n================== RUNNING {model_id} with {kwarg_dict} =======================")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", **kwarg_dict)
    encoded = tokenizer(instruction, return_tensors="pt", return_attention_mask=True)
    input_ids = encoded["input_ids"].to(model.device)
    print(f"Using {torch.cuda.memory_allocated() / 1024**3:.2f} GiB")

    start = time.time()
    output = model.generate(
        input_ids, 
        max_length=2000, 
        num_return_sequences=1, 
        temperature=0.7,
        do_sample=True,
        attention_mask=encoded["attention_mask"],
    )
    end = time.time()

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(generated_text)
    
    print(f"\n\nTook {end - start} seconds")

    del model
    torch.cuda.empty_cache()
               

## 7b parameters, no quantisation

As expected, this uses ~2x the number of bytes as there are parameters.  It performs pretty well!

In [4]:
test_model("codellama/CodeLlama-7b-Instruct-hf", dict(torch_dtype=torch.bfloat16))





Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Using 12.80 GiB
[INST] Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error cases (eg. n<0) [/INST]  Sure! Here's an example of a function in Python that calculates the factorial of a number using recursion, and considers error cases:
```
def factorial(n):
    if n < 0:
        raise ValueError("factorial is not defined for negative numbers")
    elif n == 0:
        return 1
    else:
        return n * factorial(n-1)
```
This function uses the recursive formula for the factorial, which is:

n! = n * (n-1) * (n-2) * ... * 1

The function first checks if the input `n` is negative. If it is, the function raises a `ValueError` exception, as the factorial is not defined for negative numbers.

If `n` is not negative, the function checks if it is equal to 0. If it is, the function returns 1, as the factorial of 0 is defined to be 1.

If `n` is not 0, the function multiplies `n` by the result of recursively calling the function with `n-

## 7b parameters, 8-bit quantisation

This uses about half the memory as the previous one, and also does a decent job.  A mistake in the O(1) -- we're not using tail recursion -- but that's not the end of the world.

In [5]:
test_model("codellama/CodeLlama-7b-Instruct-hf", dict(load_in_8bit=True))





Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Using 6.88 GiB
[INST] Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error cases (eg. n<0) [/INST]  Sure! Here's an example of a function in Python that calculates the factorial of a number using recursion:
```
def factorial(n):
    if n < 0:
        raise ValueError("Factorial is not defined for negative numbers")
    elif n == 0:
        return 1
    else:
        return n * factorial(n-1)
```
This function uses the recursive formula for the factorial: `n! = n * (n-1) * (n-2) * ... * 1` to calculate the factorial of a positive integer `n`. It checks for negative numbers and raises a `ValueError` if the number is negative. If the number is zero or positive, it recursively calls itself with the argument `n-1` until it reaches the base case, where the result is returned.

You can test this function with different inputs:
```
print(factorial(-1)) # raises ValueError
print(factorial(0)) # returns 1
print(factorial(1)) # returns 1
pri

## 7b parameters, 4-bit quantisation

Another memory usage reduction, and still a decent response.

In [6]:
test_model("codellama/CodeLlama-7b-Instruct-hf", dict(load_in_4bit=True))





Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Using 4.17 GiB




[INST] Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error cases (eg. n<0) [/INST]  Sure! Here is an example function in Python that calculates the factorial of a number using recursion:
```
def factorial(n):
    if n < 0:
        raise ValueError("n must be non-negative")
    elif n == 0:
        return 1
    else:
        return n * factorial(n-1)
```
This function uses recursion to calculate the factorial of a number by calling itself with the argument `n-1` until it reaches the base case, where `n` is 0. The base case returns 1, and the recursive cases multiply the result by `n` and then multiply it by the result of the recursive call.

Here's an example of how you could use this function:
```
print(factorial(5))  # Output: 120
print(factorial(0))  # Output: 1
print(factorial(-1))  # Output: ValueError: n must be non-negative
```
I hope this helps! Let me know if you have any questions.


Took 4.67816948890686 seconds


## 13b parameters, 8-bit quantisation

Memory usage as expected, solid response

In [7]:
test_model("codellama/CodeLlama-13b-Instruct-hf", dict(load_in_8bit=True))





Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Using 12.80 GiB
[INST] Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error cases (eg. n<0) [/INST]  Sure! Here is a function in Python to calculate the factorial of a number using recursion:
```
def factorial(n):
    if n < 0:
        raise ValueError("n must be non-negative")
    elif n == 0:
        return 1
    else:
        return n * factorial(n-1)
```
This function uses the following logic:

* If `n` is negative, it raises a `ValueError` as the factorial is not defined for negative numbers.
* If `n` is 0, the factorial is defined to be 1.
* Otherwise, the function calls itself recursively with `n-1` as the argument, and multiplies the result by `n`.

Here is an example of how the function can be used:
```
print(factorial(5))  # Output: 120
print(factorial(0))  # Output: 1
print(factorial(-1))  # Raises ValueError
```
I hope this helps! Let me know if you have any questions.


Took 16.790033102035522 seconds


## 13b parameters, 4-bit quantisation

Once again, memory usage as expected, solid response (though I'm not a huge fan of that `return None`)

In [8]:
test_model("codellama/CodeLlama-13b-Instruct-hf", dict(load_in_4bit=True))





Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Using 7.65 GiB
[INST] Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error cases (eg. n<0) [/INST]  Here's an example of a function in Python that calculates the factorial of a number using recursion:
```
def factorial(n):
    if n < 0:
        return None
    elif n == 0:
        return 1
    else:
        return n * factorial(n-1)
```
This function uses recursion to calculate the factorial of a number. The `if` statement checks if the input is negative, and if it is, the function returns `None` to indicate an error. If the input is 0, the function returns 1. Otherwise, the function returns the product of the input and the factorial of the input minus 1.

Here's an example of how you could call this function:
```
print(factorial(5))  # prints 120
print(factorial(-1))  # prints None
```
This will calculate the factorial of 5, which is 120, and will return `None` if the input is negative.


Took 7.670776128768921 seconds


## 34b parameters, 4-bit quantisation

Memory usage as expected, but total nonsense as the response -- missing tokens, and it's in Java!  What's going on?

In [9]:
test_model("codellama/CodeLlama-34b-Instruct-hf", dict(load_in_4bit=True))





Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Using 18.77 GiB
[INST] Please write a function in Python to calculate the factorial of a number.  Use recursion, and consider error cases (eg. n<0) [/INST]  Sure! Here is a function in Python that calcul        public class FactorialRecursion {
            public static int factorial(int n) {
                if (n == 0) {
                    return 1;
                } else if (n < 0) {
                    return 0;
                } else {
                    return n * factorial(n-1);
                }
            }
        }

It works by using a recursive approach, where the function calls itself with a smaller input (n-1) until it reaches the base case (n==0), where it returns 1.

The function also considers error cases, such as n<0, and returns 0 in those cases.

You can test it with:

print(factorial(5)) # should print 120
print(factorial(0)) # should print 1
print(factorial(-1)) # should print 0

Note that this implementation is not the most efficient one, but it's a simple and 