In [8]:
import pandas as pd
import numpy as np
import random
import transformers
import torch
import dotenv
import os
import matplotlib.pyplot as plt
import re
import string
from tqdm import tqdm

dotenv.load_dotenv()

True

In [9]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    token=os.getenv('HF_TOKEN')
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.04it/s]


# Rule 110 Cellular Automaton

In [None]:
def make_cellular_problem_set(size, steps, num_problems, boundary='wrap'):
    if steps < 3:
        raise ValueError("Need 3 steps to have a solution and 2 intermediates")
    def int_to_binary_list(n, min_length=8):
        binary = bin(n)[2:]  # Convert to binary string and remove '0b' prefix
        binary_list = [int(b) for b in binary.zfill(min_length)]  # Pad with zeros if necessary
        return binary_list
    
    dict_110 = {
        (0, 0, 0): 0,
        (0, 0, 1): 1,
        (0, 1, 0): 1,
        (0, 1, 1): 1,
        (1, 0, 0): 0,
        (1, 0, 1): 1,
        (1, 1, 0): 1,
        (1, 1, 1): 0
    }

    def rule_110(prev):
        next_state = []
        for i in range(len(prev)):
            left = prev[(i-1) % len(prev)] if boundary == 'wrap' or (i > 0 and i < len(prev)-1) else boundary
            center = prev[i]
            right = prev[(i+1) % len(prev)] if boundary == 'wrap' or (i > 0 and i < len(prev)-1) else boundary
            pattern = (left, center, right)
            next_state.append(dict_110[pattern])
        return next_state
    
    def make_rule_110_problem(initial_state, steps):
        current_state = initial_state
        states = [current_state]
        for _ in range(steps):
            current_state = rule_110(current_state)
            states.append(current_state)
        return (''.join(str(x) for x in initial_state),
            ''.join(str(x) for x in states[-1]),
            ''.join(str(x) for x in states[1]),
            ''.join(str(x) for x in states[-2]))

    return pd.DataFrame(
        [make_rule_110_problem(int_to_binary_list(((i+1)*33581)%(2**(size))), steps) for i in range(num_problems)]
        , columns=['problem', 'correct_solution', 'intermediate_1', 'intermediate_2'])

In [None]:
make_cellular_problem_set(10, 3, 1000, boundary=0)

# SAT

In [None]:
from pysat.formula import CNF
from pysat.solvers import Glucose3

def solve_nsat(clauses):
    # Create a CNF formula
    cnf = CNF()
    for clause in clauses:
        cnf.append(clause)

    # Create a SAT solver
    with Glucose3(bootstrap_with=cnf) as solver:
        # Check if the formula is satisfiable
        if solver.solve():
            return solver.get_model()
        else:
            return None

In [None]:
def make_nsat_problem_set(vars_per_clause, num_clauses, num_problems):
    def make_nsat_problem(vars_per_clause, num_clauses):
        text_variables = [string.ascii_lowercase[i] for i in range(vars_per_clause)]
        text_problem = []
        pysat_problem = []
        for _ in range(num_clauses):
            clause = random.sample(range(vars_per_clause), 3)
            signs = [random.choice([-1, 1]) for _ in range(3)]
            pysat_clause = [signs[i]*(var+1) for i, var in enumerate(clause)]
            pysat_problem.append(pysat_clause)
            pysat_solution = solve_nsat(pysat_problem)
            if pysat_solution is None:
                text_solution = None
            else:
                text_solution_letters = [f"{'¬' if var <0 else ''}{text_variables[abs(var)-1]}" for i, var in enumerate(pysat_solution)]
                text_solution = f"{' ^ '.join(text_solution_letters)}"

            text_clause = [f"{'¬' if signs[i] == -1 else ''}{text_variables[var]}" for i, var in enumerate(clause)]
            text_problem.append(f"({' v '.join(text_clause)})")
        return ' ^ '.join(text_problem), text_solution, None, None

    return pd.DataFrame(
        [make_nsat_problem(vars_per_clause, num_clauses) for _ in range(num_problems)],
        columns=['problem', 'correct_solution', 'intermediate_1', 'intermediate_2'])

