# Model

In [1]:
import torch
import transformers
print("torch", torch.__version__)
print("transformers", transformers.__version__)

torch 2.1.2+cu121
transformers 4.39.2


In [2]:
args_seed = 14

model_id = "databricks/dbrx-instruct"

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


TiktokenTokenizerWrapper(name_or_path='databricks/dbrx-instruct', vocab_size=100277, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|pad|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	100257: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100277: AddedToken("<|pad|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100278: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100279: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [4]:
from datasets import Dataset
# sharegpt format
test_dataset = Dataset.from_parquet("dbrx_instruct_pruning_test_dataset.parquet")
test_dataset

Dataset({
    features: ['conversations'],
    num_rows: 2010
})

## chat_template

In [5]:
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200

def preprocess_chat_template(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False):
    conv_roles = ["system", "user", "assistant"]
    roles = {"system": conv_roles[0], "human": conv_roles[1], "gpt": conv_roles[2]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            # assert role == conv_roles[j % 2], f"{i}"
            messages.append({"role": role, "content": sentence["value"]})
        conversations.append(tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=False))

    if has_image:
        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            add_special_tokens=False
        ).input_ids
    targets = input_ids.clone()

    # Mask targets
    sep = "<|im_start|>assistant\n"
    for conversation, target in zip(conversations, targets):
        conv_sep = "<|im_end|>\n"
        rounds = conversation.split(conv_sep)
        cur_len = 0
        instruction_len = 0
        for i, rou in enumerate(rounds):
            if rou == "":
                break
            if sep not in rou:
                if has_image:
                    instruction_len = len(tokenizer_image_token(rou + conv_sep, tokenizer))
                else:
                    instruction_len = len(tokenizer(rou + conv_sep, add_special_tokens=False).input_ids)
                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
            cur_len += instruction_len
    return dict(
        input_ids=input_ids,
        labels=targets,
    )

In [9]:
sources = test_dataset[1]["conversations"]

item = preprocess_chat_template([sources], tokenizer)

print("input:", tokenizer.decode(item['input_ids'][item['input_ids'].ne(IMAGE_TOKEN_INDEX)]))
print()
print("label", tokenizer.decode(item['labels'][item['labels'].ne(IGNORE_INDEX | IMAGE_TOKEN_INDEX)]))

input: <|im_start|> system
You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.
YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.
You assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).
(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)
This is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.
YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSEL

# Pruning

In [4]:
import torch
import transformers
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnb
print("torch", torch.__version__)
print("transformers", transformers.__version__)

device_map = "auto"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True,
    device_map="cpu",
)
model

torch 2.1.2+cu121
transformers 4.39.2


Loading checkpoint shards:   0%|          | 0/61 [00:00<?, ?it/s]

