In [60]:
# testing to see if models are capable of counting task 
# w/ performance > random chance
import pandas as pd
from pprint import pprint
import importlib
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('..')
from utils.evals import load_model, create_counting_prompt
from tqdm import tqdm

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [64]:
PERSISTENT_MODEL_DIR = "/workspace/models/"
DATA_DIR = "/workspace/data/synthetic-data.json"
MODEL_PREFIX = [
    'Qwen3-1.7B',
    'Qwen3-4B',
    'Qwen3-8B',
    'Qwen3-14B',
    'Phi-3-mini-4k-instruct' # ideally, we'll later use this to check for generality
][0]

# import benchmarking dataset 
count_df = pd.read_json(DATA_DIR)

# Formatting 
Get test data into appropriate form, including instruct formatting. 

Create a torch dataset to hold all test examples + use dataloader or efficient batched inference during evaluation.

In [None]:
tokenizer, model = load_model(PERSISTENT_MODEL_DIR + MODEL_PREFIX[0])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [65]:
count_df['full_list_string'] = [' '.join(ele) for ele in count_df['full_list']]

count_df['prompt'] = count_df.apply(
    lambda row: create_counting_prompt(
        entity_type = row['category'],
        word_list = row['full_list_string'],
        tokenizer = tokenizer
    ), axis = 1
)

# pprint(count_df['prompt'][0])

In [None]:
# load dataset into a dataloader – this will help handle batching
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, sample_index, tokenized_prompts):
        self.sample_index = sample_index
        self.input_ids = tokenized_prompts['input_ids']
        self.attention_mask = tokenized_prompts['attention_mask']

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'sample_index': self.sample_index[idx],
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }

# we're tokenizing everything here, but the dataloader will handle batching later :) "tokenizing is very cheap"
tokenized_prompts = tokenizer(count_df['prompt'].tolist(), add_special_tokens = False, max_length = 100, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated (since attention mask applies a 1 to indicate "some guy was here")

count_dl = DataLoader(TextDataset(
    count_df['sample_ix'].tolist(),
    tokenized_prompts # don't move to gpu yet (or will have big mem problems)
), batch_size = 4, shuffle = False)

# pprint(next(iter(count_dl)), width = 80)

tensor(69)


# Model evaluation

In [None]:
k = 1 
results = []
with torch.no_grad():
    for batch in tqdm(count_dl, total = len(count_dl)):
        sample_ix = batch["sample_index"]
        input_ids = batch["input_ids"].to(device)
        attn_mask = batch["attention_mask"].to(device)

        probs = model(input_ids, attention_mask=attn_mask).logits.softmax(-1)

        last_pos = attn_mask.sum(1) - 1
        b_idx = torch.arange(input_ids.size(0), device = device)

        tok_ids = input_ids[b_idx, last_pos]
        tok_probs = probs[b_idx, last_pos, :]

        
        topk_probs, topk_idx = torch.topk(tok_probs, k = k, dim = 1)
        tokens_flat = tokenizer.convert_ids_to_tokens(topk_idx.cpu().flatten().tolist())
        topk_tokens = [tokens_flat[i * k:(i + 1) * k] for i in range(len(tokens_flat) // k)]

        for s_ix, toks, ps in zip(sample_ix.tolist(), topk_tokens, topk_probs.cpu()):
            results.append(
                {
                    "sample_ix": s_ix,
                    "output_token": toks[0], # this code only supports k = 1 for now            
                    "output_prob": ps.tolist()[0]
                }
            )

100%|██████████| 125/125 [00:11<00:00, 10.67it/s]


In [None]:

pd.DataFrame(results)

Unnamed: 0,sample_ix,tokens,probs
0,0,1,0.999864
1,1,0,0.708974
2,2,1,0.999998
3,3,1,0.983762
4,4,2,0.886190
...,...,...,...
493,45,1,0.999995
494,46,1,1.000000
495,47,1,0.990946
496,48,1,0.999947


In [98]:
count_df

Unnamed: 0,sample_ix,category,category_list,category_length,noncategory_list,noncategory_length,full_list,full_list_length,full_list_category_indices,full_list_string,prompt
0,0,colors,[],0,"[cat, dog, table, car]",4,"[dog, table, car, cat]",4,[],dog table car cat,<|im_start|>user\nCount the number of words in...
1,1,planets,[mars],1,"[house, dog, apple]",3,"[house, dog, mars, apple]",4,[2],house dog mars apple,<|im_start|>user\nCount the number of words in...
2,2,vehicles,"[car, bus]",2,"[blue, table, dog]",3,"[dog, blue, bus, table, car]",5,"[2, 4]",dog blue bus table car,<|im_start|>user\nCount the number of words in...
3,3,clothing,"[shirt, pants, hat]",3,"[rock, cat, guitar]",3,"[pants, cat, guitar, hat, rock, shirt]",6,"[0, 3, 5]",pants cat guitar hat rock shirt,<|im_start|>user\nCount the number of words in...
4,4,instruments,"[piano, guitar, drums, violin]",4,"[apple, river, dog]",3,"[violin, piano, apple, dog, guitar, drums, river]",7,"[0, 1, 4, 5]",violin piano apple dog guitar drums river,<|im_start|>user\nCount the number of words in...
...,...,...,...,...,...,...,...,...,...,...,...
493,45,natural_disasters,"[earthquake, tsunami]",2,"[apple, cat, table, rose, violin]",5,"[cat, violin, table, tsunami, apple, earthquak...",7,"[3, 5]",cat violin table tsunami apple earthquake rose,<|im_start|>user\nCount the number of words in...
494,46,yoga_poses,[downwarddog],1,"[apple, car, table, rose, violin]",5,"[downwarddog, apple, table, violin, car, rose]",6,[0],downwarddog apple table violin car rose,<|im_start|>user\nCount the number of words in...
495,47,cloud_types,[],0,"[cirrus, cumulus, stratus, nimbus]",4,"[cirrus, stratus, cumulus, nimbus]",4,[],cirrus stratus cumulus nimbus,<|im_start|>user\nCount the number of words in...
496,48,illnesses,"[flu, cold, measles]",3,"[apple, car, rose]",3,"[measles, car, apple, flu, rose, cold]",6,"[0, 3, 5]",measles car apple flu rose cold,<|im_start|>user\nCount the number of words in...


In [None]:
# chance accuracy is expected prob. of a uniform random guess 
# being correct, averaged over dataset (since lists are of varying lengths)


# want to check accuracy for count + instruction following (does it output what we expect)

In [None]:
# plotting: performance across model family