# Experts Hook for Mixtral 8x7B

This experiment is storing all the tokens of a datasets into a database where I can view which tokens go to which experts and plot paths. Maybe given multiple datasets, I could see how they differ.

In [1]:
import torch as t
import pandas as pd
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import altair as alt
from mlx_lm import load, generate

import os
from dotenv import load_dotenv
from huggingface_hub import login

## Dataset 

We will use the `Salesforce/wikitext` dataset. Its very simple, unassuming and easily understand.  

In [2]:
dataset = load_dataset("Salesforce/wikitext", 'wikitext-2-raw-v1', split="test")
data = "\n".join(dataset["text"])

## Model Setup

We would need to add a forward hook into the layers to extract which experts a token gets send to

In [4]:
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=t.float16,
    bnb_4bit_use_double_quant=True
)

model = AutoModel.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)
tokenizer.pad_token = tokenizer.eos_token

model.safetensors.index.json:   0%|          | 0.00/92.7k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

model-00001-of-00019.safetensors:   0%|          | 0.00/4.89G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
def dump_tokens(module, input, output):
    print(module.__class__.__name__)
    topk_values, topk_index = t.topk(output, 2, dim=1)
    print("Chosen experts: ", topk_index)


In [None]:
for layer in model.m