In [None]:
# Imports
import torch
from utils.memory import check_memory, profile_memory, clear_all_cuda_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm 

main_device = 'cuda:0'
seed = 1234
check_memory()

## Load Model

In [None]:
hf_model_id = 'deepseek-ai/DeepSeek-V2-Lite'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda()

check_memory()

In [None]:
# Test forward pass

# Hooks needed: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

def attach_moe_gate_hooks(model):
    """
    Registers forward-hooks on each MoE gating module in 'model' so that after a forward pass,
    we can retrieve the BN x topk 'topk_idx' from each layer.
    
    Returns:
        all_expert_ids: A list that will be filled at runtime withtuples of (layer_index, topk_idx_tensor).
        handles: A dictionary of {layer_index: hook_handle}, so you can remove them if desired.
    """
    all_expert_ids = []
    handles = {}

    def gate_forward_hook(module, input, output):
        """
        This hook is triggered after MoEGate.forward(...).
        'output' should be the tuple: (topk_idx, topk_weight, aux_loss).
        We only need topk_idx here.
        """
        topk_idx, _, _ = output
        all_expert_ids.append(topk_idx.detach())

    for layer_idx, layer in enumerate(model.model.layers):
        # Layer 0 is not moe
        if layer_idx > 0:
            # attach an attribute so we know which layer this gating belongs to
            layer.mlp.gate._layer_id = layer_idx
            hook_handle = layer.mlp.gate.register_forward_hook(gate_forward_hook)
            handles[layer_idx] = hook_handle

    return all_expert_ids, handles


all_expert_ids, hook_handles = attach_moe_gate_hooks(model)

with torch.no_grad():
    outputs = model(**inputs)

for topk_idx_tensor in all_expert_ids:
    print(f"topk_idx shape = {topk_idx_tensor.shape}")
    # e.g. shape is [B*N, top_k]

for layer_idx, h in hook_handles.items():
    h.remove()

In [None]:
# Test forward pass

# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

with torch.no_grad():
    all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
    outputs = model(**inputs)
    for layer_idx, h in hook_handles.items():
        h.remove()

all_topk_experts[0]

## Get MMLU

In [None]:
""""
Get MMLU data and domains to test
"""
from datasets import load_dataset

mmlu_ds = load_dataset("cais/mmlu", 'all', split = 'test')
print(mmlu_ds[0])

# Only retain domains for high school subjects
domains = [x for x in list(set([x['subject'] for x in mmlu_ds])) if 'high_school_' in x]
print(domains)

# Now let's put the MMLU questions into a list gruoped by domain
all_domain_questions = [
    {
        'domain': domain,
        'questions': 
            [
                {'question': q['question'], 'choices': q['choices'], 'answer_index': q['answer'], 'answer_char': chr(65 + q['answer'])}
                for q in mmlu_ds
                if q['subject'] == domain 
            ]
    }
    for domain in tqdm(domains)
]

all_domain_questions[0]

In [None]:
"""
Create function to map MMLU data into questions
"""
def prep_question(question, choices):
    
    prompt = f"Question: {question}\nChoices:\n"
    
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"

    return prompt

# print(prep_question(mmlu_ds[0]['question'], mmlu_ds[0]['choices']))

print(tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
        {'role': 'user', 'content': prep_question(mmlu_ds[0]['question'], mmlu_ds[0]['choices'])},
        {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + mmlu_ds[0]['answer'])}
    ],
    tokenize = False, add_generation_prompt = False, continue_final_message = True
))

In [None]:
import pandas as pd
from utils.store_topk import convert_topk_to_df

cat_results = []

for this_domain in all_domain_questions[0:3]:

    domain_questions = this_domain['questions']

    count_correct = 0
    count_incorrect = 0
    topk_dfs = []
    results = []
    for question_ix, q in tqdm(enumerate(domain_questions[3:103])):

        input_prompt = tokenizer.apply_chat_template(
            [
                {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
                {'role': 'user', 'content': prep_question(domain_questions[0]['question'], domain_questions[0]['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is ' + domain_questions[0]['answer_char']},
                {'role': 'user', 'content': prep_question(domain_questions[1]['question'], domain_questions[1]['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is ' + domain_questions[1]['answer_char']},
                {'role': 'user', 'content': prep_question(domain_questions[2]['question'], domain_questions[2]['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is ' + domain_questions[2]['answer_char']},
                {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is'},
            ],
            tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
        )
        inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        with torch.no_grad():
            all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
            outputs = model(input_ids = input_ids, attention_mask = attention_mask)
            for layer_idx, h in hook_handles.items():
                h.remove()

            topk_df = convert_topk_to_df(all_topk_experts, input_ids).assign(domain = this_domain['domain'], question_ix = question_ix).drop(columns = 'sequence_ix')
            topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask
            topk_dfs.append(topk_df)

            next_token_logits = outputs['logits'][0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
            predicted_text = tokenizer.decode([next_token_id]).strip()

        predicted_letter = None
        for c in predicted_text:
            if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                predicted_letter = c.upper()
                break

        result = {
            'domain': this_domain['domain'],
            'question_ix': question_ix, 
            'model_output': predicted_text,
            'model_choice': predicted_letter,
            'correct_choice': q['answer_char'],
            'is_correct': 1 if predicted_letter == q['answer_char'] else 0
        }

        if result['is_correct'] == 1:
            count_correct += 1
        else:
            count_incorrect += 1

        results.append(result)
    

    cat_results.append({
        'answer_df': pd.DataFrame(results),
        'topks_df': pd.concat(topk_dfs)
    })

    print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')

In [None]:
vocab_map =\
    pd.DataFrame([{"token": token.replace('Ġ', ' '), "token_id": token_id} for token, token_id in tokenizer.get_vocab().items()])\
    .sort_values(by = 'token_id')\
    .reset_index()

display(vocab_map)

all_answers = pd.concat([cat['answer_df'] for cat in cat_results])
display(all_answers)

all_topks = pd.concat([cat['topks_df'] for cat in cat_results]).merge(vocab_map, how = 'left', on = 'token_id')
display(all_topks)

In [None]:
all_answers.to_csv('all_answers.csv', index = False)
all_topks.to_csv('all_topks.csv', index = False)

In [None]:
# from datasets import load_dataset

# mmlu_ds = load_dataset("TIGER-Lab/MMLU-Pro", split = 'test')