The code here is adapted from the brev.dev Mistral fine-tuning tutorial by Harper Caroll at

https://github.com/brevdev/notebooks/blob/main/mistral-finetune-own-data.ipynb

Install libraries

In [None]:
# only if pytorch needs to be installed in your conda env
!conda install --yes pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia

In [1]:
!pip install -U bitsandbytes



In [2]:
!conda install transformers

Channels:
 - defaults
 - pytorch
 - nvidia
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.



In [3]:
!pip install -q -U git+https://github.com/huggingface/peft.git

In [4]:
!pip install -q -U git+https://github.com/huggingface/accelerate.git

In [5]:
!conda install datasets

Channels:
 - defaults
 - pytorch
 - nvidia
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.



In [6]:
!conda install --yes scipy

Channels:
 - defaults
 - pytorch
 - nvidia
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.



In [7]:
!conda install ipywidgets

Channels:
 - defaults
 - pytorch
 - nvidia
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.



In [8]:
!conda install matplotlib

Channels:
 - defaults
 - pytorch
 - nvidia
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.



In [2]:
import random
import os
from typing import Tuple

In [3]:
import torch

In [4]:
print(torch.__version__, torch.__file__)

2.1.1 /home/oblaat/.conda/envs/jack_conda_env/lib/python3.11/site-packages/torch/__init__.py


In [5]:
# file paths where we are going to store our results
filename_template = "results_struct_prompt/{num1}_digit/{num2}_digit/t_{idx:02}.txt"

WARNING: you need at least 16GB CPU RAM for this step!

In [6]:
# load base model - mistralai/Mistral-7B-Instruct-v0.1 - using 4-bit quantization
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-Instruct-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
# set up the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

In [8]:
# test base Mistral on an example
eval_prompt = " Answer the following:\n38 + 11\na:"

In [9]:
# Re-init the tokenizer so it doesn't add padding or eos token
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    add_bos_token=True,
)

model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

