# Experts Hook for Mixtral 8x7B

This experiment is storing all the tokens of a datasets into a database where I can view which tokens go to which experts and plot paths. Maybe given multiple datasets, I could see how they differ.

In [47]:
import torch as t
import pandas as pd
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import altair as alt

import os
from collections import defaultdict
from dotenv import load_dotenv
from huggingface_hub import login

In [48]:
load_dotenv()
login(os.getenv("HF_TOKEN"))

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/ubuntu/.cache/huggingface/token
Login successful


## Dataset 

We will use the `Salesforce/wikitext` dataset. Its very simple, unassuming and easily understand.  

In [49]:
dataset = load_dataset("Salesforce/wikitext", 'wikitext-2-raw-v1', split="test")
data = "\n".join(dataset["text"][:8])

Using the latest cached version of the dataset since Salesforce/wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-2-raw-v1' at /home/ubuntu/.cache/huggingface/datasets/Salesforce___wikitext/wikitext-2-raw-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Thu Jun 20 17:23:04 2024).


## Model Setup

We would need to add a forward hook into the layers to extract which experts a token gets send to which experts. To store the information, we will use 2 simple databases right now. 
- expert_to_token: 32 keys representing the Layers, each contain a dictionary of where keys are the exprts and values are sets of tokens. Tracks which tokens are sent to the expert.
- token_to_expert: keys are tokens, values are 32-length lists of 2-tuples. Tracks which experts processed a token through the decoder layer. 

In [50]:
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=t.float16,
    bnb_4bit_use_double_quant=True
)

model = AutoModel.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [51]:
NUM_LAYERS = len(model.layers)
token_to_expert = defaultdict(list)
expert_to_token = {f"Layer {l}": defaultdict(set) for l in range(NUM_LAYERS)}

In [52]:
print(model)

MixtralModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x MixtralDecoderLayer(
      (self_attn): MixtralSdpaAttention(
        (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): MixtralRotaryEmbedding()
      )
      (block_sparse_moe): MixtralSparseMoeBlock(
        (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
        (experts): ModuleList(
          (0-7): 8 x MixtralBlockSparseTop2MLP(
            (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
            (w3): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (act_fn): SiLU()
        

Right now, we just define a hook that prints out which experts were chose for a specific output of the attention heads. We need to understand which tokens are directed where. So I will try and figure out which ones go where.

In [53]:
def get_print_token(prompt: str, layer_num:int):
    ids = tokenizer.encode(prompt, add_special_tokens=False) 
    tokens = tokenizer.convert_ids_to_tokens(ids)
    pos_to_token = {i: token for i, token in enumerate(tokens)}

    def print_tokens(module, input, output):
        _, topk_index = t.topk(output, 2, dim=1) 
        # topk_list of of the shape [S_l, 2] where S_l is the length of the sequence
        topk_list = topk_index.tolist()

        for i, topk in enumerate(topk_list):
            print(i)
            token = pos_to_token[i]
            topk = tuple(topk)
            token_to_expert[token].append(topk)
            expert_1, expert_2 = topk
            expert_to_token[f"Layer {layer_num}"][expert_1].add(token)
            expert_to_token[f"Layer {layer_num}"][expert_2].add(token)
    return print_tokens



In [54]:
hooks = []
for layer_num, layers in enumerate(model.layers):
    hook = layers.block_sparse_moe.gate.register_forward_hook(get_print_token(data, layer_num))
    hooks.append(hook)

In [55]:
tokenized = tokenizer(data, return_tensors="pt", max_length=128, truncation=True, add_special_tokens=False)

with t.no_grad():
  outputs = model(**tokenized.to(model.device), output_hidden_states=True, return_dict=True)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

In [56]:
expert_to_token

{'Layer 0': defaultdict(set,
             {6: {'-',
               '0',
               '<0x0A>',
               'dy',
               'ed',
               's',
               '▁',
               '▁=',
               '▁an',
               '▁episode',
               '▁followed',
               '▁performed',
               '▁play',
               '▁was'},
              2: {'2',
               'dy',
               'ed',
               'ons',
               'oul',
               's',
               '▁',
               '▁Bill',
               '▁Craig',
               '▁English',
               '▁Robert',
               '▁Royal',
               '▁Simon',
               '▁Ted',
               '▁by',
               '▁episode',
               '▁guest',
               '▁had',
               '▁in',
               '▁play',
               '▁role',
               '▁television',
               '▁the'},
              5: {'-',
               '2',
               '<0x0A>',
               '@',
             

In [57]:
token_to_expert

defaultdict(list,
            {'▁': [(6, 2),
              (2, 5),
              (5, 4),
              (5, 4),
              (5, 4),
              (5, 6),
              (3, 0),
              (0, 2),
              (0, 7),
              (5, 3),
              (0, 7),
              (0, 3),
              (7, 1),
              (4, 0),
              (0, 3),
              (0, 2),
              (0, 2),
              (2, 0),
              (6, 4),
              (4, 5),
              (6, 5),
              (5, 6),
              (6, 5),
              (7, 3),
              (3, 0),
              (3, 2),
              (2, 1),
              (2, 3),
              (2, 4),
              (2, 0),
              (3, 5),
              (1, 6),
              (3, 2),
              (0, 4),
              (4, 2),
              (0, 4),
              (3, 2),
              (0, 7),
              (5, 7),
              (5, 7),
              (5, 6),
              (4, 7),
              (3, 5),
              (5, 2),
         