In [1]:
# Imports
import torch
from utils.memory import check_memory, profile_memory, clear_all_cuda_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm 

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

All CUDA memory cleared on all devices.
Device 0: NVIDIA A40
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 44.45 GB



## Load Model

In [3]:
hf_model_id = 'Qwen/Qwen1.5-MoE-A2.7B-Chat'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda()

check_memory()

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Device 0: NVIDIA A40
  Allocated: 31.11 GB
  Reserved: 37.00 GB
  Total: 44.45 GB



In [4]:
# Test forward pass

# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_router_logits = True)

all_topk_experts = ()
for l, layer_router_logits in enumerate(outputs.router_logits):
    # layer_router_logits is shape [B*N, num_experts]
    gating_probs = torch.softmax(layer_router_logits, dim = -1)
    _, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1)
    all_topk_experts += (topk_experts,) 

all_topk_experts[0]

tensor([[ 5,  9, 47, 54],
        [ 5,  9, 47, 54],
        [ 5,  9, 47, 54],
        ...,
        [50,  4, 41, 30],
        [40, 37, 11, 35],
        [ 8, 46, 42,  5]], device='cuda:0')

## Get MMLU

In [5]:
""""
Get MMLU data and domains to test
"""
from datasets import load_dataset

mmlu_ds = load_dataset("cais/mmlu", 'all', split = 'test')
print(mmlu_ds[0])

# Only retain domains for high school subjects
domains_to_keep = ['college_biology', 'college_medicine', 'college_computer_science', 'college_mathematics']
domains = [x for x in list(set([x['subject'] for x in mmlu_ds])) if x in domains_to_keep]
print(domains)

# Now let's put the MMLU questions into a list gruoped by domain
all_domain_questions = [
    {
        'domain': domain,
        'questions': 
            [
                {'question': q['question'], 'choices': q['choices'], 'answer_index': q['answer'], 'answer_char': chr(65 + q['answer'])}
                for q in mmlu_ds
                if q['subject'] == domain 
            ]
    }
    for domain in tqdm(domains)
]

all_domain_questions[0]

{'question': 'Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.', 'subject': 'abstract_algebra', 'choices': ['0', '4', '2', '6'], 'answer': 1}
['college_biology', 'college_computer_science', 'college_mathematics', 'college_medicine']


100%|██████████| 4/4 [00:02<00:00,  1.74it/s]