In [None]:
make_nsat_problem_set(3, 20, 1000)

# Dot Product

In [10]:
def make_dot_product_problem_set(vec_len, vec_mag, num_problems):
    if vec_len < 2:
        raise ValueError("Need vectors of length 2 or greater to have two intermediates")
    def make_dot_product_problem(vec_len):
        a = np.random.randint(-vec_mag, vec_mag, vec_len)
        b = np.random.randint(-vec_mag, vec_mag, vec_len)
        return (f"[{', '.join([str(x) for x in a])}] ⋅ [{', '.join([str(x) for x in b])}]"
        , np.dot(a, b)
        , a[0]*b[0],
        a[-1]*b[-1])

    return pd.DataFrame(
        [make_dot_product_problem(vec_len) for _ in range(num_problems)],
        columns=['problem', 'correct_solution', 'intermediate_1', 'intermediate_2'])

In [None]:
make_dot_product_problem_set(3, 10, 1000)

In [17]:
def solve_problem_cot(problem, sys_prompt, cot_prompt):
    messages = [
    {"role": "system", "content": sys_prompt + ' ' + cot_prompt},
    {"role": "user", "content": problem},
    ]

    outputs = pipeline(
    messages,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
    pad_token_id=pipeline.tokenizer.eos_token_id
    )

    return outputs[0]['generated_text'][-1]['content']

def solve_problem_memo(problem, sys_prompt, memo_prompt, max_toks=10): # setting max_toks to 1 because we're doing small digit numbers
    messages = [
    {"role": "system", "content": sys_prompt + ' ' + memo_prompt},
    {"role": "user", "content": problem},
    ]

    outputs = pipeline(
    messages,
    max_new_tokens=max_toks,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
    pad_token_id=pipeline.tokenizer.eos_token_id
    )

    return outputs[0]['generated_text'][-1]['content']

In [12]:
problem_prompt = "What is the dot product of these two vectors?"
cot_prompt = "Show your work."
memo_prompt = "Answer with only a number."
n=100

In [13]:
def dot_test(dot_problems, problem_prompt=problem_prompt, cot_prompt=cot_prompt, memo_prompt=memo_prompt, n=n):
    cot_solutions = []
    cot_correct = 0
    for i, row in tqdm(list(dot_problems.iterrows())[:n]):
        cot_solutions.append(solve_problem_cot(row['problem'], problem_prompt, cot_prompt))
        cot_correct += str(row['correct_solution']) in cot_solutions[-1]

    memo_solutions = []
    memo_correct = 0
    for i, row in tqdm(list(dot_problems.iterrows())[:n]):
        memo_solutions.append(solve_problem_memo(row['problem'], problem_prompt, memo_prompt))
        memo_correct += str(row['correct_solution']) in memo_solutions[-1]

    print(cot_correct, memo_correct)
    return cot_solutions, memo_solutions

In [14]:
results = {}

In [None]:
results[(3, 10)] = dot_test(make_dot_product_problem_set(3, 10, n))

In [None]:
results[(2, 10)] = dot_test(make_dot_product_problem_set(2, 10, n))

In [None]:
results[(2, 20)] = dot_test(make_dot_product_problem_set(2, 20, n))

In [None]:
results[(5, 10)] = dot_test(make_dot_product_problem_set(5, 10, n))

In [39]:
def repeated_solve_problem_memo(problem, sys_prompt, memo_prompt, max_toks=10, return_retries=False):
    soln = solve_problem_memo(problem, sys_prompt, memo_prompt, max_toks=max_toks)
    i = 0
    while sum([len([n for n in soln if n.isnumeric()]) > 3]):
        if i >= 50: # tap out
            return soln if not return_retries else (soln, i)
        i += 1
        print('retrying', i, soln)
        soln = solve_problem_memo(problem, sys_prompt, memo_prompt, max_toks=max_toks)
    return soln if not return_retries else (soln, i)