DbrxForCausalLM(
  (transformer): DbrxModel(
    (wte): Embedding(100352, 6144)
    (blocks): ModuleList(
      (0-39): 40 x DbrxBlock(
        (norm_attn_norm): DbrxNormAttentionNorm(
          (norm_1): LayerNorm((6144,), eps=1e-05, elementwise_affine=True)
          (attn): DbrxFlashAttention2(
            (Wqkv): Linear(in_features=6144, out_features=8192, bias=False)
            (out_proj): Linear(in_features=6144, out_features=6144, bias=False)
            (rotary_emb): DbrxRotaryEmbedding()
          )
          (norm_2): LayerNorm((6144,), eps=1e-05, elementwise_affine=True)
        )
        (ffn): DbrxFFN(
          (router): DbrxRouter(
            (layer): Linear(in_features=6144, out_features=16, bias=False)
          )
          (experts): DbrxExperts(
            (mlp): DbrxExpertGLU()
          )
        )
      )
    )
    (norm_f): LayerNorm((6144,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=6144, out_features=100352, bias=False)
)

In [5]:
import torch
from torch import nn
from typing import Any, Callable, Dict, Optional, Tuple, Union

class DbrxFFNPruningWrapper(nn.Module):
    def __init__(self, model, r = None):
        super().__init__()
        self.model = model
        self.top_experts_list = []

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        weights, top_weights, top_experts = self.model.router(x)
        out = self.model.experts(x, weights, top_weights, top_experts)
        self.top_experts_list.append(top_experts.detach().to('cpu', non_blocking=True))
        return out, weights

In [6]:
for block in model.transformer.blocks:
    block.ffn = DbrxFFNPruningWrapper(block.ffn)

In [None]:
from tqdm.auto import tqdm

with torch.inference_mode():
    hidden_states = {}
    model.transformer.wte = model.transformer.wte.to("cuda")
    for i, sources in tqdm(enumerate(test_dataset["conversations"])):
        item = preprocess_chat_template([sources], tokenizer)
        inputs_embeds = model.transformer.wte.forward(item['input_ids'].to("cuda"))

        # hidden_states = inputs_embeds
        hidden_states[i] = inputs_embeds.detach().to('cpu', non_blocking=True)

In [None]:
import gc
import pandas as pd

def release_list(a):
   del a[:]
   del a

layer_dict = []
with torch.inference_mode():
    for block in tqdm(model.transformer.blocks):
        block = block.to("cuda")
        for i, sources in tqdm(enumerate(test_dataset["conversations"])):
            past_seen_tokens = 0
            block_outputs = block(
                hidden_states[i].to("cuda"),
                torch.arange(  # type: ignore
                        past_seen_tokens,
                        past_seen_tokens + hidden_states[i].shape[1],
                        device="cuda").unsqueeze(0).to("cuda")
            )
            hidden_states[i] = block_outputs[0].detach().to('cpu', non_blocking=True)

        layer_nums = []
        for x in block.ffn.top_experts_list[0]:
            for v in x.tolist():
                layer_nums.append(int(v))
        layer_nums = pd.Series(layer_nums)
        layer_dict.append(layer_nums.value_counts().reset_index())
        release_list(block.ffn.top_experts_list)

        block.to("meta")
        gc.collect()
        torch.cuda.empty_cache()

In [18]:
torch.save(layer_dict, "dbrx_instruct_pruning_layer_stats.pt")

# Model pruning

In [4]:
dbrx_ru_pruning_layer_stats = torch.load("dbrx_instruct_pruning_layer_stats.pt")
dbrx_ru_pruning_layer_stats[39]

Unnamed: 0,index,count
0,7,437
1,15,395
2,0,308
3,8,293
4,4,248
5,13,243
6,5,183
7,3,177
8,14,173
9,11,165


In [13]:
import json
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path
from huggingface_hub import hf_hub_download

model_dir = hf_hub_download(repo_id=model_id, filename="modeling_dbrx.py")
model_dir = Path(model_dir).parent

In [124]:
# https://huggingface.co/databricks/dbrx-instruct/discussions/10

NUM_EXPERTS = 16
HIDDEN_SIZE = 6144
FFN_HIDDEN_SIZE = 10752

TARGET_NUM_EXPERTS = 8

def change_tensor(tensor, reverse=False):
    output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]
    return output

def change_mlp(tensors):
    keys = list(tensors.keys())
    for k in keys:
        if "router" in k:
            block_idx = int(k.rsplit('.')[2])
            name = f"transformer.blocks.{block_idx}.ffn.router.layer.weight"
            experts_to_reserve = sorted(list(dbrx_ru_pruning_layer_stats[block_idx][:TARGET_NUM_EXPERTS]["index"].values))
            tensor = tensors.pop(k)
            t = tensor[experts_to_reserve]
            tensors[name] = t
        if any([x in k for x in ['w1', 'v1', 'w2']]):
            prefix,dtype = k.rsplit('.', 1)

            block_idx = int(prefix.rsplit('.')[2])
            experts_to_reserve = sorted(list(dbrx_ru_pruning_layer_stats[block_idx][:TARGET_NUM_EXPERTS]["index"].values))

            tensor = tensors.pop(k)
            output_tensor = change_tensor(tensor, dtype=='w2')
            for i in range(TARGET_NUM_EXPERTS):
                t = output_tensor[experts_to_reserve[i]]
                name = f'{prefix}.{i}.{dtype}.weight'
                tensors[name] = t
    return tensors

In [None]:
output_dir = Path("./dbrx-instruct-8X").absolute()

for file in model_dir.glob('*.safetensors'):
    print(file)
    tensors = {}
    with safe_open(file, 'pt') as f:
        metadata = f.metadata()
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    tensors = change_mlp(tensors)
    save_file(tensors, (output_dir / file.name).as_posix(), metadata)

In [126]:
with open(model_dir / 'model.safetensors.index.json') as f:
    weight_map = json.load(f)

In [128]:
weight_keys = list(weight_map['weight_map'])
for k in weight_keys:
    if any([x in k for x in ['w1', 'v1', 'w2']]):
        prefix,dtype = k.rsplit('.', 1)
        value = weight_map['weight_map'].pop(k)
        for i in range(TARGET_NUM_EXPERTS):
            weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value

In [133]:
weight_map["metadata"]['total_size'] = int(weight_map["metadata"]['total_size'] / (NUM_EXPERTS / TARGET_NUM_EXPERTS))
weight_map["metadata"]['total_size']

131596523520

In [134]:
sorted_map = sorted(list(weight_map['weight_map'].items()))
weight_map['weight_map'] = dict(sorted_map)

with open(output_dir / 'model.safetensors.index.json', 'w') as f:
    json.dump(weight_map, f, indent=4)

# Model_2

In [2]:
model_id = "dbrx-instruct-8X"

In [3]:
from transformers import AutoTokenizer

# fix 'eos_token': '<|im_end|>'
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


TiktokenTokenizerWrapper(name_or_path='dbrx-instruct-8X', vocab_size=100277, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|im_end|>', 'unk_token': '<|im_end|>', 'pad_token': '<|pad|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	100257: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100277: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100278: AddedToken("<|pad|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100279: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [4]:
import torch
import transformers
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnb
print("torch", torch.__version__)
print("transformers", transformers.__version__)

max_memory = {0:"24GiB", 1: "24GiB"}

qConfig = BitsAndBytesConfig(load_in_4bit=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True,
    device_map="auto",
    max_memory = max_memory,
    quantization_config = qConfig
)
model

torch 2.1.2+cu121
transformers 4.39.2


Loading checkpoint shards:   0%|          | 0/61 [00:00<?, ?it/s]

DbrxForCausalLM(
  (transformer): DbrxModel(
    (wte): Embedding(100352, 6144)
    (blocks): ModuleList(
      (0-39): 40 x DbrxBlock(
        (norm_attn_norm): DbrxNormAttentionNorm(
          (norm_1): LayerNorm((6144,), eps=1e-05, elementwise_affine=True)
          (attn): DbrxFlashAttention2(
            (Wqkv): Linear4bit(in_features=6144, out_features=8192, bias=False)
            (out_proj): Linear4bit(in_features=6144, out_features=6144, bias=False)
            (rotary_emb): DbrxRotaryEmbedding()
          )
          (norm_2): LayerNorm((6144,), eps=1e-05, elementwise_affine=True)
        )
        (ffn): DbrxFFN(
          (router): DbrxRouter(
            (layer): Linear4bit(in_features=6144, out_features=8, bias=False)
          )
          (experts): DbrxExperts(
            (mlp): ModuleList(
              (0-7): 8 x DbrxMLP(
                (w1): Linear4bit(in_features=6144, out_features=10752, bias=False)
                (v1): Linear4bit(in_features=6144, out_features=

# MMLU

In [5]:
subcategories = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}

categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}

# in the form to fit the prompt headline
subcategories_en2ru = {
    "abstract_algebra": "абстрактной_алгебре",
    "anatomy": "анатомии",
    "astronomy": "астрономии",
    "business_ethics": "деловой_этике",
    "clinical_knowledge": "медицинским_знаниям",
    "college_biology": "биологии_в_вузе",
    "college_chemistry": "химии_в_вузе",
    "college_computer_science": "компьютерным_наукам_в_вузе",
    "college_mathematics": "математике_в_вузе",
    "college_medicine": "медицине_в_вузе",
    "college_physics": "физике_в_вузе",
    "computer_security": "компьютерной_безопасности",
    "conceptual_physics": "теоретической_физике",
    "econometrics": "эконометрике",
    "electrical_engineering": "электротехнике",
    "elementary_mathematics": "элементарной_математике",
    "formal_logic": "формальной_логике",
    "global_facts": "фактам_о_мире",
    "high_school_biology": "биологии_в_старшей_школе",
    "high_school_chemistry": "химии_в_старшей_школе",
    "high_school_computer_science": "информатике_в_старшей_школе",
    "high_school_european_history": "истории_Европы_в_старшей_школе",
    "high_school_geography": "географии_в_старшей_школе",
    "high_school_government_and_politics": "государству_и_политике_в_старшей_школе",
    "high_school_macroeconomics": "макроэкономике_в_старшей_школе",
    "high_school_mathematics": "математике_в_старшей_школе",
    "high_school_microeconomics": "микроэкономике_в_старшей_школе",
    "high_school_physics": "физике_в_старшей_школе",
    "high_school_psychology": "психологии_в_старшей_школе",
    "high_school_statistics": "статистике_в_старшей_школе",
    "high_school_us_history": "истории_США_в_старшей_школе",
    "high_school_world_history": "всемирной_истории_в_старшей_школе",
    "human_aging": "старению_человека",
    "human_sexuality": "человеческой_сексуальности",
    "international_law": "международному_праву",
    "jurisprudence": "юриспруденции",
    "logical_fallacies": "логическим_ошибкам",
    "machine_learning": "машинному_обучению",
    "management": "менеджменту",
    "marketing": "маркетингу",
    "medical_genetics": "медицинской_генетике",
    "miscellaneous": "разным_темам",
    "moral_disputes": "нравственным_спорам",
    "moral_scenarios": "нравственным_сценариям",
    "nutrition": "правильному_питанию",
    "philosophy": "философии",
    "prehistory": "доисторической_эпохе",
    "professional_accounting": "профессиональному_бухгалтерскому_учету",
    "professional_law": "профессиональному_праву",
    "professional_medicine": "профессиональной_медицине",
    "professional_psychology": "профессиональной_психологии",
    "public_relations": "связям_с_общественностью",
    "security_studies": "исследованиям_в_области_безопасности",
    "sociology": "социологии",
    "us_foreign_policy": "внешней_политике_США",
    "virology": "вирусологии",
    "world_religions": "мировым_религиям",
}

In [6]:
import abc
import typing as tp

class Conversation(abc.ABC):
    """
    Inspired by https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
    """
    def __init__(self, system_prompt: str = "", roles: tp.Tuple[str, str] = ("user", "assistant")):
        self.system_prompt = system_prompt
        self.roles = roles
        self.messages: tp.List[tp.List[str, str]] = []

    def get_prompt(self) -> str:
        pass

    def update_last_message(self, text: str) -> None:
        self.messages[-1] = (self.messages[-1][0], text)

    def append_message(self, role: str, text: str) -> None:
        self.messages.append({"role":role, "content":text})

class EmptyConversation(Conversation):

    #"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:"
    def __init__(self):
        super().__init__(
            system_prompt="",
            roles=("user", "assistant"),
        )

    def get_prompt(self) -> str:
        prompt = self.system_prompt
        for m in self.messages:
            prompt += str(m)
        return prompt

conversation_classes = {
    "empy_prompt_conv": EmptyConversation,
}

In [7]:
import argparse
import json
import logging
import os
import pathlib
import typing as tp

import pandas as pd
import datasets
import peft
import transformers
import torch
from tqdm.auto import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


LANGUAGE_CONFIG: tp.Dict[str, tp.Dict[str, str]] = {
    "en": {
        "headline_prefix": "The following are multiple choice questions (with answers) about",
        "answer_prefix": "Answer:",
    },
    "ru": {
        "headline_prefix": "Ниже приведены вопросы с множественным выбором (с ответами) по",
        "answer_prefix": "Ответ:",
    },
}

In [8]:
def get_df_in_hendrycks_format(subject: str, split: str, lang: str) -> pd.DataFrame:
    dataset = datasets.load_dataset("NLPCoreTeam/mmlu_ru", name=subject, split=split)
    wanted_cols = {
        "en": ["question_en", "choices_en", "answer"],
        "ru": ["question_ru", "choices_ru", "answer"],
    }[lang]
    df = dataset.to_pandas()[wanted_cols]
    int2str = dataset.features["answer"].int2str
    df[df.columns[2]] = df[df.columns[2]].apply(lambda x: int2str(x))
    df = pd.concat([
        df[[df.columns[0]]],
        pd.DataFrame(df[df.columns[1]].tolist()),
        df[[df.columns[2]]],
    ], axis=1)
    df.columns = range(len(df.columns))
    return df

In [9]:
def format_subject(subject: str) -> str:
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s.strip()

def get_pretty_subject(subject: str, lang: str) -> str:
    return format_subject({
        "en": subject,
        "ru": subcategories_en2ru[subject],  # predefined map
    }[lang])

def get_prompt_from_dataframes(dev_df: pd.DataFrame, test_df: pd.DataFrame,
                               k: int, test_iloc_idx: int, lang: str, subject: str, conversation_type: str):
    assert 0 <= k <= 5
    headline_prefix = LANGUAGE_CONFIG[lang]["headline_prefix"]
    headline_postfix = get_pretty_subject(subject=subject, lang=lang)
    headline = f"{headline_prefix} {headline_postfix}.\n\n"

    answer_prefix = LANGUAGE_CONFIG[lang]["answer_prefix"]

    conv = conversation_classes[conversation_type]()

    is_already_taken_headline = False
    for row_idx, row in dev_df.head(k).iterrows():
        q = row[0]
        options = row[1:5].tolist()
        lettered_options = [f"({x}) {y}" for x, y in zip(["A", "B", "C", "D"], options)]
        q_with_lettered_options = "\n".join([q, " ".join(lettered_options)])
        if row_idx == 0:
            q_with_lettered_options = headline + q_with_lettered_options
            is_already_taken_headline = True
        conv.append_message(conv.roles[0], q_with_lettered_options)
        a = row[5]
        
        # if is not instruct, needed to be manually separated for mmlu examples
        conv.append_message(conv.roles[1], f"{answer_prefix}{a}")

    row = test_df.iloc[test_iloc_idx]
    q = row[0]
    options = row[1:5].tolist()
    lettered_options = [f"({x}) {y}" for x, y in zip(["A", "B", "C", "D"], options)]
    q_with_lettered_options = "\n".join([q, " ".join(lettered_options)])
    if not is_already_taken_headline:
        q_with_lettered_options = headline + q_with_lettered_options
        is_already_taken_headline = True
    conv.append_message(conv.roles[0], q_with_lettered_options)
    a = row[5]
    conv.append_message(conv.roles[1], answer_prefix)
    # prompt = f"{conv.get_prompt()}{answer_prefix}"
    return conv.messages

def calculate_token_interest_probs(
    input_ids,
    tokenizer: transformers.PreTrainedTokenizerBase,
    model: tp.Union[transformers.PreTrainedModel, peft.peft_model.PeftModelForCausalLM],
) -> tp.Dict[str, float]:    
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    logits = outputs.logits  # shape (batch_size, sequence_length, vocab_size)
    next_token_logits = logits[:, -1, :]  # shape (batch_size, vocab_size)

    next_token_logits = next_token_logits.flatten()
    assert next_token_logits.shape == torch.Size((model.config.vocab_size, ))

    next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1).cpu()  # all probs over vocab
    # assert torch.isclose(next_token_probs.sum(), torch.tensor(1.0).to(next_token_probs.dtype), atol=1e-03)  # dtype for half/nothalf, -03 for float16
    
    tokens_of_interest = [
        tokenizer("A", add_special_tokens=False).input_ids[-1],
        tokenizer("B", add_special_tokens=False).input_ids[-1],
        tokenizer("C", add_special_tokens=False).input_ids[-1],
        tokenizer("D", add_special_tokens=False).input_ids[-1],
    ]

    probs = next_token_probs[tokens_of_interest].tolist()
    res = dict(zip(["A", "B", "C", "D"], probs))
    return res

def append_to_jsonl(data: list, filename: str) -> None:
    with open(filename, "a") as f:
        f.write(json.dumps(data) + "\n")

def evaluate_subject(
    subject: str,
    lang: str,
    k_shot: int,
    jsonl_filepath: str,
    maxlen: int,
    convtype: str,
    tokenizer: transformers.PreTrainedTokenizerBase,
    model: tp.Union[transformers.PreTrainedModel, peft.peft_model.PeftModelForCausalLM],
) -> None:

    dev_df = get_df_in_hendrycks_format(subject=subject, split="dev", lang=lang)
    test_df = get_df_in_hendrycks_format(subject=subject, split="test", lang=lang)

    for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc=subject):

        current_k_shot = k_shot
        skip_too_lengthy = False
        while True:
            if current_k_shot < 0:
                logger.info("Skip too lengthy.")
                skip_too_lengthy = True
                break
            input_messages = get_prompt_from_dataframes(
                dev_df=dev_df,
                test_df=test_df,
                k=current_k_shot,
                test_iloc_idx=idx,
                lang=lang,
                subject=subject,
                conversation_type=convtype,
            )
            input_prompt = tokenizer.apply_chat_template(input_messages, tokenize=False, add_special_tokens=False)[:-len(tokenizer.eos_token)]
            input_ids = tokenizer.encode(input_prompt, return_tensors="pt", add_special_tokens=False).to(model.device)

            if input_ids.shape[-1] > maxlen and current_k_shot >= 0:
                logger.info("Takes smaller current_k_shot since maxlen.")
                current_k_shot -= 1
            elif current_k_shot < 0:
                logger.info("Skip too lengthy.")
                skip_too_lengthy = True
            else:
                break
        if skip_too_lengthy:
            continue

        label = row[5]

        preds = calculate_token_interest_probs(
            input_ids=input_ids,
            tokenizer=tokenizer,
            model=model,
        )

        append_to_jsonl(data=[input_prompt, label, preds], filename=jsonl_filepath)

In [10]:
%%time
lang = "en"
subject = "abstract_algebra"
convtype = "empy_prompt_conv"
current_k_shot = 5
idx = 0
dev_df = get_df_in_hendrycks_format(subject=subject, split="dev", lang=lang)
test_df = get_df_in_hendrycks_format(subject=subject, split="test", lang=lang)

input_messages = get_prompt_from_dataframes(
    dev_df=dev_df,
    test_df=test_df,
    k=current_k_shot,
    test_iloc_idx=idx,
    lang=lang,
    subject=subject,
    conversation_type=convtype,
)
input_prompt = tokenizer.apply_chat_template(input_messages, tokenize=False, add_special_tokens=False)[:-len(tokenizer.eos_token)]
print(input_prompt)
input_ids = tokenizer.encode(input_prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
with torch.no_grad():
    output_ids = model.generate(input_ids, max_new_tokens=10, eos_token_id=tokenizer.eos_token_id)
ouput_str = tokenizer.decode(output_ids[0]).strip()
ouput_str

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:100277 for open-end generation.


<|im_start|>system
You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.
YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.
You assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).
(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)
This is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.
YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS

"<|im_start|>system\nYou are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF

In [11]:
output_dir = "mmlu_en_dbrx-instruct_8X_bnb4q"
lang = "en"
k_shot = 5
maxlen = 8192
convtype = "empy_prompt_conv"

subjects = list(subcategories.keys())
for each_subject in subjects:
    jsonl_filepath = str(pathlib.Path(output_dir) / f"{each_subject}.jsonl")
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
    logger.info(f"Filepath JSONL: {jsonl_filepath}")
    if pathlib.Path(jsonl_filepath).exists():
        logger.info(f"File already exists! Please manually verify that it wasn't partially interrupted.")
        continue
    evaluate_subject(
            subject=each_subject,
            lang=lang,
            k_shot=k_shot,
            jsonl_filepath=jsonl_filepath,
            maxlen=maxlen, convtype=convtype,
            tokenizer=tokenizer,
            model=model,
    )

In [12]:
import numpy as np

category_to_main_category = {value: key for key, sublist in categories.items() for value in sublist}
subcategories2categories = {key: category_to_main_category[value[0]] for key, value in subcategories.items()}

def calculate_accuracy_from_directory(dirpath: str) -> tp.Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    assert pathlib.Path(dirpath).exists()
    filepaths = [str(x) for x in pathlib.Path(dirpath).glob('*.jsonl')]
    # assert len(filepaths) == 57
    res = {}
    for each_filepath in filepaths:
        df = pd.read_json(each_filepath, lines=True)
        df.columns = ['prompt', 'label', 'preds']
        cors = []
        for idx, row in df.iterrows():
            preds = row['preds']
            best_idx = np.argmax(list(preds.values()))
            y_pred = list(preds.keys())[best_idx]
            y_true = row['label']
            y_pred = y_pred.strip()
            y_true = y_true.strip()
            cors.append(y_true == y_pred)
        acc = np.mean(cors)
        res[pathlib.Path(each_filepath).stem] = acc * 100
    
    df = pd.DataFrame({pathlib.Path(dirpath).stem: res}).reset_index()
    df = df.rename(columns={'index': 'subcategory'})
    subcategories_df = df.copy()
    
    df = subcategories_df.copy()
    df['subcategory'] = df['subcategory'].map(subcategories2categories)
    df = df.rename(columns={'subcategory': 'category'})
    df = df.groupby('category').mean().reset_index()
    categories_df = df.copy()
    
    total_df = pd.DataFrame({pathlib.Path(dirpath).stem: [categories_df[pathlib.Path(dirpath).stem].mean()]})
    
    # assert subcategories_df.shape == (57, 2)
    # assert categories_df.shape == (4, 2)
    # assert total_df.shape == (1, 1)
    return (subcategories_df, categories_df, total_df)

subcategories_df, categories_df, total_df = calculate_accuracy_from_directory(dirpath=output_dir)
print(total_df.shape)
total_df

(1, 1)


Unnamed: 0,mmlu_en_dbrx-instruct_8X_bnb4q
0,24.333333


In [13]:
categories_df

Unnamed: 0,category,mmlu_en_dbrx-instruct_8X_bnb4q
0,STEM,22.0
1,"other (business, health, misc.)",26.666667


In [14]:
subcategories_df

Unnamed: 0,subcategory,mmlu_en_dbrx-instruct_8X_bnb4q
0,abstract_algebra,22.0
1,anatomy,26.666667