{'domain': 'college_biology',
 'questions': [{'question': 'Based on the characteristic population curves that result from plotting population growth of a species, the most effective means of controlling the mosquito population is to',
   'choices': ['maintain the population at a point corresponding to the midpoint of its logistic curve',
    'opt for zero population control once the K value of the curve has been reached',
    'reduce the carrying capacity cif the environment to lower the K value',
    'increase the mortality rate'],
   'answer_index': 2,
   'answer_char': 'C'},
  {'question': 'A frameshift mutation is created when',
   'choices': ['telomeric sequences are removed from DNA',
    "a codon's nucleotide sequence changes so that it calls for production of a different amino acid than the original one",
    'a base pair is either inserted or deleted in a gene',
    "a codon's nucleotide sequence is changed so that instead of coding for a given amino acid it acts to terminate 

In [6]:
"""
Create function to map MMLU data into questions
"""
def prep_question(question, choices):
    
    prompt = f"Question: {question}\nChoices:\n"
    
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"

    return prompt

# print(prep_question(mmlu_ds[0]['question'], mmlu_ds[0]['choices']))

print(tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
        {'role': 'user', 'content': prep_question(mmlu_ds[0]['question'], mmlu_ds[0]['choices'])},
        {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + mmlu_ds[0]['answer'])}
    ],
    tokenize = False, add_generation_prompt = False, continue_final_message = True
))

<|im_start|>system
You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.<|im_end|>
<|im_start|>user
Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
Choices:
(A) 0
(B) 4
(C) 2
(D) 6
<|im_end|>
<|im_start|>assistant
The correct answer is B


In [7]:
import pandas as pd
from utils.store_topk import convert_topk_to_df

cat_results = []

for this_domain in all_domain_questions:

    domain_questions = this_domain['questions']

    count_correct = 0
    count_incorrect = 0
    topk_dfs = []
    results = []
    for question_ix, q in tqdm(enumerate(domain_questions[3:203])):

        input_prompt = tokenizer.apply_chat_template(
            [
                {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
                {'role': 'user', 'content': prep_question(domain_questions[0]['question'], domain_questions[0]['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is ' + domain_questions[0]['answer_char']},
                {'role': 'user', 'content': prep_question(domain_questions[1]['question'], domain_questions[1]['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is ' + domain_questions[1]['answer_char']},
                {'role': 'user', 'content': prep_question(domain_questions[2]['question'], domain_questions[2]['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is ' + domain_questions[2]['answer_char']},
                {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is'},
            ],
            tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
        )
        inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        with torch.no_grad():
            outputs = model(input_ids = input_ids, attention_mask = attention_mask, output_router_logits = True)
            all_topk_experts = ()
            for l, layer_router_logits in enumerate(outputs.router_logits):
                # layer_router_logits is shape [B*N, num_experts]
                gating_probs = torch.softmax(layer_router_logits, dim = -1)
                _, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1)
                all_topk_experts += (topk_experts,) 

            topk_df = convert_topk_to_df(all_topk_experts, input_ids).assign(domain = this_domain['domain'], question_ix = question_ix).drop(columns = 'sequence_ix')
            topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask
            topk_dfs.append(topk_df)

            next_token_logits = outputs['logits'][0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
            predicted_text = tokenizer.decode([next_token_id]).strip()

        predicted_letter = None
        for c in predicted_text:
            if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                predicted_letter = c.upper()
                break

        result = {
            'domain': this_domain['domain'],
            'question_ix': question_ix, 
            'model_output': predicted_text,
            'model_choice': predicted_letter,
            'correct_choice': q['answer_char'],
            'is_correct': 1 if predicted_letter == q['answer_char'] else 0
        }

        if result['is_correct'] == 1:
            count_correct += 1
        else:
            count_incorrect += 1

        results.append(result)
    

    cat_results.append({
        'answer_df': pd.DataFrame(results),
        'topks_df': pd.concat(topk_dfs)
    })

    print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')

141it [00:52,  2.68it/s]


college_biology | Correct: 95 | Incorrect: 46 | Accuracy: 67.4%


97it [00:38,  2.49it/s]


college_computer_science | Correct: 44 | Incorrect: 53 | Accuracy: 45.4%


97it [00:36,  2.63it/s]


college_mathematics | Correct: 33 | Incorrect: 64 | Accuracy: 34.0%


170it [01:05,  2.59it/s]

college_medicine | Correct: 98 | Incorrect: 72 | Accuracy: 57.6%





In [8]:
vocab_map =\
    pd.DataFrame([{"token": token.replace('Ġ', ' '), "token_id": token_id} for token, token_id in tokenizer.get_vocab().items()])\
    .sort_values(by = 'token_id')\
    .reset_index()

display(vocab_map)

all_answers = pd.concat([cat['answer_df'] for cat in cat_results])
display(all_answers)

all_topks = pd.concat([cat['topks_df'] for cat in cat_results]).merge(vocab_map, how = 'left', on = 'token_id')
display(all_topks)

Unnamed: 0,index,token,token_id
0,120481,!,0
1,137205,"""",1
2,123677,#,2
3,148195,$,3
4,118638,%,4
...,...,...,...
151641,72624,âºŁ,151641
151642,86374,â½Ĺ,151642
151643,139398,<|endoftext|>,151643
151644,37768,<|im_start|>,151644


Unnamed: 0,domain,question_ix,model_output,model_choice,correct_choice,is_correct
0,college_biology,0,B,B,B,1
1,college_biology,1,D,D,A,0
2,college_biology,2,A,A,A,1
3,college_biology,3,B,B,B,1
4,college_biology,4,B,B,B,1
...,...,...,...,...,...,...
165,college_medicine,165,A,A,A,1
166,college_medicine,166,C,C,A,0
167,college_medicine,167,C,C,D,0
168,college_medicine,168,C,C,C,1


Unnamed: 0,token_ix,token_id,layer_ix,expert_1,expert_2,expert_3,expert_4,domain,question_ix,index,token
0,0,151644,0,33,9,6,45,college_biology,0,37768,<|im_start|>
1,1,8948,0,33,14,16,51,college_biology,0,55191,system
2,2,198,0,5,28,43,38,college_biology,0,49933,Ċ
3,3,2610,0,5,6,58,45,college_biology,0,14049,You
4,4,686,0,27,11,47,4,college_biology,0,27363,will
...,...,...,...,...,...,...,...,...,...,...,...
5959003,418,198,23,21,5,58,49,college_medicine,169,49933,Ċ
5959004,419,785,23,57,32,9,13,college_medicine,169,56819,The
5959005,420,4396,23,36,22,32,1,college_medicine,169,104232,correct
5959006,421,4226,23,24,1,57,25,college_medicine,169,11644,answer


In [None]:
model.config

In [9]:
all_answers.to_csv('qwen_all_answers.csv', index = False)
all_topks.to_csv('qwen_all_topks.csv', index = False)

In [None]:
# from datasets import load_dataset

# mmlu_ds = load_dataset("TIGER-Lab/MMLU-Pro", split = 'test')