## Expert Usage for Mixtral 8x7B

After experimenting with the expert hooks in Mixtral 8x7B, I want to see how many tokens of different datasets are sent to which 

In [11]:
import torch as t
import pandas as pd
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import altair as alt
from tqdm import tqdm

import os
from collections import defaultdict
from dotenv import load_dotenv
from huggingface_hub import login

In [2]:
load_dotenv()
login(os.getenv("HF_TOKEN"), add_to_git_credential=True)

Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /home/ubuntu/.cache/huggingface/token
Login successful


## Datasets

I will be using 3 different datasets, and recording the percentages of tokens hanldled by each of the experts in the first and the last layers of the model. Let's see how this differs for different datasets. The datasets are
- `stanfordnlp/imdb` - this is a classification task for movie reviews.
- `databricks/databricks-dolly-15k` - plain 'ol question answering
- `bigcode/bigcodebench` - code generation from prompt

In [3]:
imdb_dataset = load_dataset("stanfordnlp/imdb", split="test")
qa_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
code_dataset = load_dataset("bigcode/bigcodebench", split="v0.1.0_hf")

## Model Setup

In [4]:
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=t.float16,
    bnb_4bit_use_double_quant=True
)

model = AutoModel.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 19/19 [02:18<00:00,  7.29s/it]


In [5]:
NUM_LAYERS = len(model.layers)
NUM_EXPERTS = len(model.layers[0].block_sparse_moe.experts)
print(model)

MixtralModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x MixtralDecoderLayer(
      (self_attn): MixtralSdpaAttention(
        (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): MixtralRotaryEmbedding()
      )
      (block_sparse_moe): MixtralSparseMoeBlock(
        (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
        (experts): ModuleList(
          (0-7): 8 x MixtralBlockSparseTop2MLP(
            (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
            (w3): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (act_fn): SiLU()
        

In [6]:
first_layer_usage = defaultdict(int)
last_layer_usage = defaultdict(int)

# these two will store the values for all the datasets 
all_first_layer_usages = dict()
all_last_layer_usages = dict()

In [7]:
def first_layer_update(module, input, output):
    _, topk_index = t.topk(output, 2, dim=1) 
    # topk_list of of the shape [S_l, 2] where S_l is the length of the sequence
    topk_list = topk_index.tolist()

    # iterate over all the tokens in the sequence
    for topk in topk_list: 
        expert_1, expert_2 = tuple(topk) 
        first_layer_usage[expert_1] += 1
        first_layer_usage[expert_2] += 1

def last_layer_update(module, input, output): 
    _, topk_index = t.topk(output, 2, dim=1) 
    # topk_list of of the shape [S_l, 2] where S_l is the length of the sequence
    topk_list = topk_index.tolist()

    # iterate over all the tokens in the sequence
    for topk in topk_list: 
        expert_1, expert_2 = tuple(topk) 
        last_layer_usage[expert_1] += 1
        last_layer_usage[expert_2] += 1


I will only register the hooks for the first and the last experts. We can try and visualize these then.

In [8]:
hooks = []
hooks.append(model.layers[0].block_sparse_moe.gate.register_forward_hook(first_layer_update))
hooks.append(model.layers[-1].block_sparse_moe.gate.register_forward_hook(last_layer_update))

## Running the Experiment

In [51]:
def normalize_dataset(usage: dict, total_num_tokens: int) -> dict:
    norm_usage = usage.copy()
    for expert_num, expert_usage in usage.items():
        norm_usage[expert_num] = expert_usage / total_num_tokens
    
    return norm_usage

In [52]:
first_layer_usage = defaultdict(int)
last_layer_usage = defaultdict(int)

dataset = qa_dataset["instruction"]
num_tokens = tokenizer("".join(dataset), return_length=True).length[0]
for instruction in tqdm(dataset):
    tok_instruction = tokenizer(instruction, return_tensors="pt")
    outputs = model(**tok_instruction)

all_first_layer_usages["qa"] = normalize_dataset(first_layer_usage, num_tokens)
all_last_layer_usages["qa"] = normalize_dataset(last_layer_usage, num_tokens)

100%|██████████| 50/50 [00:13<00:00,  3.62it/s]


In [55]:
first_layer_usage = defaultdict(int)
last_layer_usage = defaultdict(int)

dataset = code_dataset["instruct_prompt"]
num_tokens = tokenizer("".join(dataset), return_length=True).length[0]
for instruction in tqdm(dataset):
    tok_instruction = tokenizer(instruction, return_tensors="pt")
    outputs = model(**tok_instruction) 

all_first_layer_usages["code"] = normalize_dataset(first_layer_usage, num_tokens)
all_last_layer_usages["code"] = normalize_dataset(last_layer_usage, num_tokens) 

100%|██████████| 50/50 [00:14<00:00,  3.40it/s]


In [56]:
first_layer_usage = defaultdict(int)
last_layer_usage = defaultdict(int)

dataset = imdb_dataset["text"]
num_tokens = tokenizer("".join(dataset), return_length=True).length[0]
for review in tqdm(dataset):
    preprompt = "Classify this as a negative or positive review"
    tok_instruction = tokenizer(f"{preprompt}:{instruction}", return_tensors="pt")
    outputs = model(**tok_instruction)    

all_first_layer_usages["imdb"] = normalize_dataset(first_layer_usage, num_tokens)
all_last_layer_usages["imbd"] = normalize_dataset(last_layer_usage, num_tokens)

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:14<00:00,  3.35it/s]


In [None]:
for hook in hooks:
    hook.remove()

## Plotting 

Now we can plot these usage charts per dataset. The first and the last layers will be printed seperately. Essentially I want to see how much each expert gets used in each of the datasets.

In [57]:
def get_usage_chart(usage: dict):
    # Convert the dictionary to a DataFrame
    data = []
    for category, values in usage.items():
        for key, value in values.items():
            data.append({"Category": category, "Key": key, "Value": value})

    df = pd.DataFrame(data)

    # Create the Altair chart
    chart = alt.Chart(df).mark_bar().encode(
        x=alt.X('Key:O', title='Keys'),
        y=alt.Y('Value:Q', title='Values'),
        color=alt.Color('Category:N', scale=alt.Scale(scheme='category10')),
        column=alt.Column('Category:N', title=None)
    ).properties(
        width=300,
        height=400,
        title='Values for Each Category'
    )
    return chart

In [58]:
first_layer_chart = get_usage_chart(all_first_layer_usages)
last_layer_chart = get_usage_chart(all_last_layer_usages)

In [59]:
first_layer_chart.show()

In [61]:
last_layer_chart.show()