In [2]:
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
from tokenizer import Tokenizer
from colour_print import cprint

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-shakespeare-char' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 5 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = True # use PyTorch 2.0 to compile the model to be faster
# Filter out Jupyter notebook arguments before executing configurator
import sys
original_argv = sys.argv.copy()
sys.argv = [arg for arg in sys.argv if not arg.startswith('--f=')]
exec(open('configurator.py').read()) # overrides from command line or config file
sys.argv = original_argv  # restore original arguments
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
    meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    # TODO want to make this more general to arbitrary encoder/decoder schemes
    stoi, itos = meta['stoi'], meta['itos']
    tokenizer = Tokenizer(stoi=stoi, itos=itos)
    encode = tokenizer.encode
    decode = tokenizer.decode
else:
    # ok let's assume gpt-2 encodings by default
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

print(f"val loss: {checkpoint['best_val_loss'].item()}")

number of parameters: 1.04M
Loading meta from data/shakespeare_char/meta.pkl...
val loss: 0.09665738046169281


In [31]:
def encode_list(x):
    y=[]
    for i in x:
        y.append(encode(i)[0])
    return (torch.tensor(y, dtype=torch.long, device=device)[None, ...])

def prompt_it(prompt):
    x = encode(prompt)
    x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
    y = model.generate(x, 100, temperature=0.1, top_k=top_k, end_tokens=encode('>'))
    response = decode(y[0].tolist())
    return response

prompt = '<two plus two equals '
print(prompt_it(prompt))

# answer = prompt_it(prompt).replace('<', '').replace('>', '')
# print(f'Question: {prompt.replace('<', '')[:-1]} \nAnswer: {answer.replace(prompt.replace('<', ''), '')}')

<two plus two equals four>


In [24]:
decode([0])

'\n'

In [8]:
# single
print(prompt_it("<six plus six equals "))

# multiple
prompts = [f"<six plus six equals {i}" for i in ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
            "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
            "seventeen", "eighteen", "nineteen"]]
outputs = prompt_it(prompts)
for o in outputs:
    print(o)


<six plus six equals twenty-two>
<six plus six equals zero




<six minus six equal
<six plus six equals one





<six times six equal
<six plus six equals two





<six times six equal
<six plus six equals three



<six minus six equal
<six plus six equals four




<six times six equal
<six plus six equals five




<six times six equal
<six plus six equals six





<six times six equal
<six plus six equals seven



<six minus six equal
<six plus six equals eight



<six times six equal
<six plus six equals nine




<six times six equal
<six plus six equals ten





<six times six equal
<six plus six equals eleven


<six times six equal
<six plus six equals twelve


<six times six equal
<six plus six equals thirteen
<times six equals th
<six plus six equals fourteen
<times six equals th
<six plus six equals fifteen

<six minus six equal
<six plus six equals sixteen

<six times six equal
<six plus six equals seventeen>
<six times six equ
<six plus six equals eighteen
<times six equals t

In [16]:
prompt = ['<six plus six equals', '<six plus two equals ', '<six plus seven equals ']
# print(prompt_it(prompt))
prompt_it(prompt)

'\n<forty-seven plus four equals fifty-one>'

In [4]:
import torch
from itertools import product

# === prompt and model inference ===
def prompt_it(prompt):
    x = encode(prompt)
    x = torch.tensor(x, dtype=torch.long, device=device)[None, ...]
    y = model.generate(x, 100, temperature=0.1, top_k=top_k, end_tokens=encode('>'))
    response = decode(y[0].tolist())
    return response

# === number to words (0–999) ===
def number_to_words(n):
    ones = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
            "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
            "seventeen", "eighteen", "nineteen"]
    tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]

    if n < 20:
        return ones[n]
    elif n < 100:
        t = n // 10
        o = n % 10
        return tens[t] if o == 0 else f"{tens[t]}-{ones[o]}"
    else:
        h = n // 100
        r = n % 100
        if r == 0:
            return f"{ones[h]} hundred"
        else:
            return f"{ones[h]} hundred {number_to_words(r)}"

