In [127]:
# testing to see if models are capable of counting task 
# w/ performance > random chance
import pandas as pd
import numpy as np
from pprint import pprint
import importlib
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('..')
from utils.evals import load_model, create_counting_prompt
from tqdm import tqdm

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [148]:
PERSISTENT_MODEL_DIR = "/workspace/models/"
DATA_DIR = "/workspace/data/synthetic-data-1.json"
MODEL_PREFIX = [
    'Qwen3-1.7B',
    'Qwen3-4B',
    'Qwen3-8B',
    'Qwen3-14B',
    'Phi-3-mini-4k-instruct' # ideally, we'll later use this to check for generality
][1]

# import benchmarking dataset 
count_df = pd.read_json(DATA_DIR)

# Formatting 
Get test data into appropriate form, including instruct formatting. 

Create a torch dataset to hold all test examples + use dataloader or efficient batched inference during evaluation.

In [149]:
tokenizer, model = load_model(PERSISTENT_MODEL_DIR + MODEL_PREFIX)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [153]:
count_df

Unnamed: 0,sample_ix,category,list_length,clean_list,clean_category_indices,clean_category_count,corrupt_list,corrupt_category_indices,corrupt_category_count
0,0,colors,3,"[red, banana, computer]",[0],1,"[apple, banana, computer]",[],0
1,1,vehicles,3,"[car, bus, banana]","[0, 1]",2,"[apple, bus, banana]",[1],1
2,2,shapes,3,"[square, circle, triangle]","[0, 1, 2]",3,"[square, apple, triangle]","[0, 2]",2
3,3,continents,3,"[computer, asia, apple]",[1],1,"[computer, banana, apple]",[],0
4,4,metals,3,"[banana, gold, silver]","[1, 2]",2,"[banana, apple, silver]",[2],1
...,...,...,...,...,...,...,...,...,...
4984,44,street_signs,7,"[rose, car, yield, speed, apple, stop, pedestr...","[2, 3, 5, 6]",4,"[rose, car, yield, speed, apple, dog, pedestrian]","[2, 3, 6]",3
4985,45,continents,7,"[asia, dog, africa, europe, apple, antarctica,...","[0, 2, 3, 5, 6]",5,"[asia, dog, africa, europe, apple, car, americas]","[0, 2, 3, 6]",4
4986,46,time_units,7,"[second, year, car, day, month, minute, hour]","[0, 1, 3, 4, 5, 6]",6,"[apple, year, car, day, month, minute, hour]","[1, 3, 4, 5, 6]",5
4987,47,elements,7,"[hydrogen, carbon, gold, silver, helium, oxyge...","[0, 1, 2, 3, 4, 5, 6]",7,"[hydrogen, apple, gold, silver, helium, oxygen...","[0, 2, 3, 4, 5, 6]",6


In [154]:
count_df['clean_list_string'] = [' '.join(ele) for ele in count_df['clean_list']]

count_df['prompt'] = count_df.apply(
    lambda row: create_counting_prompt(
        entity_type = row['category'],
        word_list = row['clean_list_string'],
        tokenizer = tokenizer
    ), axis = 1
)

# pprint(count_df['prompt'][0])

In [156]:
# load dataset into a dataloader – this will help handle batching
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, sample_index, tokenized_prompts):
        self.sample_index = sample_index
        self.input_ids = tokenized_prompts['input_ids']
        self.attention_mask = tokenized_prompts['attention_mask']

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'sample_index': self.sample_index[idx],
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }

