In [127]:
# testing to see if models are capable of counting task 
# w/ performance > random chance
import pandas as pd
import numpy as np
from pprint import pprint
import importlib
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('..')
from utils.evals import load_model, create_counting_prompt
from tqdm import tqdm

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [64]:
PERSISTENT_MODEL_DIR = "/workspace/models/"
DATA_DIR = "/workspace/data/synthetic-data.json"
MODEL_PREFIX = [
    'Qwen3-1.7B',
    'Qwen3-4B',
    'Qwen3-8B',
    'Qwen3-14B',
    'Phi-3-mini-4k-instruct' # ideally, we'll later use this to check for generality
][0]

# import benchmarking dataset 
count_df = pd.read_json(DATA_DIR)

# Formatting 
Get test data into appropriate form, including instruct formatting. 

Create a torch dataset to hold all test examples + use dataloader or efficient batched inference during evaluation.

In [None]:
tokenizer, model = load_model(PERSISTENT_MODEL_DIR + MODEL_PREFIX[0])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [65]:
count_df['full_list_string'] = [' '.join(ele) for ele in count_df['full_list']]

count_df['prompt'] = count_df.apply(
    lambda row: create_counting_prompt(
        entity_type = row['category'],
        word_list = row['full_list_string'],
        tokenizer = tokenizer
    ), axis = 1
)

# pprint(count_df['prompt'][0])

In [None]:
# load dataset into a dataloader – this will help handle batching
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, sample_index, tokenized_prompts):
        self.sample_index = sample_index
        self.input_ids = tokenized_prompts['input_ids']
        self.attention_mask = tokenized_prompts['attention_mask']

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'sample_index': self.sample_index[idx],
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }

# we're tokenizing everything here, but the dataloader will handle batching later :) "tokenizing is very cheap"
tokenized_prompts = tokenizer(count_df['prompt'].tolist(), add_special_tokens = False, max_length = 100, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated (since attention mask applies a 1 to indicate "some guy was here")

count_dl = DataLoader(TextDataset(
    count_df['sample_ix'].tolist(),
    tokenized_prompts # don't move to gpu yet (or will have big mem problems)
), batch_size = 4, shuffle = False)

# pprint(next(iter(count_dl)), width = 80)

tensor(69)


# Model evaluation

In [116]:
k = 1 
results = []
with torch.no_grad():
    for batch in tqdm(count_dl, total = len(count_dl)):
        sample_ix = batch["sample_index"]
        input_ids = batch["input_ids"].to(device)
        attn_mask = batch["attention_mask"].to(device)

        probs = model(input_ids, attention_mask=attn_mask).logits.softmax(-1)

        last_pos = attn_mask.sum(1) - 1
        b_idx = torch.arange(input_ids.size(0), device = device)

        tok_ids = input_ids[b_idx, last_pos]
        tok_probs = probs[b_idx, last_pos, :]

        
        topk_probs, topk_idx = torch.topk(tok_probs, k = k, dim = 1)
        tokens_flat = tokenizer.convert_ids_to_tokens(topk_idx.cpu().flatten().tolist())
        topk_tokens = [tokens_flat[i * k:(i + 1) * k] for i in range(len(tokens_flat) // k)]

        for s_ix, toks, ps in zip(sample_ix.tolist(), topk_tokens, topk_probs.cpu()):
            results.append(
                {
                    "sample_ix": s_ix,
                    "output_token": toks[0], # this code only supports k = 1 for now            
                    "output_prob": ps.tolist()[0]
                }
            )

100%|██████████| 125/125 [00:11<00:00, 10.70it/s]


In [None]:
results = pd.DataFrame(results)

# join with the original dataframe to get accuracy 
combined = pd.merge(results, count_df, on = 'sample_ix', how = 'inner')

# identify the # correct 
combined['is_correct'] = np.where(combined['output_token'] == combined['category_length'], 1, 0)

# # filter for rows where is_correct = 1 – oh boy it never gets this right for 1.7
# combined.query("is_correct == 1")
# combined['is_correct'].unique()

In [None]:
##  accuracy metrics 
# todo: want to check accuracy for count + instruction following (does it output what we expect) – generally answer format following seems to be pretty good (even w/ 1.7B)
# todo: also want to check answer plausibility (basically, does the model say "4" when the full list is only 3)
# chance accuracy is expected prob. of a uniform random guess 
# being correct, averaged over dataset (since lists are of varying lengths)

# overall accuracy 
acc = combined['is_correct'].mean()

# check what % are returning integers 

# expected accuracy under a uniform random guess
chance = (1/combined['full_list_length']).mean()

print(f"Accuracy : {acc:.3%}  ({df['is_correct'].sum()}/{len(df)})")
print(f"Chance accuracy: {chance:.3%}")
print(f"Accuracy lift  : {acc - chance:+.3%}")




0.17511708028636683


## Plot 
1. Model size (params) v. accuracy 
2. Total length, accuracy (check to see how accuracy changes as list length changes)
3. Correct length (x), total length (y) – a heatmap

correct length, total length, accuracy 