# === words to number (0–999) ===
def words_to_number(words):
    words = words.replace("-", " ")
    tokens = words.split()
    num = 0
    temp = 0
    word_to_num = {
        "zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
        "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
        "eleven": 11, "twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15,
        "sixteen": 16, "seventeen": 17, "eighteen": 18, "nineteen": 19,
        "twenty": 20, "thirty": 30, "forty": 40, "fifty": 50,
        "sixty": 60, "seventy": 70, "eighty": 80, "ninety": 90
    }

    for w in tokens:
        if w == "hundred":
            temp *= 100
        elif w in word_to_num:
            temp += word_to_num[w]
    num += temp
    return num

# === expression verification ===
def verify_arithmetic_expression(expression):
    expr = expression.strip().replace(">", "").replace("<", "")
    parts = expr.split()

    try:
        if "plus" in parts:
            op = "plus"
        elif "minus" in parts:
            op = "minus"
        elif "times" in parts:
            op = "times"
        elif "divided" in parts:
            op = "divided"
        else:
            return False, "No recognized operator"

        eq_index = parts.index("equals")
        op_index = parts.index(op)

        num1 = words_to_number(" ".join(parts[:op_index]))
        num2 = words_to_number(" ".join(parts[op_index + 1:eq_index]))
        result = words_to_number(" ".join(parts[eq_index + 1:]))

        if op == "plus":
            actual = num1 + num2
        elif op == "minus":
            actual = num1 - num2
        elif op == "times":
            actual = num1 * num2
        elif op == "divided":
            if num2 == 0:
                return False, "Division by zero"
            actual = num1 // num2 if num1 % num2 == 0 else None

        if actual is None:
            return False, "Non-integer division"

        return (actual == result), f"{num1} {op} {num2} = {actual}, got {result}"
    except Exception as e:
        return False, f"Error parsing: {e}"

# === main testing loop ===
print("Testing verbal arithmetic expressions from 0–999:")
print("=" * 70)

# operations = ["plus", "minus", "times", "divided by"]
# correct_count = 0
# total_count = 0
# log_every = 500

# for i, j in product(range(99), range(99)):
#     for op in operations:
#         # filter valid cases
#         if op == "plus" and (i + j <= 999):
#             valid = True
#         elif op == "minus" and (i >= j):
#             valid = True
#         elif op == "times" and (i * j <= 999):
#             valid = True
#         elif op == "divided by" and (j != 0 and i % j == 0 and i // j <= 999):
#             valid = True
#         else:
#             valid = False

#         if not valid:
#             continue

#         num1_word = number_to_words(i)
#         num2_word = number_to_words(j)
#         prompt = f"<{num1_word} {op} {num2_word} equals "
#         expression = prompt_it(prompt).replace('<', '').replace('>', '')

#         is_correct, msg = verify_arithmetic_expression(expression)
#         total_count += 1
#         if is_correct:
#             correct_count += 1
#         else:
#             cprint(f"❌ {expression}", "red")
#             cprint(f"   → {msg}", "yellow")

#         if total_count % log_every == 0:
#             acc = correct_count / total_count * 100
#             cprint(f"Progress: {total_count} tests | Accuracy: {acc:.2f}%", "cyan")

# # === summary ===
# print("=" * 70)
# acc = correct_count / total_count * 100 if total_count > 0 else 0
# cprint("Final Results", "blue")
# cprint(f"Total tests: {total_count}", "blue")
# cprint(f"Correct: {correct_count}", "blue")
# cprint(f"Accuracy: {acc:.2f}%", "blue")

Testing verbal arithmetic expressions from 0–999:


In [5]:
from itertools import product

val = []
operations = ["plus", "minus", "times", "divided by"]