def test_memo_prompt(memo_prompt):
    dot_problems = make_dot_product_problem_set(3, 10, n)
    memo_solutions = []
    memo_correct = 0
    for i, row in tqdm(list(dot_problems.iterrows())[:n]):
        memo_solutions.append(repeated_solve_problem_memo(row['problem'], problem_prompt, memo_prompt=memo_prompt, max_toks=20))
        memo_correct += str(row['correct_solution']) in memo_solutions[-1]
    print(sum([len([n for n in x if n.isnumeric()]) <= 2 for x in memo_solutions]))
    print('\n'.join(memo_solutions))

test_memo_prompt("""Answer with only a number. Do not do any calculations. For example:
                 Problem: [2, 9, -3] ⋅ [-6, -2, -9]
                 Solution: -3
                 """)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:08<00:00, 12.48it/s]

100
-32
35
-18
-17
-3
-62
-54
-26
-17
-6
-16
-24
-23
-52
32
6
35
-64
-12
-12
-24
27
14
-26
-30
-38
-10
49
57
-16
-43
-66
-34
-6
-12
34
-72
-6
0
48
-12
-78
54
24
64
-38
-38
12
35
-65
35
-38
24
-1
-33
-13
-10
-1
-15
-31
6
-95
-7
-34
-30
-12
-73
21
-26
-12
-12
6
6
72
0
18
81
-73
-44
-27
6
-10
-34
32
12
0
-27
-38
15
-56
85
-3
-14
-13
-14
-10
3
-40
-44
-13





In [40]:
one_shot_true_memo_prompt = """Answer with only a number. Do not do any calculations. For example:
                 Problem: [2, 9, -3] ⋅ [-6, -2, -9]
                 Solution: -3
                 """

one_shot_false_memo_prompt = """Answer with only a number. Do not do any calculations. For example:
                 Problem: [2, 9, -3] ⋅ [-6, -2, -9]
                 Solution: 6
                 """

def test_memo_prompt_correctness(memo_prompt, n):
    dot_problems = make_dot_product_problem_set(3, 10, n)
    memo_solutions = []
    memo_correct = 0
    max_retries = 0
    for i, row in tqdm(list(dot_problems.iterrows())[:n]):
        sol, retries = repeated_solve_problem_memo(row['problem'], problem_prompt, memo_prompt=memo_prompt, max_toks=20, return_retries=True)
        memo_solutions.append(sol)
        memo_correct += str(row['correct_solution']) in memo_solutions[-1]
        max_retries = max(max_retries, retries)
    print(f"Max retries: {max_retries}")
    print(f"Correctness: {memo_correct}")