# we're tokenizing everything here, but the dataloader will handle batching later :) "tokenizing is very cheap"
tokenized_prompts = tokenizer(count_df['prompt'].tolist(), add_special_tokens = False, max_length = 100, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated (since attention mask applies a 1 to indicate "some guy was here")

count_dl = DataLoader(TextDataset(
    count_df['sample_ix'].tolist(),
    tokenized_prompts # don't move to gpu yet (or will have big mem problems)
), batch_size = 4, shuffle = False)

# pprint(next(iter(count_dl)), width = 80)

tensor(75)


# Model evaluation

In [157]:
k = 1 
results = []
with torch.no_grad():
    for batch in tqdm(count_dl, total = len(count_dl)):
        sample_ix = batch["sample_index"]
        input_ids = batch["input_ids"].to(device)
        attn_mask = batch["attention_mask"].to(device)

        probs = model(input_ids, attention_mask=attn_mask).logits.softmax(-1)

        last_pos = attn_mask.sum(1) - 1
        b_idx = torch.arange(input_ids.size(0), device = device)

        tok_ids = input_ids[b_idx, last_pos]
        tok_probs = probs[b_idx, last_pos, :]

        
        topk_probs, topk_idx = torch.topk(tok_probs, k = k, dim = 1)
        tokens_flat = tokenizer.convert_ids_to_tokens(topk_idx.cpu().flatten().tolist())
        topk_tokens = [tokens_flat[i * k:(i + 1) * k] for i in range(len(tokens_flat) // k)]

        for s_ix, toks, ps in zip(sample_ix.tolist(), topk_tokens, topk_probs.cpu()):
            results.append(
                {
                    "sample_ix": s_ix,
                    "output_token": toks[0], # this code only supports k = 1 for now            
                    "output_prob": ps.tolist()[0]
                }
            )

100%|██████████| 1248/1248 [03:16<00:00,  6.35it/s]


In [161]:
results = pd.DataFrame(results)

# join with the original dataframe to get accuracy 
combined = pd.merge(results, count_df, on = 'sample_ix', how = 'inner')

# identify the # correct 
combined['is_correct'] = np.where(combined['output_token'] == combined['clean_category_count'], 1, 0)

# # filter for rows where is_correct = 1 – oh boy it never gets this right for 1.7
# combined.query("is_correct == 1")
# combined['is_correct'].unique()

# formatting accuracy (is answer returned as int)
combined["guess_int"] = pd.to_numeric(combined["output_token"], errors = "coerce")
combined["is_integer"]  = combined["guess_int"].notna()

# plausible counts (≤ list length)
combined["is_plausible"] = combined["is_integer"] & (
    combined["guess_int"] <= combined["list_length"]
)

In [170]:
##  accuracy metrics 
# todo: want to check accuracy for count + instruction following (does it output what we expect) – generally answer format following seems to be pretty good (even w/ 1.7B)
# todo: also want to check answer plausibility (basically, does the model say "4" when the full list is only 3)
# chance accuracy is expected prob. of a uniform random guess 
# being correct, averaged over dataset (since lists are of varying lengths)
n_rows = len(combined)
n_correct = combined['is_correct'].sum()

# overall accuracy 
acc = combined['is_correct'].mean()

# expected accuracy under a uniform random guess
chance = (1/combined['list_length']).mean()
accuracy_above_chance = acc - chance 

pct_integer = combined["is_integer"].mean()
pct_plausible = combined["is_plausible"].mean()
acc_on_plausible = combined.loc[combined["is_plausible"],
                                   "is_correct"].mean()

# param count 
params = list(model.parameters())
total = sum(p.numel() for p in params)
                                

metrics = pd.Series({
    'model': MODEL_PREFIX,
    'params': total/1e9, # in billions
    'total_rows': n_rows,
    'total_correct': n_correct,
    'accuracy': acc,
    'chance_accuracy': chance,
    'accuracy_above_chance': accuracy_above_chance,
    'pct_integer': pct_integer, # correct format
    'pct_plausible': pct_plausible, # % of rows where guess was ≤ total list length
    'acc_on_plausible': acc_on_plausible
})

metrics

model                    Qwen3-4B
params                   4.022468
total_rows                 497833
total_correct                   0
accuracy                      0.0
chance_accuracy          0.210803
accuracy_above_chance   -0.210803
pct_integer                   1.0
pct_plausible            0.946084
acc_on_plausible              0.0
dtype: object

In [None]:
# write data for test results + agg. metrics 
combined.to_csv(f"/workspace/data/results_{MODEL_PREFIX}.csv", index = False)
metrics.to_csv(f"/workspace/data/agg_metrics_{MODEL_PREFIX}.csv", index = False)

In [143]:
# remove model object from environment after benchmarking is complete 
del model

## Plot 
dims of interest: correct length, total length, accuracy – also want to see how performance changes across model sizes.

1. Model size (params) v. accuracy 
2. Total length, accuracy (check to see how accuracy changes as list length changes)
3. Correct length (x), total length (y) – a heatmap

