In [1]:
# setup runpod env
# !pip install -U bitsandbytes wandb flash_attn accelerate datasets trl transformers tokenizers peft sentencepiece

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch

# modelpath="mistralai/Mistral-7B-v0.1"
# modelpath="alpindale/Mistral-7B-v0.2-hf"
modelpath="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

# Load model
model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False) 

dataset = load_dataset("cais/mmlu", "all")

In [3]:
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
choices_tok = [ tokenizer(letters[i],add_special_tokens=False)["input_ids"][-1] for i in range(4) ]

# fshots = "The following are multiple choice questions (with answers) about  abstract algebra.\n\nFind all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer: B\n\nStatement 1 | If aH is an element of a factor group, then |aH| divides |a|. Statement 2 | If H and K are subgroups of G then HK is a subgroup of G.\nA. True, True\nB. False, False\nC. True, False\nD. False, True\nAnswer: B\n\nStatement 1 | Every element of a group generates a cyclic subgroup of the group. Statement 2 | The symmetric group S_10 has 10 elements.\nA. True, True\nB. False, False\nC. True, False\nD. False, True\nAnswer: C\n\nStatement 1| Every function from a finite set onto itself must be one to one. Statement 2 | Every subgroup of an abelian group is abelian.\nA. True, True\nB. False, False\nC. True, False\nD. False, True\nAnswer: A"

def format_mmlu(entry, include_answer = True):
    template = "{question}\n{choices}\nAnswer:"
    choices = [ f"{letters[i]}. {choice}" for i, choice in enumerate(entry["choices"])]
    choices =  "\n".join(choices)
    text = template.format(choices = choices, question = entry["question"])
    
    if include_answer:
        text += f" {letters[entry['answer']]}"
    return text

def mmlu_accuracy(entries, model, tokenizer, batch_size = 8):
    mmlu_questions = [format_mmlu(entry, include_answer = False) for entry in entries]
    mmlu_batches = [mmlu_questions[i:i + batch_size] for i in range(0, len(mmlu_questions), batch_size)]  
    
    total, correct = 0, 0    
    with tqdm(total=len(mmlu_batches)) as pbar:
        for batch_no, batch in enumerate(mmlu_batches):
            pbar.update()
            batch_tok = tokenizer(batch, return_tensors = "pt", padding = True).to("cuda")

            with torch.no_grad():
                batch_logits = model(**batch_tok).logits
                batch_logits.to("cpu")
                
            for i, logits in enumerate(batch_logits):
                # tok_in_len = torch.count_nonzero(batch_tok["attention_mask"][i]).item()
                # print(i, tokenizer.decode(batch_tok["input_ids"][i,-1]))
                model_choice = torch.argmax(logits[-1][choices_tok]).item()  # 0=batch, -1 is last logit, choices_tok = logits for A, B, C, D
                correct += 1 if model_choice == entries[total]["answer"] else 0
                total += 1
            pbar.set_postfix_str(f"acc={round(correct/total*100,2)}")
    return total, correct, correct/total*100

tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "left"

total, correct, acc = mmlu_accuracy(
    # dataset["test"].select(range(1000)),
    dataset["test"],
    model, 
    tokenizer,
    batch_size = 16
)
print("acc", acc)
print("total", total)
print("correct", correct)

100%|██████████| 878/878 [05:11<00:00,  2.82it/s, acc=25.02]

acc 25.01780373166216
total 14042
correct 3513