In [42]:
test_memo_prompt_correctness(one_shot_true_memo_prompt, n=1000)
test_memo_prompt_correctness(one_shot_false_memo_prompt, n=1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

 28%|██▊       | 278/1000 [00:22<01:40,  7.20it/s]

retrying 1 8 + (-6) + (-4) = -2


100%|██████████| 1000/1000 [01:21<00:00, 12.25it/s]


Max retries: 1
Correctness: 42


  0%|          | 1/1000 [00:00<10:47,  1.54it/s]

retrying 1 9*7 + (-2)*6 + (-2)*(-1) = 63 - 


 20%|█▉        | 199/1000 [00:16<01:00, 13.22it/s]

retrying 1 -8 - 27 + 18 = -17


 20%|██        | 201/1000 [00:17<02:23,  5.56it/s]

retrying 2 -8 + 27 + 18 = 37


 27%|██▋       | 267/1000 [00:22<00:55, 13.28it/s]

retrying 1 -3 + 18 - 80 = -65
retrying 2 -3 - 18 + 80 = 59
retrying 3 -3 + 18 - 80 = -65
retrying 4 -3 - 18 + 80 = 59
retrying 5 -3 + 18 - 80 = -65
retrying 6 -3 + 18 - 80 = -65
retrying 7 -3 + 18 - 80 = -65
retrying 8 -3 - 18 + 80 = 59
retrying 9 -3 + 18 - 80 = -65
retrying 10 -3 + 18 - 80 = -65
retrying 11 -3 - 18 + 80 = 59
retrying 12 -3 - 18 + 80 = 59
retrying 13 -3 + 18 - 80 = -65
retrying 14 -3 + 18 - 80 = -65
retrying 15 -3 + 18 - 80 = -65


 27%|██▋       | 269/1000 [00:28<11:09,  1.09it/s]

retrying 16 -3 + 18 - 80 = -65


 46%|████▌     | 457/1000 [00:43<01:31,  5.91it/s]

retrying 1 -7* -1 + (-3)* -6 + (-2)* -2 = 7


 51%|█████     | 509/1000 [00:48<01:09,  7.02it/s]

retrying 1 -10 + (-35) + (-7) = -52


 55%|█████▍    | 547/1000 [00:51<01:02,  7.23it/s]

retrying 1 -36 + 40 + 40 = 44


 62%|██████▏   | 621/1000 [00:57<00:28, 13.39it/s]

retrying 1 -10 - 48 + 0 = -58
retrying 2 -10 - 48 + 0 = -58
retrying 3 -10 - 48 + 0 = -58
retrying 4 -10 - 48 + 0 = -58
retrying 5 -10 - 48 + 0 = -58
retrying 6 -10*[-1] + 8*[-6] + [-7]*[0]
retrying 7 -10 - 48 + 0 = -58


 62%|██████▏   | 623/1000 [01:00<03:20,  1.88it/s]

retrying 8 -10 - 48 + 0 = -58


 73%|███████▎  | 733/1000 [01:09<00:33,  8.06it/s]

retrying 1 -24 + 8 + 40 = 44


 89%|████████▉ | 891/1000 [01:22<00:18,  5.84it/s]

retrying 1 -1*[-5] - 2*[3] + 8*[3] = -


 95%|█████████▍| 947/1000 [01:28<00:08,  6.07it/s]

retrying 1 7*5 + 2*(-9) + 1*0 = -13


 98%|█████████▊| 983/1000 [01:31<00:02,  7.43it/s]

retrying 1 -7 - 0 + 54 = 47


100%|██████████| 1000/1000 [01:32<00:00, 10.79it/s]

Max retries: 16
Correctness: 35





# Found a useable memo prompt
Correct example elicits more reliably
Correct example doesn't seem to significantly boost performance

In [43]:
results = {}

In [44]:
results[(3, 10)] = dot_test(make_dot_product_problem_set(3, 10, n), memo_prompt=one_shot_true_memo_prompt)

100%|██████████| 100/100 [05:14<00:00,  3.15s/it]
100%|██████████| 100/100 [00:08<00:00, 12.23it/s]

81 5





In [45]:
results[(2, 10)] = dot_test(make_dot_product_problem_set(2, 10, n), memo_prompt=one_shot_true_memo_prompt)

100%|██████████| 100/100 [04:06<00:00,  2.47s/it]
100%|██████████| 100/100 [00:07<00:00, 13.13it/s]

86 12





In [46]:
results[(3, 5)] = dot_test(make_dot_product_problem_set(3, 5, n), memo_prompt=one_shot_true_memo_prompt)

100%|██████████| 100/100 [05:19<00:00,  3.20s/it]
100%|██████████| 100/100 [00:07<00:00, 12.61it/s]

86 12





In [47]:
make_dot_product_problem_set(3, 10, 1000).to_csv('dot_product_problems_3_5.csv')