for i, j in product(range(99), range(99)):
    for op in operations:
        # apply same validity rules as your test loop
        if op == "plus":
            valid = (i + j <= 999)
            if not valid: 
                continue
            result = i + j
        elif op == "minus":
            valid = (i >= j)
            if not valid:
                continue
            result = i - j
        elif op == "times":
            valid = (i * j <= 999)
            if not valid:
                continue
            result = i * j
        elif op == "divided by":
            valid = (j != 0 and i % j == 0 and (i // j) <= 999)
            if not valid:
                continue
            result = i // j
        # build the verbal test string
        val.append(f"<{number_to_words(i)} {op} {number_to_words(j)} equals")


In [6]:
from collections import defaultdict

grouped = defaultdict(list)

for expr in val:
    length = len(expr.split())
    grouped[length].append(expr)

# turn it into a list of groups sorted by length
grouped_val = [grouped[k] for k in sorted(grouped.keys())]


In [7]:
def prompt_it(prompts):
    # Handle both single prompt (string) and multiple prompts (list)
    if isinstance(prompts, str):
        prompts = [prompts]
    
    # Encode all prompts
    encoded = [torch.tensor(encode(prompt), dtype=torch.long) for prompt in prompts]
    
    # Pad sequences to same length
    min_len = min(len(t) for t in encoded)
    encoded = [t[:min_len] for t in encoded]
    x = torch.stack([torch.tensor(t, dtype=torch.long) for t in encoded]).to(device)
    
    # Generate for all prompts in batch
    with torch.no_grad():
        with ctx:
            special_end_token = encode('>')
            y = model.generate(x, 60, temperature=0.8, top_k=top_k, end_tokens=special_end_token)
    
    # Decode all outputs
    responses = []
    for sample in y:
        responses.append(decode(sample.tolist()).split('>')[0].replace('<', ''))
    
    # Return single response if input was single prompt, otherwise return list
    return responses[0] if len(responses) == 1 else responses

prompt = ['<six times six equals ', '<six plus two equals ', '<six plus seven equals ']
# print(prompt_it(prompt))
prompt_it(grouped_val[0][:10])

  x = torch.stack([torch.tensor(t, dtype=torch.long) for t in encoded]).to(device)


['zero plus zero equals two',
 'zero minus zero equals zero',
 'zero times zero equals zero',
 'zero plus one equals one',
 'zero times one equals zero',
 'zero plus two equals two',
 'zero times two equals zero',
 'zero plus three equals three',
 'zero times three equals zero',
 'zero plus four equals four']

In [21]:
def process_lists(input_list):
    result = []
    for item in input_list:
        # Split each list into chunks of 1000 items
        for i in range(0, len(item), 2000):
            result.append(item[i:i+2000])
    return result

In [22]:
correct_count = 0
total_count = 0
log_every = 500

grouped_val = process_lists(grouped_val)
for group in grouped_val:
    result = prompt_it(group)
    for expression in result:
        is_correct, msg = verify_arithmetic_expression(expression)
        total_count += 1
        if is_correct:
            correct_count += 1
        else:
            cprint(f"❌ {expression}", "red")
            cprint(f"   → {msg}", "yellow")

        if total_count % log_every == 0:
            acc = correct_count / total_count * 100
            cprint(f"Progress: {total_count} tests | Accuracy: {acc:.2f}%", "cyan")
print("=" * 70)
acc = correct_count / total_count * 100 if total_count > 0 else 0
cprint("Final Results", "blue")
cprint(f"Total tests: {total_count}", "blue")
cprint(f"Correct: {correct_count}", "blue")
cprint(f"Accuracy: {acc:.2f}%", "blue")

  x = torch.stack([torch.tensor(t, dtype=torch.long) for t in encoded]).to(device)


[31m❌ zero plus zero equals two[0m
[33m   → 0 plus 0 = 0, got 2[0m
[36mProgress: 500 tests | Accuracy: 99.80%[0m
[31m❌ two times fifty-five equals one hundred[0m
[33m   → 2 times 55 = 110, got 100[0m
[31m❌ three times thirteen equals forty-nine[0m
[33m   → 3 times 13 = 39, got 49[0m
[31m❌ three times thirty-seven equals one hundred fourteen[0m
[33m   → 3 times 37 = 111, got 114[0m
[31m❌ three times forty-eight equals one hundred fifty-four[0m
[33m   → 3 times 48 = 144, got 154[0m
[31m❌ three times forty-nine equals one hundred fifty-seven[0m
[33m   → 3 times 49 = 147, got 157[0m
[31m❌ three times sixty-eight equals two hundred fourteen[0m
[33m   → 3 times 68 = 204, got 214[0m
[36mProgress: 1000 tests | Accuracy: 99.30%[0m
[31m❌ five plus nine equals fourten[0m
[33m   → 5 plus 9 = 14, got 0[0m
[31m❌ five times seventeen equals ninety-five[0m
[33m   → 5 times 17 = 85, got 95[0m
[31m❌ five times thirty-seven equals one hundred ninety-five[0m
[33m

In [None]:
# Accuracy: 65.05%
# Accuracy: 78.59%
# Accuracy: 80.35%

In [None]:
verify_arithmetic_expression('one plus sixteen equals seventeen')

(False, '1 + 16 = 17, expected: 1')

In [None]:
num_samples = 5 # number of samples to draw
# run generation
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            special_end_token = encode('<|end|>')
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k, end_tokens=special_end_token)
            response = decode(y[0].tolist())
            print(response)
            try:
                response_split = response.split('+')
                x_ = int(response_split[0].strip('\n<|start|>'))
                y_ = int(response_split[1].split('=<|answer|>')[0])
                z_ = int(response.split('=<|answer|>')[-1].replace('<|end|>', '').strip())
                cprint(f"{x_} + {y_} == {z_}", "green") if x_ + y_ == z_ else cprint(f"{x_} + {y_} != {z_}", "red")
            except:
                cprint("Invalid response format", "yellow")
            print('---------------')


<|start|>flpqv-flhffx<|answer|>fuuqv<|end|>
[33mInvalid response format[0m
---------------

<|start|>fhhpg-fvhfpx<|answer|>gglpq<|end|>
[33mInvalid response format[0m
---------------

<|start|>fuqlo-fvovpx<|answer|>guilg<|end|>
[33mInvalid response format[0m
---------------

<|start|>flfhi-fqqllx<|answer|>ggqpq<|end|>
[33mInvalid response format[0m
---------------

<|start|>fgoii-fuihox<|answer|>fvuou<|end|>
[33mInvalid response format[0m
---------------


In [1]:
from itertools import product
from tqdm.auto import tqdm
import random

# Define ranges to test
ranges = [
    (0, 10, "Single-digit"),
    (10, 100, "Two-digit"),
    (100, 1000, "Three-digit")
]

# Initialize counters for each range
results = {}
for start, end, name in ranges:
    results[name] = {"correct": 0, "wrong": 0}

# Calculate total number of tests to run
total_tests = 0
for start, end, name in ranges:
    if name == "Single-digit":
        total_tests += (end - start) * (end - start)  # All combinations (100 tests)
    else:
        total_tests += 1000  # Sample 1000 pairs for larger ranges

# Use a single progress bar for all tests
total_correct = 0
total_tests_done = 0
with tqdm(total=total_tests, desc="Testing model accuracy") as pbar:
    for start, end, name in ranges:
        # Generate test pairs for this range
        if name == "Single-digit":
            # Test all combinations for single digits
            pairs = list(product(range(start, end), range(start, end)))
        else:
            # Sample random pairs for larger ranges
            random.seed(1337)  # For reproducibility
            pairs = [(random.randint(start, end-1), random.randint(start, end-1)) 
                    for _ in range(1000)]
        
        # Test each pair in this range
        for i, j in pairs:
            line = f'<|start|>{i:05d}+{j:05d}=<|answer|>'
            custom_ids = encode(line)
            custom_x = (torch.tensor(custom_ids, dtype=torch.long, device=device)[None, ...])
            special_end_token = encode('<|end|>')

            with torch.no_grad():
                with ctx:
                    logits = model.generate(custom_x, max_new_tokens, temperature=temperature, top_k=top_k, end_tokens=special_end_token)
                    decoded = decode(logits[0].tolist())
                    # Print random lines
                    print_bool = random.random() < 0.1  # Print 10% of the lines
                    if print_bool:
                        print(decoded)
                    try:
                        response_split = decoded.split('+')
                        x_ = int(response_split[0].strip('\n<|start|>'))
                        y_ = int(response_split[1].split('=<|answer|>')[0])
                        z_ = int(decoded.split('=<|answer|>')[-1].replace('<|end|>', '').strip())
                        if print_bool:
                            cprint(f"{x_} + {y_} == {z_}", "green") if x_ + y_ == z_ else cprint(f"{x_} + {y_} != {z_}", "red")
                            print('---------------')
                        if x_ + y_ == z_:
                            results[name]["correct"] += 1
                            total_correct += 1
                        else:
                            results[name]["wrong"] += 1
                    except:
                        results[name]["wrong"] += 1

            # Update progress bar with overall accuracy across all ranges
            total_tests_done += 1
            overall_accuracy = (total_correct / total_tests_done) * 100
            pbar.update(1)
            pbar.set_postfix(range=name, overall_acc=f"{overall_accuracy:.2f}%")
# Display final accuracy for each range
print("\nFinal model accuracy by range:")
for start, end, name in ranges:
    correct = results[name]["correct"]
    total = correct + results[name]["wrong"]
    print(f"{name} numbers ({start}-{end-1}): {correct}/{total} ({correct/total*100:.2f}%)")

# Calculate overall accuracy across all ranges
total_correct = sum(results[name]["correct"] for _, _, name in ranges)
total_predictions = sum(results[name]["correct"] + results[name]["wrong"] for _, _, name in ranges)
overall_accuracy = total_correct / total_predictions * 100 if total_predictions > 0 else 0

print(f"Overall model accuracy: {total_correct}/{total_predictions} ({overall_accuracy:.2f}%)")

Testing model accuracy:   0%|          | 0/2100 [00:00<?, ?it/s]

NameError: name 'encode' is not defined

In [None]:
response_split = decoded.split('+')
x_ = int(response_split[0].strip('\n<|start|>'))
y_ = int(response_split[1].split('=<|answer|>')[0])
z_ = int(response.split('=<|answer|>')[-1].replace('<|end|>', '').strip())
cprint(f"{x_} + {y_} == {z_}", "green") if x_ + y_ == z_ else cprint(f"{x_} + {y_} != {z_}", "red")

NameError: name 'decoded' is not defined

In [None]:
start, answer, end = ['<|start|>', '<|answer|>', '<|end|>']
# Sample those that are missing
file_path = 'data/shakespeare_char/input.txt'
# Read the file content into memory
with open(file_path, 'r') as f:
    input_data = f.read()

for _ in range(1):
    for i, j in product(range(1000), range(1000)):
        line = f"{start}{i}+{j}={answer}{i+j}{end}\n"
        if line not in input_data:
            print(f"Missing line: {line.strip()}")

In [None]:
import random
from tqdm.auto import tqdm

start, answer, end = ['<|start|>', '<|answer|>', '<|end|>']

# Use a more efficient approach - sample random pairs instead of checking all combinations

# Set random seed for reproducibility
random.seed(1337)

# Define test parameters
total_tests = 1000
correct = 0
wrong = 0

# Use tqdm to show progress
with tqdm(total=total_tests, desc="Testing model accuracy") as pbar:
    # Generate random number pairs from 0-999
    pairs = [(random.randint(0, 999), random.randint(0, 999)) for _ in range(total_tests)]
    
    for i, j in product(range(1000), range(1000)):
        line = f"{start}{i}+{j}={answer}{i+j}{end}\n"
        if line not in input_data:
            custom_ids = encode(line)[:-2]
            custom_x = (torch.tensor(custom_ids, dtype=torch.long, device=device)[None, ...])
            special_end_token = encode(end)

            with torch.no_grad():
                with ctx:
                    logits = model.generate(custom_x, max_new_tokens=max_new_tokens, end_tokens=special_end_token)
                    decoded = decode(logits[0].tolist())
                    try:
                        response_split = decoded.split('+')
                        x_ = int(response_split[0].strip('\n' + start))
                        y_ = int(response_split[1].split('=' + answer)[0])
                        z_ = int(decoded.split('=' + answer)[-1].replace(end, '').strip())
                        
                        # Print every 100th example
                        if random.random() < 0.1:
                            print(decoded)
                            if x_ + y_ == z_:
                                cprint(f"{x_} + {y_} == {z_}", "green")
                            else:
                                cprint(f"{x_} + {y_} != {z_}", "red")
                            print('---------------')
                        
                        if x_ + y_ == z_:
                            correct += 1
                        else:
                            wrong += 1
                    except:
                        wrong += 1
            
            # Update progress bar
            pbar.update(1)
            pbar.set_postfix(accuracy=f"{correct/(correct+wrong)*100:.2f}%")

# Display final accuracy
print(f"Model accuracy on 0-999 range: {correct}/{total_tests} ({correct/total_tests*100:.2f}%)")

Testing model accuracy:   0%|          | 0/1000 [00:00<?, ?it/s]

<|start|>20+30=<|answer|>50<|end|>
[32m20 + 30 == 50[0m
---------------
<|start|>20+60=<|answer|>80<|end|>
[32m20 + 60 == 80[0m
---------------
<|start|>40+40=<|answer|>80<|end|>
[32m40 + 40 == 80[0m
---------------
<|start|>50+60=<|answer|>110<|end|>
[32m50 + 60 == 110[0m
---------------
<|start|>70+40=<|answer|>110<|end|>
[32m70 + 40 == 110[0m
---------------
<|start|>70+70=<|answer|>140<|end|>
[32m70 + 70 == 140[0m
---------------
<|start|>80+70=<|answer|>150<|end|>
[32m80 + 70 == 150[0m
---------------
<|start|>80+90=<|answer|>170<|end|>
[32m80 + 90 == 170[0m
---------------
<|start|>90+30=<|answer|>120<|end|>
[32m90 + 30 == 120[0m
---------------


KeyboardInterrupt: 

In [None]:
from itertools import product
from tqdm.auto import tqdm
import random

# Set random seed for reproducibility
random.seed(1337)

# Define the range to test (0-1000)
total_tests = 1000
correct = 0
wrong = 0

# Use tqdm to show progress
with tqdm(total=total_tests, desc="Testing model accuracy") as pbar:
    # Generate random number pairs from 0-1000
    pairs = [(random.randint(0, 1000), random.randint(0, 1000)) for _ in range(total_tests)]
    
    for i, j in pairs:
        line = f'<|start|>{i}+{j}=<|answer|>'
        custom_ids = encode(line)
        custom_x = (torch.tensor(custom_ids, dtype=torch.long, device=device)[None, ...])
        special_end_token = encode('<|end|>')

        with torch.no_grad():
            with ctx:
                logits = model.generate(custom_x, max_new_tokens, temperature=temperature, top_k=top_k, end_tokens=special_end_token)
                decoded = decode(logits[0].tolist())
                try:
                    z_ = int(decoded.split('=<|answer|>')[-1].replace('<|end|>', '').strip())
                    if i + j == z_:
                        correct += 1
                    else:
                        wrong += 1
                except:
                    wrong += 1
        
        # Update progress bar
        pbar.update(1)
        pbar.set_postfix(accuracy=f"{correct/(correct+wrong)*100:.2f}%")

# Display final accuracy
print(f"Model accuracy on 0-1000 range: {correct}/{total_tests} ({correct/total_tests*100:.2f}%)")


Final model accuracy by range:
Single-digit numbers (0-9): 61/100 (61.00%)
Two-digit numbers (10-99): 1000/1000 (100.00%)
Three-digit numbers (100-999): 995/1000 (99.50%)

Final model accuracy by range:
Single-digit numbers (0-9): 65/100 (65.00%)
Two-digit numbers (10-99): 1000/1000 (100.00%)
Three-digit numbers (100-999): 994/1000 (99.40%)

In [None]:
# Overall model accuracy on single-digit addition: 1361/10000 (13.61%)

In [None]:
from itertools import product
from tqdm import tqdm

correct = 0
wrong = 0

# Use tqdm to show a progress bar for all pairs (0-9 x 0-9)
for _ in tqdm(range(1)):
    for i, j in product(range(10,100), range(10,100)):
        line = f'<|start|>{i}+{j}=<|answer|>'
        custom_ids = encode(line)
        custom_x = (torch.tensor(custom_ids, dtype=torch.long, device=device)[None, ...])

        special_end_token = encode('<|end|>')

        with torch.no_grad():
            with ctx:
                logits = model.generate(custom_x, max_new_tokens, temperature=temperature, top_k=top_k, end_tokens=special_end_token)
                decoded = decode(logits[0].tolist())
                try:
                    z_ = int(decoded.split('=<|answer|>')[-1].replace('<|end|>', '').strip())
                    if i + j == z_:
                        correct += 1
                    else:
                        wrong += 1
                except:
                    wrong += 1

print(f"Model accuracy on single-digit addition: {correct}/{correct+wrong} ({correct/(correct+wrong)*100:.2f}%)")

100%|██████████| 1/1 [00:32<00:00, 32.51s/it]

Model accuracy on single-digit addition: 8099/8100 (99.99%)





In [None]:
x_, y_ = 55,43
line = f'<|start|>{x_}+{y_}=<|answer|>'
custom_ids = encode(line)
custom_x = (torch.tensor(custom_ids, dtype=torch.long, device=device)[None, ...])

special_end_token = encode('<|end|>')

with torch.no_grad():
    with ctx:
        logits = model.generate(custom_x, max_new_tokens, temperature=temperature, top_k=top_k, end_tokens=special_end_token)
        decoded = decode(logits[0].tolist())
        print(decoded)
        z_ = int(decoded.split('<|answer|>')[-1].replace('<|end|>', '').strip())
        cprint(f"{x_} + {y_} == {z_}", "green") if x_ + y_ == z_ else cprint(f"{x_} + {y_} != {z_}", "red")

<|start|>55+43=<|answer|>98<|end|>
[32m55 + 43 == 98[0m


In [None]:
import os
import pickle
import requests
import numpy as np

# download the tiny shakespeare dataset
special_tokens = ['<|start|>', '<|answer|>', '<|end|>']
include_special_tokens = False

input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# Create a unified character set
all_chars = sorted(list(set(data))) if include_special_tokens else sorted(list(set([c for c in data if c not in ''.join(special_tokens)])))
vocab_size = len(special_tokens) + len(all_chars)
print("Regular characters:", all_chars)
print(f"Total vocab size (characters + special tokens): {vocab_size:,}")

# Create mappings that include both individual characters and special tokens
stoi = {}
# First add special tokens
for i, token in enumerate(special_tokens):
    stoi[token] = i
# Then add all individual characters
for i, ch in enumerate(all_chars):
    stoi[ch] = i + len(special_tokens)

itos = {i: ch for ch, i in stoi.items()}

def encode(s):
    # Prioritize special tokens during encoding
    encoded = []
    i = 0
    while i < len(s):
        # Check if any special token starts at the current position
        is_special = False
        for token in special_tokens:
            if s[i:].startswith(token):
                encoded.append(stoi[token])
                i += len(token)
                is_special = True
                break
        if not is_special:
            # If not a special token, encode as individual character
            encoded.append(stoi[s[i]])
            i += 1
    return encoded

def decode(l):
    return ''.join([itos[i] for i in l])

# create the train and test splits
# Find the position of the last complete record in the first 90% of data
split_point = int(len(data) * 0.9)
# Look for the next <|end|> token after this point
end_token = "<|end|>\n"
next_end = data.find(end_token, split_point)
if next_end != -1:
    # Split after the end token (including the newline)
    split_point = next_end + len(end_token)
else:
    # Fallback if no end token is found
    print("Warning: Couldn't find clean split point. Using approximate split.")

train_data = data[:split_point]
val_data = data[split_point:]

# Verify the split is clean
print(f"Train data ends with: {train_data[-20:].replace('\n', '\\n')}")
print(f"Val data begins with: {val_data[:20].replace('\n', '\\n')}")

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
# train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
# val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))

meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
# with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
#     pickle.dump(meta, f)

length of dataset in characters: 342,895
Regular characters: ['\n', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Total vocab size (characters + special tokens): 15
Train data ends with: |answer|>106<|end|>\n
Val data begins with: <|start|>90+17<|answ
train has 101,234 tokens
val has 11,661 tokens


In [None]:
line.split('+')[0].strip('<|start|>').strip()
line.split('+')[1].strip('=<|answer|>').strip()

'10'

In [None]:
response = '<|start|>5+93=<|answer|>107<|end|>'
response_split = response.split('+')
x=int(response_split[0].strip('<|start|>'))
y=int(response_split[1].split('=<|answer|>')[0])
z = int(response.split('=<|answer|>')[-1].replace('<|end|>', '').strip())
x+y==z, (x,y,z)

(False, (5, 93, 107))

In [None]:
'<|start|>5+93=<|answer|>107<|end|>'.split('+')[1].split('=<|answer|>')[0]

'93'