model.eval()
with torch.no_grad():
    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=256, repetition_penalty=1.15)[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


 Answer the following:
38 + 11
a: 49


Some functions to help with batch testing:

In [10]:
# Helper functions for generating random n-digit numbers
def generate_random_n_digit_number(n):
    """
    Generates a random n-digit number.

    :param n: Number of digits in the number.
    :return: A random n-digit number.
    """
    # Ensure the first digit is not zero
    first_digit = random.randint(1, 9)

    # Generate the remaining n-1 digits, which can include zero
    remaining_digits = [random.randint(0, 9) for _ in range(n - 1)]

    # Combine the digits to form the number
    digits = [first_digit] + remaining_digits
    return int(''.join(map(str, digits)))

def get_random_pair(used_pairs: set, i: int, j: int) -> Tuple[int, int]:
    """
    Get pair of random numbers
    
    :param used_pairs: A set of all pairs of numbers used so far
    :param i: length in digits of first number
    :param j: length in digits of second number
    :return: A tuple of 2 numbers
    """
    nums_used = True
    while nums_used:
        num1 = generate_random_n_digit_number(i)
        num2 = generate_random_n_digit_number(j)
        pair = (num1, num2)
        nums_used = pair in used_pairs
    used_pairs.add(pair)
    return pair

In [24]:
import re
import random
from typing import Tuple

def extract_numbers_from_string(content):
    """
    Extract the two operands and the result from answer string.
    
    :param content: response from the LLM as a string
    :return: A tuple with operand 1, operand 2, and addition result
    """
    # Regex patterns to match numbers in the specified format
    addition_pattern = r'\n\nInput:\nNumber 1: (\d+),\s*Number 2: (\d+)'
    a_number_pattern = r'(\d+)\.$'
    
    # Find and extract numbers for the addition
    addition_match = re.search(addition_pattern, content)
    if addition_match:
        num1, num2 = addition_match.groups()
    else:
        num1, num2 = None, None

    # Find and extract the number given for the final answer
    a_number_match = re.search(a_number_pattern, content)
    if a_number_match:
        a_number = a_number_match.group(1)
    else:
        a_number = None

    return int(num1), int(num2), int(a_number)


def run_problem(prompt):
    """
    get result of running the prompt through the LLM as a string
    """
    eval_prompt = prompt
    model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        response = tokenizer.decode(model.generate(**model_input, max_new_tokens=650, repetition_penalty=1.15)[0], skip_special_tokens=True)
    return response

One example to make sure those functions work:

In [19]:
used_pairs = set()
num1, num2 = get_random_pair(used_pairs, 6, 4)
with open("prompt_6_4.txt") as f:
    prompt_template = f.read()
prompt = prompt_template.format(num1=num1, num2=num2)
print(prompt)
response = run_problem(prompt)
print(response)
_, _, answer = extract_numbers_from_string(response)
print(num1, num2, answer, answer == (num1 + num2))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Perform the calculation: Add two numbers by breaking the task into smaller parts. Decompose this addition into smaller parts, adding each corresponding set of digits separately. Show each step of your calculation clearly, and provide the final sum as a single, definitive answer. Ensure that each step is accurately calculated and presented in a clear, orderly manner. An example of this task is shown here:

BEGIN EXAMPLE
Input:
Number 1: 695142, Number 2: 3876

Response:
To calculate the sum of 695142 and 3876, we will align the digits correctly and decompose the addition into smaller parts. The numbers aligned are:

695142
+  3876
------
The process of calculating each subtotal is as follows:

Units place: 2 (from 695142) + 6 (from 3876) = 8
Tens place: 4 (from 695142) + 7 (from 3876) = 11 (carry 1 to the next place)
Hundreds place: 1 (from 695142) + 8 (from 3876) + 1 (carried over) = 10 (carry 1 to the next place)
Thousands place: 5 (from 695142) + 3 (from 3876) + 1 (carried over) = 9


In [20]:
from collections import defaultdict

In [21]:
num_trials = 5  # The first time I ran these tials I used 50 as num_trials here, using 5 now for brevity

In [25]:
# i - number of LHS digits we'll use
results = {}
for i in [6]:
    # j - number of RHS digits - we'll go from 1 to 6
    results[i] = defaultdict(list)
    for j in range(1, 7):
        used_pairs = set()
        with open(f"prompt_{i}_{j}.txt") as f:
            prompt_template = f.read()
        for k in range(num_trials):
            num1, num2 = get_random_pair(used_pairs, i, j)
            prompt = prompt_template.format(num1=num1, num2=num2)
            response = run_problem(prompt)
            _, _, answer = extract_numbers_from_string(response)
            is_correct = answer == (num1 + num2)
            results[i][j].append((num1, num2, answer, is_correct))
            if k % 2 == 0:
                print(num1, num2, answer, is_correct)
            filename = filename_template.format(num1=i, num2=j, idx=k)
            directory = os.path.dirname(filename)
            if not os.path.exists(directory):
                os.makedirs(directory)
            with open(filename, 'w') as f_out:
                f_out.write(response)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


124753 166453 46 False


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


132577 383107 45 False


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


784510 514272 902222 False


Now for the batch testing. We'll test for 6 digit LHS, and 1-6 digits RHS. 100 trials for each RHS digit length.

Now the results!

In [28]:
for i in [6]:
    accuracies = []
    for j in range(1, i+1):
        num_correct = 0
        for k in range(num_trials):
            filename = filename_template.format(num1=i, num2=j, idx=k)
            #print(i, j, k, filename)
            with open(filename, 'r') as f:
                content = f.read()
            num1, num2, answer = extract_numbers_from_string(content)
            is_correct = int(num1) + int(num2) == int(answer)
            #print(num1, num2, answer, is_correct)
            if is_correct:
                num_correct += 1
        accuracies.append(num_correct)
    accuracies = [(float(x) / num_trials) * 100 for x in accuracies]
    str_acc = [str(int(x)) for x in accuracies]
    print(f"{i}: {' '.join(str_acc)}")

6: 0 0 0 0 0 0


In [29]:
# NOTE the first time I ran this I got 6: [26, 14, 8, 0, 0, 0] for the accuracies, this is what is reflected in the chart on the poster.

Disappointing results to say the least. Not only did this not improve our accuracy, but it got much worse!