# Analysis of Dynamic Behaviours

This notebook includes experiments listed below:
- Outputs of experts
- Norms of expert outputs and gate scores
- Intermediate states of experts
- Chosen experts

The models have their own code blocks for each experiment. The overall logic of the code belonging to different models is alike, and the minor differences stem from the unique settings of the corresponding model.

Usually, the figures are plotted in two ways: 'auto_colorbar' and 'full_colorbar'. The former allows the matplotlib methods to automatically dicide the range of the color bar for each layer. For the latter, we manually set it to be the global minimum/maximum for all the layers.

In [None]:
import csv
import math
import ml_dtypes
import os
import pickle

import functools
import matplotlib as mlp
import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.functional import normalize
from transformers import AutoTokenizer

from mixtral_base.modeling_moe_mistral import MixtralForCausalLM, MistralDecoderLayer, FeedForward
from mixtral_instruct.modeling_mixtral_instruct import MixtralInstructForCausalLM
from mistral.modeling_mistral import MistralModel, MistralMLP
from deepseekmoe.modeling_deepseek import DeepseekForCausalLM, DeepseekDecoderLayer, DeepseekMLP, MoEGate
from grok.modeling_grok1 import Grok1ModelForCausalLM, DecoderLayer, MoeMLP

# The root directory for saving the output figures and data.
WORK_DIR = './outputs'

Run one or more cells below to load the models you need.

In [None]:
mixtral_model = MixtralForCausalLM.from_pretrained(
    "./ckpt/mixtral", 
    low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.bfloat16
)
mixtral_tok = AutoTokenizer.from_pretrained("./ckpt/mixtral")
mixtral_model.eval()

In [None]:
mixtral_instruct_model = MixtralInstructForCausalLM.from_pretrained(
    "./ckpt/mixtral-instruct", 
    low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.bfloat16
)
mixtral_instruct_tok = AutoTokenizer.from_pretrained("./ckpt/mixtral-instruct")
mixtral_instruct_model.eval()

In [None]:
mistral_model = MistralModel.from_pretrained(
    './ckpt/mistral',
    low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.bfloat16
)
mistral_tok = AutoTokenizer.from_pretrained("./ckpt/mistral")
mistral_model.eval()

In [None]:
deepseek_model = DeepseekForCausalLM.from_pretrained(
    './ckpt/deepseekmoe',
    low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.bfloat16
)
deepseek_tok = AutoTokenizer.from_pretrained("./ckpt/deepseekmoe")
deepseek_model.eval()

In [None]:
grok_model = Grok1ModelForCausalLM.from_pretrained(
    './ckpt/grok',
    low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.bfloat16
)
grok_tok = AutoTokenizer.from_pretrained("./ckpt/grok")
grok_model.eval()

## Outputs of Experts

We use both the short and long sequence in this experiment. We plot the similarity heat map of each token in the short sequence, while only the averaged heat map is plotted for the long sequence. To employ the long sequence as the input, set `use_short_input=False`.

### Mixtral and Mistral

In [None]:
# Input.
use_short_input = True # Set False to use the long sequence.
sentence_lst = []
if use_short_input:
    raw_input = "As an open source alternative to"
    sentence_lst.append(raw_input)
else:
    with open('./wikitext103_test.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\n')
        for row in csv_reader:
            sentences = row[0].split('\n')
            for sent in sentences:
                sent = sent.strip()
                if sent.startswith('=') or sent == '':
                    continue
                sentence_lst.append(sent)

cos = torch.nn.CosineSimilarity(dim=0)
matrices = [('w3', 'up_proj'), ('w1', 'gate_proj'), ('w2', 'down_proj')]
num_layers = mixtral_model.config.num_hidden_layers
num_experts = mixtral_model.config.num_experts

tick_labels = [str(i) for i in range(num_experts)]
tick_labels.append('F')
save_dir = os.path.join(WORK_DIR, 'mixtral/mixtral_experts_outsim')
if not use_short_input:
    save_dir += '_average'
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(os.path.join(plot_dir, 'auto_colorbar'), exist_ok=True)
os.makedirs(os.path.join(plot_dir, 'full_colorbar'), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer_short_seq(arr_lst, all_gate_indices, layer_idx, num_tokens, range_type, global_vmin=None, global_vmax=None):
    imlst = []
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(8.0, 2.5))
    for i, sim_arr in enumerate(arr_lst):
        if range_type == 'auto_colorbar':
            im = axs[i].imshow(sim_arr)
            imlst.append(im)
        elif range_type == 'full_colorbar':
            im = axs[i].imshow(sim_arr, vmin=global_vmin, vmax=global_vmax)
        exp1, exp2 = all_gate_indices[layer_idx][0][i, 0], all_gate_indices[layer_idx][0][i, 1]
        axs[i].set_title(f'exp {exp1},{exp2}', fontsize=15)
        axs[i].set_xticks(np.arange(num_experts+1), labels=tick_labels, fontsize=15)
        axs[i].set_yticks(np.arange(num_experts+1), labels=tick_labels, fontsize=15)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=12., fontsize=20)
    if range_type == 'auto_colorbar':
        local_vmin = min(img.get_array().min() for img in imlst)
        local_vmax = max(img.get_array().max() for img in imlst)
        norm = colors.Normalize(vmin=local_vmin, vmax=local_vmax)
        for img in imlst:
            img.set_norm(norm)
    cbar = fig.colorbar(im, ax=axs, shrink=0.88)
    cbar.ax.tick_params(labelsize=15)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


def plot_one_layer_long_seq(avg_arr, layer_idx, range_type, global_vmin=None, global_vmax=None):
    imlst = []
    fig, ax = plt.subplots(ncols=1, layout='constrained', figsize=(3.5, 2.0))
    if range_type == 'auto_colorbar':
        im = ax.imshow(avg_arr)
        imlst.append(im)
    elif range_type == 'full_colorbar':
        im = ax.imshow(avg_arr, vmin=global_vmin, vmax=global_vmax)
    ax.set_xticks(np.arange(num_experts+1), labels=tick_labels, fontsize=14.5)
    ax.set_yticks(np.arange(num_experts+1), labels=tick_labels, fontsize=14.5)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=12., fontsize=20)
    cbar = fig.colorbar(im, ax=ax, shrink=1.)
    cbar.ax.tick_params(labelsize=14.5)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


In [None]:
# Forward pass.

def get_angular_similarity(v1, v2):
    batch_cos = torch.nn.CosineSimilarity(dim=2)
    return 1 - (torch.acos(batch_cos(v1, v2)) / math.pi)


def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_gate_output(module, input, output, layer_idx):  
    scores = output
    _, expert_indices = torch.topk(scores, 2, dim=-1)
    all_gate_indices[layer_idx].append(expert_indices.cpu().detach().numpy())


def record_expert_output(module, input, output, layer_idx, expert_idx):
    # output shape = [num_tokens, hidden_dim]
    all_expert_output[layer_idx][expert_idx] = output 


def record_ffn_output(module, input, output, layer_idx):
    all_expert_output[layer_idx][-1] = output[0, ...]


total_token_count = 0
all_sim_arr = [np.zeros((num_experts+1, num_experts+1)) for _ in range(num_layers)]
for s, sent in enumerate(sentence_lst):
    if s == 10:
        # For the long sequence, use the first ten sentences only.
        break
    mix_enc_input = mixtral_tok.encode(sent, return_tensors='pt') # mix_enc_input is actually the same as mis_enc_input.
    mis_enc_input = mistral_tok.encode(sent, return_tensors='pt')
    assert mix_enc_input.shape[1] == mis_enc_input.shape[1]
    num_tokens = mix_enc_input.shape[1]
    total_token_count += num_tokens
    print(s, num_tokens, total_token_count)
    all_layer_output = [[] for _ in range(num_layers)]
    all_expert_output = [{} for _ in range(num_layers)]
    all_gate_indices = [[] for _ in range(num_layers)]
    handles = []

    # Obtain the original output feature vectors of experts 
    # and gate choices when topk=2. 
    for name, module in mixtral_model.named_modules():
        if isinstance(module, MistralDecoderLayer):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_layer_output, layer_idx=layer_idx)
            ))
        elif isinstance(module, torch.nn.Linear) and 'gate' in name:
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_gate_output, layer_idx=layer_idx)
            ))

    mix_output = mixtral_model(mix_enc_input)
    for h in handles:
        h.remove()
    handles = []

    # Modify the number of chosen experts to ALL.
    for i in range(num_layers):
        mixtral_model.model.layers[i].mlp.num_experts_per_token = num_experts
    # Iterate over the layers and register a hook once a time.
    for i in range(num_layers):
        for name, module in mixtral_model.named_modules():
            if isinstance(module, FeedForward):
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    expert_idx = int(name.split('.')[-1])
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=expert_idx)
                    ))
                elif layer_idx > i:
                    break
        if i == 0:
            with torch.no_grad():
                mix_output = mixtral_model(mix_enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
        else: 
            with torch.no_grad():
                # Feed the topk=2 output of previous layer as input.
                mix_output = mixtral_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
        for h in handles:
            h.remove()
        handles = []
    # Revert to the original value.
    for i in range(num_layers):
        mixtral_model.model.layers[i].mlp.num_experts_per_token = mixtral_model.config.num_experts_per_token

    # Obtain Mistral FFNs' output.
    for name, module in mistral_model.named_modules():
        if isinstance(module, MistralMLP):
            layer_idx = int(name.split('.')[1])
            handles.append(module.register_forward_hook(
                functools.partial(record_ffn_output, layer_idx=layer_idx)
            ))

    with torch.no_grad():
        mis_output = mistral_model(mis_enc_input)
    for h in handles:
        h.remove()

    # Compute similarity between outputs of current sentence.
    if use_short_input:
        for i in range(num_layers):
            for j in range(num_tokens):
                sim_arr = np.ones((num_experts+1, num_experts+1))
                for k in range(num_experts+1):
                    for l in range(k+1, num_experts+1):
                        # Mixtral and Mistral layers can be loaded on differnet GPUs, so put them on the same device manually. 
                        sim = cos(all_expert_output[i][k][j].to('cuda:0'), all_expert_output[i][l][j].to('cuda:0')).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 
                        sim_arr[k][l] = sim
                        sim_arr[l][k] = sim
                all_sim_arr[i].append(sim_arr)
    else:
        output_dim = all_expert_output[0][0][0].shape[0]
        for i in range(num_layers):
            for j in range(num_tokens):
                # Reorganize recorded data to compute similarity in parallel.
                expert_output_self = torch.empty(num_experts+1, 1, output_dim).cuda()
                expert_output_other = torch.empty(1, num_experts+1, output_dim).cuda()
                for k in range(num_experts+1):
                    k = -1 if k == num_experts else k
                    expert_output_self[k, 0] = all_expert_output[i][k][j]
                    expert_output_other[0, k] = all_expert_output[i][k][j]
                sim = get_angular_similarity(expert_output_self, expert_output_other).fill_diagonal_(1.) # Replace nan values due to numerical instability
                sim = sim.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 
                all_sim_arr[i] += sim

In [None]:
# Save and plot.
if use_short_input:
    # Record the maximum and minimum values for plotting.
    global_vmax, global_vmin = -1 * math.inf, math.inf
    for i in range(num_layers):
        for j in range(num_tokens):
            sim_arr = all_sim_arr[i][j]
            curr_vmax = np.max(sim_arr)
            curr_vmin = np.min(sim_arr)
            if curr_vmin < global_vmin:
                global_vmin = curr_vmin
            if curr_vmax > global_vmax:
                global_vmax = curr_vmax
    
    output_dict = {'global_vmax':global_vmax, 'global_vmin':global_vmin}
    with open(os.path.join(output_dir, 'all_sim_arr'), 'wb') as f:
        pickle.dump(all_sim_arr, f)
    with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
        pickle.dump(all_gate_indices, f)
    with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
        pickle.dump(output_dict, f)

    for i in range(num_layers):
        plot_one_layer_short_seq(all_sim_arr[i], all_gate_indices, i, num_tokens, 'auto_colorbar')
        plot_one_layer_short_seq(all_sim_arr[i], all_gate_indices, i, num_tokens, 'full_colorbar', global_vmin, global_vmax)

else:
    global_vmax, global_vmin = -1 * math.inf, math.inf
    all_avg_arr = []
    for i in range(num_layers):
        avg_arr = all_sim_arr[i] / total_token_count
        all_avg_arr.append(avg_arr)
        curr_vmax = np.max(avg_arr)
        curr_vmin = np.min(avg_arr)
        if curr_vmin < global_vmin:
            global_vmin = curr_vmin
        if curr_vmax > global_vmax:
            global_vmax = curr_vmax
    output_dict = {'global_vmax':global_vmax, 'global_vmin':global_vmin}
    with open(os.path.join(output_dir, 'all_avg_arr'), 'wb') as f:
        pickle.dump(all_avg_arr, f)
    with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
        pickle.dump(output_dict, f)
    
    for i in range(num_layers):
        avg_arr = all_sim_arr[i] / total_token_count
        plot_one_layer_long_seq(all_avg_arr[i], i, 'auto_colorbar')
        plot_one_layer_long_seq(all_avg_arr[i], i, 'full_colorbar', global_vmin, global_vmax)

### DeepSeek

In [None]:
# Input.
use_short_input = True # Set False to use the long sequence.
sentence_lst = []
if use_short_input:
    raw_input = "As an open source alternative to"
    sentence_lst.append(raw_input)
else:
    with open('./wikitext103_test.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\n')
        for row in csv_reader:
            sentences = row[0].split('\n')
            for sent in sentences:
                sent = sent.strip()
                if sent.startswith('=') or sent == '':
                    continue
                sentence_lst.append(sent)

cos = torch.nn.CosineSimilarity(dim=0)
matrices = ['up_proj', 'gate_proj', 'down_proj']
num_layers = deepseek_model.config.num_hidden_layers
num_routed_experts = deepseek_model.config.n_routed_experts
tick_pos = [i for i in range(0, num_routed_experts, 8)]
tick_labels = [str(i) for i in range(0, num_routed_experts, 8)]
tick_pos.append(num_routed_experts)
tick_labels.append('SE')
save_dir = os.path.join(WORK_DIR, 'deepseek/deepseek_experts_outsim')
if not use_short_input:
    save_dir += '_average'
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(os.path.join(plot_dir, 'auto_colorbar'), exist_ok=True)
os.makedirs(os.path.join(plot_dir, 'full_colorbar'), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer_short_seq(arr_lst, all_gate_indices, layer_idx, num_tokens, range_type, global_vmin=None, global_vmax=None):
    imlst = []
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(14., 4))
    num_chosen_experts = deepseek_model.config.num_experts_per_tok
    for i, sim_arr in enumerate(arr_lst):
        if range_type == 'auto_colorbar':
            im = axs[i].imshow(sim_arr)
            imlst.append(im)
        elif range_type == 'full_colorbar':
            im = axs[i].imshow(sim_arr, vmin=global_vmin, vmax=global_vmax)
        chosen_experts = ''
        for j in range(num_chosen_experts):
            chosen_experts += str(all_gate_indices[layer_idx][0][i, j])
            if j != num_chosen_experts - 1:
                chosen_experts += ', '
        axs[i].set_title(f'exp {chosen_experts}', fontsize=15)
        axs[i].set_xticks(tick_pos, labels=tick_labels, fontsize=15)
        axs[i].set_yticks(tick_pos, labels=tick_labels, fontsize=15)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=36)
    if range_type == 'auto_colorbar':
        local_vmin = min(img.get_array().min() for img in imlst)
        local_vmax = max(img.get_array().max() for img in imlst)
        norm = colors.Normalize(vmin=local_vmin, vmax=local_vmax)
        for img in imlst:
            img.set_norm(norm)
    cbar = fig.colorbar(im, ax=axs, shrink=1.)
    cbar.ax.tick_params(labelsize=15)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


def plot_one_layer_long_seq(avg_arr, layer_idx, range_type, global_vmin=None, global_vmax=None):
    fig, ax = plt.subplots(ncols=1, layout='constrained', figsize=(7.5, 3.5))
    if range_type == 'auto_colorbar':
        im = ax.imshow(avg_arr)
    elif range_type == 'full_colorbar':
        im = ax.imshow(avg_arr, vmin=global_vmin, vmax=global_vmax)
    ax.set_xticks(tick_pos, labels=tick_labels, fontsize=18)
    ax.set_yticks(tick_pos, labels=tick_labels, fontsize=18)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=18., fontsize=26)
    cbar = fig.colorbar(im, ax=ax, shrink=1.)
    cbar.ax.tick_params(labelsize=18)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()

In [None]:
# Forward pass.


def get_angular_similarity(v1, v2):
    batch_cos = torch.nn.CosineSimilarity(dim=2)
    return 1 - (torch.acos(batch_cos(v1, v2)) / math.pi) 


def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_expert_output(module, input, output, layer_idx, expert_idx):
    # output shape = [num_tokens, hidden_dim]
    if expert_idx == -1:
        all_expert_output[layer_idx][expert_idx] = output.squeeze(dim=0)
    else:
        all_expert_output[layer_idx][expert_idx] = output 


def record_gate_output(module, input, output, layer_idx):  
    expert_indices, expert_weights, _ = output
    all_gate_indices[layer_idx].append(expert_indices.cpu().detach().numpy())


total_token_count = 0
all_sim_arr = [np.zeros((num_routed_experts+1, num_routed_experts+1)) for _ in range(num_layers)]
for s, sent in enumerate(sentence_lst):
    if s == 10:
        # For the long sequence, use the first ten sentences only.
        break
    enc_input = deepseek_tok.encode(sent, return_tensors='pt').cuda()
    num_tokens = enc_input.shape[1]
    total_token_count += num_tokens
    print(s, num_tokens, total_token_count)
    all_layer_output = [[] for _ in range(num_layers)]
    all_expert_output = [{} for _ in range(num_layers)]
    all_gate_indices = [[] for _ in range(num_layers)]
    handles = []

    # Obtain the original output feature vectors of experts 
    # and gate choices when topk=6. 
    for name, module in deepseek_model.named_modules():
        if isinstance(module, DeepseekDecoderLayer):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_layer_output, layer_idx=layer_idx)
            ))
        elif isinstance(module, MoEGate):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_gate_output, layer_idx=layer_idx)
            ))

    with torch.no_grad():
        output = deepseek_model(enc_input)
    for h in handles:
        h.remove()

    handles = []

    # Modify the number of chosen experts.
    for i in range(1, num_layers):
        curr_layer = deepseek_model.model.layers[i]
        if hasattr(curr_layer.mlp, 'num_experts_per_tok'):
            curr_layer.mlp.num_experts_per_tok = num_routed_experts
        if hasattr(curr_layer.mlp, 'gate'):
            curr_layer.mlp.gate.top_k = num_routed_experts

    # Iterate over the layers and register a hook once a time.
    for i in range(num_layers):
        for name, module in deepseek_model.named_modules():
            if isinstance(module, DeepseekMLP) and 'shared_experts' in name:
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=-1) # Use -1 to represent shared experts.
                    ))
                elif layer_idx > i:
                    break
            elif isinstance(module, DeepseekMLP) and 'experts' in name:
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    expert_idx = int(name.split('.')[-1])
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=expert_idx)
                    ))
                elif layer_idx > i:
                    break
        if i == 0:
            with torch.no_grad():
                output = deepseek_model(enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
        else: 
            with torch.no_grad():
                # Feed the topk=6 output of previous layer as input.
                output = deepseek_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
        for h in handles:
            h.remove()

    # Revert to the original value.
    for i in range(1, num_layers):
        curr_layer = deepseek_model.model.layers[i]
        if hasattr(curr_layer.mlp, 'num_experts_per_tok'):
            curr_layer.mlp.num_experts_per_tok = deepseek_model.config.num_experts_per_tok
        if hasattr(curr_layer.mlp, 'gate'):
            curr_layer.mlp.gate.top_k = deepseek_model.config.num_experts_per_tok

    # Compute similarity between outputs of current sentence.
    if use_short_input:
        for i in range(1, num_layers):
            for j in range(num_tokens):
                sim_arr = np.empty((num_routed_experts+1, num_routed_experts+1))
                for k in range(num_routed_experts+1):
                    for l in range(k, num_routed_experts+1):
                        sim = cos(all_expert_output[i][k][j], all_expert_output[i][l][j]).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16)
                        sim_arr[k][l] = sim
                        sim_arr[l][k] = sim
                all_sim_arr[i].append(sim_arr)
    else:
        output_dim = all_expert_output[1][0][0].shape[0]
        for i in range(1, num_layers):
            for j in range(num_tokens):
                # Reorganize recorded data to compute similarity in parallel.
                expert_output_self = torch.empty(num_routed_experts+1, 1, output_dim).cuda()
                expert_output_other = torch.empty(1, num_routed_experts+1, output_dim).cuda()
                for k in range(num_routed_experts+1):
                    k = -1 if k == num_routed_experts else k
                    expert_output_self[k, 0] = all_expert_output[i][k][j]
                    expert_output_other[0, k] = all_expert_output[i][k][j]
                sim = get_angular_similarity(expert_output_self, expert_output_other).fill_diagonal_(1.) # Replace nan values due to numerical instability
                sim = sim.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 
                all_sim_arr[i] += sim

In [None]:
# Save and plot.
if use_short_input:
    # Record the maximum and minimum values for plotting.
    global_vmax, global_vmin = -1 * math.inf, math.inf
    for i in range(num_layers):
        for j in range(num_tokens):
            sim_arr = all_sim_arr[i][j]
            curr_vmax = np.max(sim_arr)
            curr_vmin = np.min(sim_arr)
            if curr_vmin < global_vmin:
                global_vmin = curr_vmin
            if curr_vmax > global_vmax:
                global_vmax = curr_vmax

    output_dict = {'global_vmax': global_vmax, 'global_vmin':global_vmin}
    with open(os.path.join(output_dir, 'all_sim_arr'), 'wb') as f:
        pickle.dump(all_sim_arr, f)
    with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
        pickle.dump(all_gate_indices, f)
    with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
        pickle.dump(output_dict, f)  

    for i in range(1, num_layers):
        plot_one_layer_short_seq(all_sim_arr[i], all_gate_indices, i, num_tokens, 'auto_colorbar')
        plot_one_layer_short_seq(all_sim_arr[i], all_gate_indices, i, num_tokens, 'full_colorbar', global_vmin, global_vmax)

else:
    global_vmax, global_vmin = -1 * math.inf, math.inf
    all_avg_arr = [[] for _ in range(num_layers)]
    for i in range(1, num_layers):
        avg_arr = all_sim_arr[i] / total_token_count
        all_avg_arr[i].append(avg_arr)
        curr_vmax = np.max(avg_arr)
        curr_vmin = np.min(avg_arr)
        if curr_vmin < global_vmin:
            global_vmin = curr_vmin
        if curr_vmax > global_vmax:
            global_vmax = curr_vmax

    output_dict = {'global_vmax':global_vmax, 'global_vmin':global_vmin}
    with open(os.path.join(output_dir, 'all_avg_arr'), 'wb') as f:
        pickle.dump(all_avg_arr, f)
    with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
        pickle.dump(output_dict, f)
    
    for i in range(1, num_layers):
        avg_arr = all_sim_arr[i] / total_token_count
        plot_one_layer_long_seq(all_avg_arr[i][0], i, 'auto_colorbar')
        plot_one_layer_long_seq(all_avg_arr[i][0], i, 'full_colorbar', global_vmin, global_vmax)


### Grok

In [None]:
# Input.
use_short_input = True # Set False to use the long sequence.
sentence_lst = []
if use_short_input:
    raw_input = "As an open source alternative to"
    sentence_lst.append(raw_input)
else:
    with open('./wikitext103_test.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\n')
        for row in csv_reader:
            sentences = row[0].split('\n')
            for sent in sentences:
                sent = sent.strip()
                if sent.startswith('=') or sent == '':
                    continue
                sentence_lst.append(sent)

cos = torch.nn.CosineSimilarity(dim=0)
matrices = ['linear_v', 'linear', 'linear_1'] # up, gate, down
normalize_output = False
num_layers = grok_model.config.num_hidden_layers
num_experts = grok_model.config.num_experts
tick_labels = [str(i) for i in range(num_experts)]
save_dir = os.path.join(WORK_DIR, 'grok/grok_experts_outsim')
if not use_short_input:
    save_dir += '_average'
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(os.path.join(plot_dir, 'auto_colorbar'), exist_ok=True)
os.makedirs(os.path.join(plot_dir, 'full_colorbar'), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer_short_seq(arr_lst, all_gate_indices, layer_idx, num_tokens, range_type, global_vmin=None, global_vmax=None):
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(8., 2.5))
    imlst = []
    for i, sim_arr in enumerate(arr_lst):
        if range_type == 'auto_colorbar':
            im = axs[i].imshow(sim_arr)
            imlst.append(im)
        elif range_type == 'full_colorbar':
            im = axs[i].imshow(sim_arr, vmin=global_vmin, vmax=global_vmax)
        exp1, exp2 = all_gate_indices[layer_idx][0][i, 0], all_gate_indices[layer_idx][0][i, 1]
        axs[i].set_title(f'exp {exp1},{exp2}', fontsize=16)
        axs[i].set_xticks(np.arange(num_experts), labels=tick_labels, fontsize=16)
        axs[i].set_yticks(np.arange(num_experts), labels=tick_labels, fontsize=16)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=12., fontsize=20)
    if range_type == 'auto_colorbar':
        local_vmin = min(img.get_array().min() for img in imlst)
        local_vmax = max(img.get_array().max() for img in imlst)
        norm = colors.Normalize(vmin=local_vmin, vmax=local_vmax)
        for img in imlst:
            img.set_norm(norm)
    cbar = fig.colorbar(im, ax=axs, shrink=.88)
    cbar.ax.tick_params(labelsize=16)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


def plot_one_layer_long_seq(avg_arr, layer_idx, range_type, global_vmin=None, global_vmax=None):
    fig, ax = plt.subplots(ncols=1, layout='constrained', figsize=(4., 2.0))
    if range_type == 'auto_colorbar':
        im = ax.imshow(avg_arr)
    elif range_type == 'full_colorbar':
        im = ax.imshow(avg_arr, vmin=global_vmin, vmax=global_vmax)
    ax.set_xticks(np.arange(num_experts), labels=tick_labels, fontsize=14.5)
    ax.set_yticks(np.arange(num_experts), labels=tick_labels, fontsize=14.5)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=12., fontsize=20)
    cbar = fig.colorbar(im, ax=ax, shrink=1.)
    cbar.ax.tick_params(labelsize=14.5)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


In [None]:
# Forward pass.

def get_angular_similarity(v1, v2):
    batch_cos = torch.nn.CosineSimilarity(dim=2)
    return 1 - (torch.acos(batch_cos(v1, v2)) / math.pi) 


def record_layer_output(module, input, output, layer_idx):
    # output[0] shape: (num_tokens, hidden_dim)
    all_layer_output[layer_idx].append(output[0])


def record_gate_output(module, input, output, layer_idx):  
    router_logits = output
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    # shape: (num_tokens, topk)
    routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
    all_gate_indices[layer_idx].append(selected_experts.cpu().detach().numpy())


def record_expert_output(module, input, output, layer_idx, expert_idx):
    # output shape: (num_tokens, hidden_dim)
    if normalize_output:
        all_expert_output[layer_idx][expert_idx] = normalize(output, p=2.0, dim=1)
    else: 
        all_expert_output[layer_idx][expert_idx] = output 


total_token_count = 0
all_sim_arr = [np.zeros((num_routed_experts+1, num_routed_experts+1)) for _ in range(num_layers)]
for s, sent in enumerate(sentence_lst):
    if s == 10:
        # For the long sequence, use the first ten sentences only.
        break
    enc_input = grok_tok.encode(sent, return_tensors='pt').cuda()
    attention_mask = torch.ones_like(enc_input)
    inputs = {
        "input_ids": grok_tok(sent, return_tensors='pt').input_ids.cuda(),
        "attention_mask": attention_mask,
        "max_new_tokens": 1,
    }
    num_tokens = enc_input.shape[1]
    total_token_count += num_tokens
    print(s, num_tokens, total_token_count)
    all_layer_output = [[] for _ in range(num_layers)]
    all_expert_output = [{} for _ in range(num_layers)]
    all_gate_indices = [[] for _ in range(num_layers)]
    handles = []

    # Obtain the original output feature vectors of experts 
    # and gate choices when topk=2. 
    for name, module in grok_model.named_modules():
        if isinstance(module, DecoderLayer):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_layer_output, layer_idx=layer_idx)
            ))
        elif isinstance(module, torch.nn.Linear) and 'gate' in name:
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_gate_output, layer_idx=layer_idx)
            ))

    output = grok_model.generate(**inputs)
    for h in handles:
        h.remove()
    handles = []

    # Modify the number of chosen experts to ALL.
    for i in range(num_layers):
        grok_model.model.layers[i].moe_block.top_k = num_experts
    # Iterate over the layers and register a hook once a time.
    for i in range(num_layers):
        for name, module in grok_model.named_modules():
            if isinstance(module, MoeMLP):
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    expert_idx = int(name.split('.')[-1])
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=expert_idx)
                    ))
                elif layer_idx > i:
                    break
        if i == 0:
            with torch.no_grad():
                output = grok_model(enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
        else: 
            with torch.no_grad():
                # Feed the topk=2 output of previous layer as input.
                output = grok_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
        for h in handles:
            h.remove()
        handles = []
    # Revert to the original value.
    for i in range(num_layers):
        grok_model.model.layers[i].moe_block.top_k = grok_model.config.num_experts_per_tok

    # Compute similarity between outputs of current sentence.
    if use_short_input:
        for i in range(num_layers):
            for j in range(num_tokens):
                sim_arr = np.ones((num_experts, num_experts))
                for k in range(num_experts):
                    for l in range(k+1, num_experts):
                        sim = cos(all_expert_output[i][k][j], all_expert_output[i][l][j]).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 
                        sim_arr[k][l] = sim
                        sim_arr[l][k] = sim
                all_sim_arr[i].append(sim_arr)
    else:
        output_dim = all_expert_output[0][0][0].shape[0]
        for i in range(num_layers):
            for j in range(num_tokens):
                # Reorganize recorded data to compute similarity in parallel.
                expert_output_self = torch.empty(num_experts, 1, output_dim).cuda()
                expert_output_other = torch.empty(1, num_experts, output_dim).cuda()
                for k in range(num_experts):
                    expert_output_self[k, 0] = all_expert_output[i][k][j]
                    expert_output_other[0, k] = all_expert_output[i][k][j]
                sim = get_angular_similarity(expert_output_self, expert_output_other).fill_diagonal_(1.) # Replace nan values due to numerical instability
                sim = sim.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 
                all_sim_arr[i] += sim

In [None]:
# Save and plot.
if use_short_input:
    # Record the maximum and minimum values for plotting.
    global_vmax, global_vmin = -1 * math.inf, math.inf
    for i in range(num_layers):
        for j in range(num_tokens):
            sim_arr = all_sim_arr[i][j]
            curr_vmax = np.max(sim_arr)
            curr_vmin = np.min(sim_arr)
            if curr_vmin < global_vmin:
                global_vmin = curr_vmin
            if curr_vmax > global_vmax:
                global_vmax = curr_vmax
    
    output_dict = {'global_vmax':global_vmax, 'global_vmin':global_vmin}
    with open(os.path.join(output_dir, 'all_sim_arr'), 'wb') as f:
        pickle.dump(all_sim_arr, f)
    with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
        pickle.dump(all_gate_indices, f)
    with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
        pickle.dump(output_dict, f)

    for i in range(num_layers):
        plot_one_layer_short_seq(all_sim_arr[i], all_gate_indices, i, num_tokens, 'auto_colorbar')
        plot_one_layer_short_seq(all_sim_arr[i], all_gate_indices, i, num_tokens, 'full_colorbar', global_vmin, global_vmax)

else:
    global_vmax, global_vmin = -1 * math.inf, math.inf
    all_avg_arr = []
    for i in range(num_layers):
        avg_arr = all_sim_arr[i] / total_token_count
        all_avg_arr.append(avg_arr)
        curr_vmax = np.max(avg_arr)
        curr_vmin = np.min(avg_arr)
        if curr_vmin < global_vmin:
            global_vmin = curr_vmin
        if curr_vmax > global_vmax:
            global_vmax = curr_vmax
    output_dict = {'global_vmax':global_vmax, 'global_vmin':global_vmin}
    with open(os.path.join(output_dir, 'all_avg_arr'), 'wb') as f:
        pickle.dump(all_avg_arr, f)
    with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
        pickle.dump(output_dict, f)
    
    for i in range(num_layers):
        avg_arr = all_sim_arr[i] / total_token_count
        plot_one_layer_long_seq(all_avg_arr[i], i, 'auto_colorbar')
        plot_one_layer_long_seq(all_avg_arr[i], i, 'full_colorbar', global_vmin, global_vmax)

## Norms of Expert Outputs and Gate Scores

We use both the short and long sequence in this experiment. We plot the norm and gate score of every expert for each token in the short sequence, while only the rank counting is plotted for the long sequence. To employ the long sequence as the input, set `use_short_input=False`.

### Mixtral

In [None]:
# Input.
use_short_input = True # Set False to use the long sequence.
sentence_lst = []
if use_short_input:
    raw_input = "As an open source alternative to"
    sentence_lst.append(raw_input)
else:
    with open('./wikitext103_test.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\n')
        for row in csv_reader:
            sentences = row[0].split('\n')
            for sent in sentences:
                sent = sent.strip()
                if sent.startswith('=') or sent == '':
                    continue
                sentence_lst.append(sent)

num_layers = mixtral_model.config.num_hidden_layers
num_experts = mixtral_model.config.num_experts

tick_labels = [str(i) for i in range(num_experts)]
save_dir = os.path.join(WORK_DIR, 'mixtral/mixtral_expert_norm')
if not use_short_input:
    save_dir += '_count'
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer_short_seq(all_gate_scores, all_gate_indices, all_expert_output, layer_idx, num_tokens):
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(22.0, 2.8))
    for i in range(num_tokens):
        # Plot the norm of feature vectors output by experts.
        norm_lst = []
        for j in range(num_experts):
            norm_lst.append(all_expert_output[layer_idx][j][i])
        im1 = axs[i].bar(np.arange(num_experts)*2-0.35, norm_lst, label='Norm', width=0.6)
        # Plot the gate scores.
        twin_ax = axs[i].twinx()
        im2 = twin_ax.bar(np.arange(num_experts)*2, all_gate_scores[layer_idx][0][i, :], tick_label=tick_labels, 
                          color='darkorange', align='edge', label='Score', width=0.5)
        axs[i].set_xticks(np.arange(num_experts)*2, labels=tick_labels, fontsize=18)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=22)
        exp1, exp2 = all_gate_indices[layer_idx][0][i, 0], all_gate_indices[layer_idx][0][i, 1]
        axs[i].set_title(f'exp {exp1},{exp2}', fontsize=18)
        axs[i].legend(loc='upper left', fontsize=12)
        twin_ax.legend(loc='upper right', fontsize=12)
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def plot_one_layer_long_seq(rankings_counts, layer_idx):
    fig, ax = plt.subplots(layout='constrained', figsize=(6.5, 4.0))
    bar_width = 0.1
    x = np.arange(num_experts)
    for i in range(num_experts):
        offset = bar_width * i
        im = ax.bar(x+offset, rankings_counts[i, :], bar_width, label=f'score rank {i+1}')
    ax.set_xticks(x+3.5*bar_width, [str(i+1) for i in range(num_experts)], fontsize=13)
    ax.tick_params(axis='y', labelsize=11)
    ax.legend(loc='best', fontsize=11)
    ax.set_xlabel('Expert output norm ranking', fontsize=15)
    ax.set_ylabel('Count of gate score ranking', fontsize=15)
    plt.savefig(os.path.join(plot_dir, f'layer{layer_idx}.png'))
    plt.close()

In [None]:
# Forward pass.

def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_gate_output(module, input, output, layer_idx):  
    scores = output
    _, expert_indices = torch.topk(scores, 2, dim=-1, sorted=True)
    all_gate_indices[layer_idx].append(expert_indices.cpu().detach().numpy())
    all_gate_scores[layer_idx].append(scores.softmax(dim=-1).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16))


def record_expert_output(module, input, output, layer_idx, expert_idx):
    # output size = [num_tokens, hidden_dim]
    all_expert_output[layer_idx][expert_idx] = torch.norm(output, dim=1).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16)


token_count = 0
rankings_counts = [np.zeros((num_experts, num_experts)) for _ in range(num_layers)]
for s, sent in enumerate(sentence_lst):
    if s == 10:
        break
    enc_input = mixtral_tok.encode(sent, return_tensors="pt").cuda()
    num_tokens = enc_input.shape[1]
    token_count += num_tokens
    print(s, num_tokens, token_count)
    all_gate_scores = [[] for _ in range(num_layers)]
    all_gate_indices = [[] for _ in range(num_layers)]
    all_layer_output = [[] for _ in range(num_layers)]
    all_expert_output = [{} for _ in range(num_layers)]
    handles = []

    # Obtain the original output feature vectors of experts 
    # and gate choices when topk=2. 
    for name, module in mixtral_model.named_modules():
        if isinstance(module, MistralDecoderLayer):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_layer_output, layer_idx=layer_idx)
            ))
        elif isinstance(module, torch.nn.Linear) and 'gate' in name:
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_gate_output, layer_idx=layer_idx)
            ))

    with torch.no_grad():
        mix_output = mixtral_model(enc_input)
    for h in handles:
        h.remove()
    handles = []

    # Modify the number of chosen experts to ALL.
    for i in range(num_layers):
        mixtral_model.model.layers[i].mlp.num_experts_per_token = num_experts
    # Iterate over the layers and register a hook once a time.
    for i in range(num_layers):
        for name, module in mixtral_model.named_modules():
            if isinstance(module, FeedForward):
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    expert_idx = int(name.split('.')[-1])
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=expert_idx)
                    ))
                elif layer_idx > i:
                    break
        if i == 0:
            with torch.no_grad():
                mix_output = mixtral_model(enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
        else: 
            with torch.no_grad():
            # Feed the topk=2 output of previous layer as input.
                mix_output = mixtral_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
        for h in handles:
            h.remove()
        handles = []
    # Revert to the original value.
    for i in range(num_layers):
        mixtral_model.model.layers[i].mlp.num_experts_per_token = mixtral_model.config.num_experts_per_token

    if not use_short_input:
        # Count the norm-score ranking pairs.
        for i in range(num_layers):
            for j in range(num_tokens):
                curr_token_output = np.array([])
                for k in range(num_experts):
                    curr_token_output = np.append(curr_token_output, all_expert_output[i][k][j])
                curr_gate_score = all_gate_scores[i][0][j, :]
                norm_rank = np.argsort(curr_token_output)
                score_rank = np.argsort(curr_gate_score)
                # Replace the values with the corresponding rankings.
                for rank, idx in enumerate(norm_rank):
                    curr_token_output[idx] = rank
                for rank, idx in enumerate(score_rank):
                    curr_gate_score[idx] = rank
                for row, col in zip(curr_gate_score.tolist(), curr_token_output.tolist()):
                    rankings_counts[i][int(row), int(col)] += 1

In [None]:
# Save and plot.
if use_short_input:
    with open(os.path.join(output_dir, 'all_gate_scores'), 'wb') as f:
        pickle.dump(all_gate_scores, f)
    with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
        pickle.dump(all_gate_indices, f)
    with open(os.path.join(output_dir, 'all_expert_output'), 'wb') as f:
        pickle.dump(all_expert_output, f)

    for i in range(num_layers):
        plot_one_layer_short_seq(all_gate_scores, all_gate_indices, all_expert_output, i, num_tokens)

else:
    with open(os.path.join(output_dir, 'rankings_counts'), 'wb') as f:
        pickle.dump(rankings_counts, f)
    # Plot layer one by one.
    for l in range(num_layers):
        plot_one_layer_long_seq(rankings_counts[l], l)
    # Plot all layers.
    total_rankings_counts = rankings_counts[0]
    for l in range(1, num_layers):
        total_rankings_counts += rankings_counts[l]
    plot_one_layer_long_seq(total_rankings_counts, 'ALL')

### DeepSeek

In [None]:
# Input.
use_short_input = True # Set False to use the long sequence.
sentence_lst = []
if use_short_input:
    raw_input = "As an open source alternative to"
    sentence_lst.append(raw_input)
else:
    with open('./wikitext103_test.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\n')
        for row in csv_reader:
            sentences = row[0].split('\n')
            for sent in sentences:
                sent = sent.strip()
                if sent.startswith('=') or sent == '':
                    continue
                sentence_lst.append(sent)

cos = torch.nn.CosineSimilarity(dim=0)
num_layers = deepseek_model.config.num_hidden_layers
num_routed_experts = deepseek_model.config.n_routed_experts

tick_pos = [i*2-0.35 for i in range(0, num_routed_experts, 8)]
tick_labels = [str(i) for i in range(0, num_routed_experts, 8)]
save_dir = os.path.join(WORK_DIR, 'deepseek/deepseek_expert_norm')
if not use_short_input:
    save_dir += '_count'
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer_short_seq(all_gate_scores, all_gate_indices, all_expert_output, layer_idx, num_tokens):
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(36.0, 5.0))
    num_chosen_experts = deepseek_model.config.num_experts_per_tok
    for i in range(num_tokens):
        norm_lst = []
        for j in range(num_routed_experts):
            norm_lst.append(all_expert_output[layer_idx][j][i])
        im1 = axs[i].bar(np.arange(num_routed_experts)*2-0.35, norm_lst, label='Norm', width=0.6)
        axs[i].set_xticks(tick_pos, labels=tick_labels, fontsize=22)
        # Plot the gate scores.
        twin_ax = axs[i].twinx()
        im2 = twin_ax.bar(np.arange(num_routed_experts)*2, all_gate_scores[layer_idx][0][i, :], 
                          color='darkorange', align='edge', label='Score', width=0.5)
        chosen_experts = ''
        for j in range(num_chosen_experts):
            chosen_experts += str(all_gate_indices[layer_idx][0][i, j])
            if j != num_chosen_experts - 1:
                chosen_experts += ', '
        axs[i].set_title(f'exp {chosen_experts}', fontsize=22)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=18., fontsize=36)
        axs[i].legend(loc='upper left', fontsize=18)
        twin_ax.legend(loc='upper right', fontsize=18)
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def plot_one_layer_long_seq(rankings_counts, layer_idx):
    nrows = 8
    fig, axs = plt.subplots(nrows=nrows, layout='constrained', figsize=(24., 16.))
    num_expert_per_row = int(num_routed_experts/nrows)
    bar_width = 0.0125
    x = np.arange(num_expert_per_row)
    for i in range(nrows):
        for j in range(num_routed_experts):
            offset = bar_width * j
            im = axs[i].bar(x+offset, rankings_counts[j, num_expert_per_row*i:num_expert_per_row*(i+1)], bar_width)
        axs[i].set_xticks(x, [str(n+1) for n in range(num_expert_per_row*i, num_expert_per_row*(i+1))], fontsize=20)
        axs[i].tick_params(axis='y', labelsize=20)
        axs[i].set_ylim([0., np.max(rankings_counts)+1])
        if i % 4 == 1:
            axs[i].set_ylabel('Count of gate score ranking', fontsize=25)
    axs[-1].set_xlabel('Expert output norm ranking', fontsize=25)
    plt.savefig(os.path.join(plot_dir, f'layer{layer_idx}.png'))
    plt.close()


In [None]:
# Forward pass.

def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_expert_output(module, input, output, layer_idx, expert_idx):
    # output: Size([7, 2048]) = [num_tokens, hidden_dim]
    all_expert_output[layer_idx][expert_idx] = torch.norm(output, dim=1).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 


def record_gate_output(module, input, output, layer_idx):  
    bsz, seq_len, h = input[0].shape        
    logits = F.linear(input[0].view(-1, h), module.weight, None)
    scores = logits.softmax(dim=-1)
    topk_weight, topk_idx = torch.topk(scores, k=6, dim=-1, sorted=True)
    all_gate_scores[layer_idx].append(scores.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16))
    all_gate_indices[layer_idx].append(topk_idx.cpu().detach().numpy())


token_count = 0
rankings_counts = [np.zeros((num_routed_experts, num_routed_experts)) for _ in range(num_layers)]
for s, sent in enumerate(sentence_lst):
    if s == 10:
        break
    enc_input = deepseek_tok.encode(sent, return_tensors='pt').cuda()
    num_tokens = enc_input.shape[1]
    token_count += num_tokens
    print(s, num_tokens, token_count)
    all_layer_output = [[] for _ in range(num_layers)]
    all_expert_output = [{} for _ in range(num_layers)]
    all_gate_scores = [[] for _ in range(num_layers)]
    all_gate_indices = [[] for _ in range(num_layers)]
    handles = []

    # Obtain the original output feature vectors of experts 
    # and gate choices when topk=6. 
    for name, module in deepseek_model.named_modules():
        if isinstance(module, DeepseekDecoderLayer):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_layer_output, layer_idx=layer_idx)
            ))
        elif isinstance(module, MoEGate):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_gate_output, layer_idx=layer_idx)
            ))

    with torch.no_grad():
        output = deepseek_model(enc_input)
    for h in handles:
        h.remove()

    handles = []

    # Modify the number of chosen experts.
    for i in range(1, num_layers):
        curr_layer = deepseek_model.model.layers[i]
        if hasattr(curr_layer.mlp, 'num_experts_per_tok'):
            curr_layer.mlp.num_experts_per_tok = num_routed_experts
        if hasattr(curr_layer.mlp, 'gate'):
            curr_layer.mlp.gate.top_k = num_routed_experts

    # Iterate over the layers and register a hook once a time.
    for i in range(1, num_layers):
        for name, module in deepseek_model.named_modules():
            if isinstance(module, DeepseekMLP) and 'shared_experts' in name:
                continue
            elif isinstance(module, DeepseekMLP) and 'experts' in name:
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    expert_idx = int(name.split('.')[-1])
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=expert_idx)
                    ))
                elif layer_idx > i:
                    break
        with torch.no_grad():
            # Feed the topk=6 output of previous layer as input.
            output = deepseek_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
        for h in handles:
            h.remove()

    # Revert to the original value.
    for i in range(1, num_layers):
        curr_layer = deepseek_model.model.layers[i]
        if hasattr(curr_layer.mlp, 'num_experts_per_tok'):
            curr_layer.mlp.num_experts_per_tok = deepseek_model.config.num_experts_per_tok
        if hasattr(curr_layer.mlp, 'gate'):
            curr_layer.mlp.gate.top_k = deepseek_model.config.num_experts_per_tok

    if not use_short_input:
        for i in range(1, num_layers):
            for j in range(num_tokens):
                curr_token_output = np.array([])
                for k in range(num_routed_experts):
                    curr_token_output = np.append(curr_token_output, all_expert_output[i][k][j])
                curr_gate_score = all_gate_scores[i][0][j, :]
                norm_rank = np.argsort(curr_token_output)
                score_rank = np.argsort(curr_gate_score)
                # Replace the values with the corresponding rankings.
                for rank, idx in enumerate(norm_rank):
                    curr_token_output[idx] = rank
                for rank, idx in enumerate(score_rank):
                    curr_gate_score[idx] = rank
                for row, col in zip(curr_gate_score.tolist(), curr_token_output.tolist()):
                    rankings_counts[i][int(row), int(col)] += 1

In [None]:
# Save and plot.
if use_short_input:
    with open(os.path.join(output_dir, 'all_gate_scores'), 'wb') as f:
        pickle.dump(all_gate_scores, f)
    with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
        pickle.dump(all_gate_indices, f)
    with open(os.path.join(output_dir, 'all_expert_output'), 'wb') as f:
        pickle.dump(all_expert_output, f)

    for i in range(1, num_layers):
        plot_one_layer_short_seq(all_gate_scores, all_gate_indices, all_expert_output, i, num_tokens)

else:
    with open(os.path.join(output_dir, 'rankings_counts'), 'wb') as f:
        pickle.dump(rankings_counts, f)
    # Plot layer one by one.
    for l in range(1, num_layers):
        plot_one_layer_long_seq(rankings_counts[l], l)
    # Plot all layers.
    total_rankings_counts = rankings_counts[1]
    for l in range(2, num_layers):
        total_rankings_counts += rankings_counts[l]
    plot_one_layer_long_seq(total_rankings_counts, 'ALL')

### Grok

In [None]:
# Input.
use_short_input = True # Set False to use the long sequence.
sentence_lst = []
if use_short_input:
    raw_input = "As an open source alternative to"
    sentence_lst.append(raw_input)
else:
    with open('./wikitext103_test.csv') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\n')
        for row in csv_reader:
            sentences = row[0].split('\n')
            for sent in sentences:
                sent = sent.strip()
                if sent.startswith('=') or sent == '':
                    continue
                sentence_lst.append(sent)

num_layers = grok_model.config.num_hidden_layers
num_experts = grok_model.config.num_experts
tick_labels = [str(i) for i in range(num_experts)]
save_dir = os.path.join(WORK_DIR, 'grok/grok_expert_norm')
if not use_short_input:
    save_dir += '_count'
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer_short_seq(all_gate_scores, all_gate_indices, all_expert_output, layer_idx, num_tokens):
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(22.0, 2.8))
    for i in range(num_tokens):
        norm_lst = []
        for j in range(num_experts):
            norm_lst.append(all_expert_output[layer_idx][j][i])
        im1 = axs[i].bar(np.arange(num_experts)*2-0.35, norm_lst, label='Norm', width=0.6)
        # Plot the gate scores.
        twin_ax = axs[i].twinx()
        im2 = twin_ax.bar(np.arange(num_experts)*2, all_gate_scores[layer_idx][0][i, :], tick_label=tick_labels, 
                          color='darkorange', align='edge', label='Score', width=0.5)
        axs[i].set_xticks(np.arange(num_experts)*2, labels=tick_labels, fontsize=18)
        exp1, exp2 = all_gate_indices[layer_idx][0][i, 0], all_gate_indices[layer_idx][0][i, 1]
        axs[i].set_title(f'exp {exp1},{exp2}', fontsize=18)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=22)
        axs[i].legend(loc='upper left', fontsize=12)
        twin_ax.legend(loc='upper right', fontsize=12)
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def plot_one_layer_long_seq(rankings_counts, layer_idx):
    fig, ax = plt.subplots(layout='constrained', figsize=(6.5, 4.0))
    bar_width = 0.1
    x = np.arange(num_experts)
    for i in range(num_experts):
        offset = bar_width * i
        im = ax.bar(x+offset, rankings_counts[i, :], bar_width)
    ax.set_xticks(x+3.5*bar_width, [str(i+1) for i in range(num_experts)], fontsize=13)
    ax.tick_params(axis='y', labelsize=11)
    ax.set_xlabel('Expert output norm ranking', fontsize=15)
    ax.set_ylabel('Count of gate score ranking', fontsize=15)
    plt.savefig(os.path.join(plot_dir, f'layer{layer_idx}.png'))
    plt.close()


In [None]:
# Forward pass.

def record_layer_output(module, input, output, layer_idx):
    # output[0] shape: (num_tokens, hidden_dim)
    all_layer_output[layer_idx].append(output[0])


def record_gate_output(module, input, output, layer_idx):  
    router_logits = output
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    # shape: (num_tokens, topk)
    _, selected_experts = torch.topk(routing_weights, 2, dim=-1)
    all_gate_indices[layer_idx].append(selected_experts.cpu().detach().numpy())
    all_gate_scores[layer_idx].append(routing_weights.cpu().detach().numpy())


def record_expert_output(module, input, output, layer_idx, expert_idx):
    # output shape: (num_tokens, hidden_dim)
    all_expert_output[layer_idx][expert_idx] = torch.norm(output, dim=1).float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 


token_count = 0
rankings_counts = [np.zeros((num_experts, num_experts)) for _ in range(num_layers)]
for s, sent in enumerate(sentence_lst):
    if s == 10:
        break
    enc_input = grok_tok.encode(sent, return_tensors='pt').cuda()
    attention_mask = torch.ones_like(enc_input)
    inputs = {
        "input_ids": grok_tok(sent, return_tensors='pt').input_ids.cuda(),
        "attention_mask": attention_mask,
        "max_new_tokens": 1,
    }
    num_tokens = enc_input.shape[1]
    token_count += num_tokens
    print(s, num_tokens, token_count)
    all_layer_output = [[] for _ in range(num_layers)]
    all_expert_output = [{} for _ in range(num_layers)]
    all_gate_scores = [[] for _ in range(num_layers)]
    all_gate_indices = [[] for _ in range(num_layers)]
    handles = []

    # Obtain the original output feature vectors of experts 
    # and gate choices when topk=2. 
    for name, module in grok_model.named_modules():
        if isinstance(module, DecoderLayer):
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_layer_output, layer_idx=layer_idx)
            ))
        elif isinstance(module, torch.nn.Linear) and 'gate' in name:
            layer_idx = int(name.split('.')[2])
            handles.append(module.register_forward_hook(
                functools.partial(record_gate_output, layer_idx=layer_idx)
            ))

    output = grok_model.generate(**inputs)
    for h in handles:
        h.remove()
    handles = []

    # Modify the number of chosen experts to ALL.
    for i in range(num_layers):
        grok_model.model.layers[i].moe_block.top_k = num_experts
    # Iterate over the layers and register a hook once a time.
    for i in range(num_layers):
        for name, module in grok_model.named_modules():
            if isinstance(module, MoeMLP):
                layer_idx = int(name.split('.')[2])
                if layer_idx == i:
                    expert_idx = int(name.split('.')[-1])
                    handles.append(module.register_forward_hook(
                        functools.partial(record_expert_output, layer_idx=layer_idx, expert_idx=expert_idx)
                    ))
                elif layer_idx > i:
                    break
        if i == 0:
            with torch.no_grad():
                output = grok_model(enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
        else: 
            with torch.no_grad():
                # Feed the topk=2 output of previous layer as input.
                output = grok_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
        for h in handles:
            h.remove()
        handles = []
    # Revert to the original value.
    for i in range(num_layers):
        grok_model.model.layers[i].moe_block.top_k = grok_model.config.num_experts_per_tok


In [None]:
# Save and plot.
if use_short_input:
    with open(os.path.join(output_dir, 'all_gate_scores'), 'wb') as f:
        pickle.dump(all_gate_scores, f)
    with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
        pickle.dump(all_gate_indices, f)
    with open(os.path.join(output_dir, 'all_expert_output'), 'wb') as f:
        pickle.dump(all_expert_output, f)

    for i in range(num_layers):
        plot_one_layer_short_seq(all_gate_scores, all_gate_indices, all_expert_output, i, num_tokens)

else:
    with open(os.path.join(output_dir, 'rankings_counts'), 'wb') as f:
        pickle.dump(rankings_counts, f)
    # Plot layer one by one.
    for l in range(num_layers):
        plot_one_layer_long_seq(rankings_counts[l], l)
    # Plot all layers.
    total_rankings_counts = rankings_counts[0]
    for l in range(1, num_layers):
        total_rankings_counts += rankings_counts[l]
    plot_one_layer_long_seq(total_rankings_counts, 'ALL')

## Intermediate States of Experts

Only the short sequence is used in this section.

### Mixtral and Mistral

In [None]:
# Input.
raw_input = "As an open source alternative to"
mix_enc_input = mixtral_tok.encode(raw_input, return_tensors='pt') # mix_enc_input is actually the same as mis_enc_input.
mis_enc_input = mistral_tok.encode(raw_input, return_tensors='pt')

num_layers = mixtral_model.config.num_hidden_layers
num_experts = mixtral_model.config.num_experts
intermediate_size = mixtral_model.config.intermediate_size
num_tokens = mix_enc_input.shape[1]
all_layer_output = [[] for _ in range(num_layers)]
all_expert_act = [{} for _ in range(num_layers)]
all_gate_indices = [[] for _ in range(num_layers)]
handles = []


def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_gate_output(module, input, output, layer_idx):  
    scores = output
    _, expert_indices = torch.topk(scores, 2, dim=-1, sorted=True)
    all_gate_indices[layer_idx].append(expert_indices.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16))


def record_expert_act(module, input, output, layer_idx, expert_idx):
    # act_neurons size = [num_tokens, intermediate_size]
    act_neurons = F.silu(module.w1(input[0]))
    all_expert_act[layer_idx][expert_idx] = act_neurons.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16)


def record_ffn_output(module, input, output, layer_idx):
    # act_neurons size = [1, num_tokens, intermediate_size]
    act_neurons = F.silu(module.gate_proj(input[0]))
    all_expert_act[layer_idx][-1] = act_neurons.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16)


# Obtain the original output feature vectors of experts 
# and gate choices when topk=2. 
for name, module in mixtral_model.named_modules():
    if isinstance(module, MistralDecoderLayer):
        layer_idx = int(name.split('.')[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_layer_output, layer_idx=layer_idx)
        ))
    elif isinstance(module, torch.nn.Linear) and 'gate' in name:
        layer_idx = int(name.split('.')[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_gate_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    mix_output = mixtral_model(mix_enc_input)
for h in handles:
    h.remove()
handles = []

# Modify the number of chosen experts to ALL.
for i in range(num_layers):
    mixtral_model.model.layers[i].mlp.num_experts_per_token = num_experts
# Iterate over the layers and register a hook once a time.
for i in range(num_layers):
    for name, module in mixtral_model.named_modules():
        if isinstance(module, FeedForward):
            layer_idx = int(name.split('.')[2])
            if layer_idx == i:
                expert_idx = int(name.split('.')[-1])
                handles.append(module.register_forward_hook(
                    functools.partial(record_expert_act, layer_idx=layer_idx, expert_idx=expert_idx)
                ))
            elif layer_idx > i:
                break
    if i == 0:
        with torch.no_grad():
            mix_output = mixtral_model(mix_enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
    else: 
        with torch.no_grad():
            # Feed the topk=2 output of previous layer as input.
            mix_output = mixtral_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
    for h in handles:
        h.remove()
    handles = []

# Revert to the original value.
for i in range(num_layers):
    mixtral_model.model.layers[i].mlp.num_experts_per_token = mixtral_model.config.num_experts_per_token

# Obtain Mistral FFNs' output.
for name, module in mistral_model.named_modules():
    if isinstance(module, MistralMLP):
        layer_idx = int(name.split('.')[1])
        handles.append(module.register_forward_hook(
            functools.partial(record_ffn_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    mis_output = mistral_model(mis_enc_input)
for h in handles:
    h.remove()
    
global_vmin = math.inf
for i in range(num_layers):
    for act in all_expert_act[i].values():
        curr_vmin = np.min(act)
        if curr_vmin < global_vmin:
            global_vmin = curr_vmin

In [None]:
# Save and plot.
xtick_labels = [str(i) for i in range(0, intermediate_size, 4000)]
ytick_labels = [str(i) for i in range(num_experts)]
ytick_labels.append('F')
save_dir = os.path.join(WORK_DIR, 'mixtral/mixtral_experts_inter')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(os.path.join(plot_dir, 'auto_colorbar'), exist_ok=True)
os.makedirs(os.path.join(plot_dir, 'full_colorbar'), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

output_dict = {'global_vmin':global_vmin}
with open(os.path.join(output_dir, 'all_expert_act'), 'wb') as f:
    pickle.dump(all_expert_act, f)
with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
    pickle.dump(all_gate_indices, f)
with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
    pickle.dump(output_dict, f)


def plot_one_layer(all_expert_act, all_gate_indices, layer_idx, range_type, global_vmin=None):
    fig, axs = plt.subplots(nrows=3, layout='constrained', figsize=(16.0, 7.0))
    imlst = []
    for i in range(num_tokens):
        curr_map = np.empty((num_experts+1, intermediate_size))
        for j in range(num_experts):
            curr_map[j] = all_expert_act[layer_idx][j][i, :]
        curr_map[-1] = all_expert_act[layer_idx][-1][0, i, :]
        if range_type == 'auto_colorbar':
            im = axs[i].imshow(curr_map, aspect='auto')
            imlst.append(im)
        elif range_type == 'full_colorbar':
            im = axs[i].imshow(curr_map, aspect='auto', vmin=global_vmin, vmax=1.0)
        axs[i].set_xticks(np.arange(0, intermediate_size, 4000), labels=xtick_labels, fontsize=13)
        axs[i].set_yticks(np.arange(num_experts+1), labels=ytick_labels, fontsize=13)
        axs[i].set_yticks(np.arange(-.5, num_experts+1, 1), minor=True)
        axs[i].tick_params(axis='y', which='minor', length=0)
        axs[i].grid(axis='y', which='minor', color='k', linestyle='-', linewidth=.2)
        exp1, exp2 = all_gate_indices[layer_idx][0][i, 0], all_gate_indices[layer_idx][0][i, 1]
        axs[i].set_title(f'expert {exp1},{exp2}', fontsize=16)
    if range_type == 'auto_colorbar':
        local_vmin = min(img.get_array().min() for img in imlst)
        local_vmax = max(img.get_array().max() for img in imlst)
        norm = colors.Normalize(vmin=local_vmin, vmax=local_vmax)
        for img in imlst:
            img.set_norm(norm)
    fig.suptitle(f'Layer {layer_idx}', fontsize=22)
    cbar = fig.colorbar(im, ax=axs, shrink=1.)
    cbar.ax.tick_params(labelsize=15)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


for i in range(num_layers):
    plot_one_layer(all_expert_act, all_gate_indices, i, 'auto_colorbar')
    plot_one_layer(all_expert_act, all_gate_indices, i, 'full_colorbar', global_vmin)


### DeepSeek

In [None]:
# Input.
raw_input = "As an open source alternative to"
enc_input = deepseek_tok.encode(raw_input, return_tensors='pt').cuda()

cos = torch.nn.CosineSimilarity(dim=0)
num_layers = deepseek_model.config.num_hidden_layers
num_routed_experts = deepseek_model.config.n_routed_experts
intermediate_size = deepseek_model.config.moe_intermediate_size
num_tokens = enc_input.shape[1]
all_layer_output = [[] for _ in range(num_layers)]
all_expert_act = [{} for _ in range(num_layers)]
all_gate_indices = [[] for _ in range(num_layers)]
handles = []


def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_expert_act(module, input, output, layer_idx, expert_idx):
    # act_neurons size = [num_tokens, hidden_dim]
    act_neurons = F.silu(module.gate_proj(input[0]))
    all_expert_act[layer_idx][expert_idx] = act_neurons.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16) 


def record_gate_output(module, input, output, layer_idx):  
    expert_indices, expert_weights, _ = output
    all_gate_indices[layer_idx].append(expert_indices.cpu().detach().numpy())


# Obtain the original output feature vectors of experts 
# and gate choices when topk=6. 
for name, module in deepseek_model.named_modules():
    if isinstance(module, DeepseekDecoderLayer):
        layer_idx = int(name.split('.')[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_layer_output, layer_idx=layer_idx)
        ))
    elif isinstance(module, MoEGate):
        layer_idx = int(name.split('.')[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_gate_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    output = deepseek_model(enc_input)
for h in handles:
    h.remove()

handles = []

# Modify the number of chosen experts.
for i in range(1, num_layers):
    curr_layer = deepseek_model.model.layers[i]
    if hasattr(curr_layer.mlp, 'num_experts_per_tok'):
        curr_layer.mlp.num_experts_per_tok = num_routed_experts
    if hasattr(curr_layer.mlp, 'gate'):
        curr_layer.mlp.gate.top_k = num_routed_experts

# Iterate over the layers and register a hook once a time.
for i in range(1, num_layers):
    for name, module in deepseek_model.named_modules():
        if isinstance(module, DeepseekMLP) and '.experts' in name:
            layer_idx = int(name.split('.')[2])
            if layer_idx == i:
                expert_idx = int(name.split('.')[-1])
                handles.append(module.register_forward_hook(
                    functools.partial(record_expert_act, layer_idx=layer_idx, expert_idx=expert_idx)
                ))
            elif layer_idx > i:
                break
    with torch.no_grad():
        # Feed the topk=6 output of previous layer as input.
        output = deepseek_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
    for h in handles:
        h.remove()

# Revert to the original value.
for i in range(1, num_layers):
    curr_layer = deepseek_model.model.layers[i]
    if hasattr(curr_layer.mlp, 'num_experts_per_tok'):
        curr_layer.mlp.num_experts_per_tok = deepseek_model.config.num_experts_per_tok
    if hasattr(curr_layer.mlp, 'gate'):
        curr_layer.mlp.gate.top_k = deepseek_model.config.num_experts_per_tok

global_vmin = math.inf
for i in range(num_layers):
    for act in all_expert_act[i].values():
        curr_vmin = np.min(act)
        if curr_vmin < global_vmin:
            global_vmin = curr_vmin


In [None]:
# Save and plot.
xtick_labels = [str(i) for i in range(0, intermediate_size, 400)]
ytick_pos = [i for i in range(0, num_routed_experts, 4)]
ytick_labels = [str(i) for i in range(0, num_routed_experts, 4)]
save_dir = os.path.join(WORK_DIR, 'deepseek/deepseek_experts_inter')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(os.path.join(plot_dir, 'auto_colorbar'), exist_ok=True)
os.makedirs(os.path.join(plot_dir, 'full_colorbar'), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

output_dict = {'global_vmin':global_vmin}
with open(os.path.join(output_dir, 'all_expert_act'), 'wb') as f:
    pickle.dump(all_expert_act, f)
with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
    pickle.dump(all_gate_indices, f)
with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
    pickle.dump(output_dict, f)


def plot_one_layer(all_expert_act, all_gate_indices, layer_idx, range_type, global_vmin=None):
    fig, axs = plt.subplots(ncols=num_tokens, layout='constrained', figsize=(32.0, 5.0))
    num_chosen_experts = deepseek_model.config.num_experts_per_tok
    imlst = []
    for i in range(num_tokens):
        curr_map = np.empty((num_routed_experts, intermediate_size))
        for j in range(num_routed_experts):
            curr_map[j] = all_expert_act[layer_idx][j][i, :]
        chosen_experts = ''
        for j in range(num_chosen_experts):
            chosen_experts += str(all_gate_indices[layer_idx][0][i, j])
            if j != num_chosen_experts - 1:
                chosen_experts += ', '
        if range_type == 'auto_colorbar':
            im = axs[i].imshow(curr_map, aspect='auto')
            imlst.append(im)
        elif range_type == 'full_colorbar':
            im = axs[i].imshow(curr_map, aspect='auto', vmin=global_vmin, vmax=1.0)
        axs[i].set_xticks(np.arange(0, intermediate_size, 400), labels=xtick_labels, fontsize=15)
        axs[i].set_yticks(ytick_pos, labels=ytick_labels, fontsize=15)
        axs[i].set_yticks(np.arange(-.5, num_routed_experts, 1), minor=True)
        if i == 0:
            axs[i].set_ylabel(f'Layer {layer_idx}', labelpad=16., fontsize=30)
        axs[i].tick_params(axis='y', which='minor', length=0)
        axs[i].grid(axis='y', which='minor', color='k', linestyle='-', linewidth=.2)
        axs[i].set_title(f'exp {chosen_experts}', fontsize=18)
    if range_type == 'auto_colorbar':
        local_vmin = min(img.get_array().min() for img in imlst)
        local_vmax = max(img.get_array().max() for img in imlst)
        norm = colors.Normalize(vmin=local_vmin, vmax=local_vmax)
        for img in imlst:
            img.set_norm(norm)
    cbar = fig.colorbar(im, ax=axs, shrink=1.)
    cbar.ax.tick_params(labelsize=15)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


for i in range(1, num_layers):
    plot_one_layer(all_expert_act, all_gate_indices, i, 'auto_colorbar')
    plot_one_layer(all_expert_act, all_gate_indices, i, 'full_colorbar', global_vmin)


### Grok

In [None]:
# Input.
raw_input = "As an open source alternative to"
enc_input = grok_tok.encode(raw_input, return_tensors='pt').cuda()
attention_mask = torch.ones_like(enc_input)
inputs = {
    "input_ids": grok_tok(raw_input, return_tensors='pt').input_ids.cuda(),
    "attention_mask": attention_mask,
    "max_new_tokens": 1,
}

num_layers = grok_model.config.num_hidden_layers
num_experts = grok_model.config.num_experts
intermediate_size = grok_model.config.intermediate_size
num_tokens = enc_input.shape[1]
all_layer_output = [[] for _ in range(num_layers)]
all_expert_act = [{} for _ in range(num_layers)]
all_gate_indices = [[] for _ in range(num_layers)]
handles = []


def record_layer_output(module, input, output, layer_idx):
    all_layer_output[layer_idx].append(output[0])


def record_gate_output(module, input, output, layer_idx):  
    router_logits = output
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    # shape: (num_tokens, topk)
    routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
    all_gate_indices[layer_idx].append(selected_experts.cpu().detach().numpy())


def record_expert_act(module, input, output, layer_idx, expert_idx):
    # act_neurons shape: (num_tokens, intermediate_size)
    act_neurons = F.gelu(module.linear(input[0]))
    all_expert_act[layer_idx][expert_idx] = act_neurons.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16)


# Obtain the original output feature vectors of experts 
# and gate choices when topk=2. 
for name, module in grok_model.named_modules():
    if isinstance(module, DecoderLayer):
        layer_idx = int(name.split('.')[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_layer_output, layer_idx=layer_idx)
        ))
    elif isinstance(module, torch.nn.Linear) and 'gate' in name:
        layer_idx = int(name.split('.')[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_gate_output, layer_idx=layer_idx)
        ))

output = grok_model.generate(**inputs)
for h in handles:
    h.remove()
handles = []

# Modify the number of chosen experts to ALL.
for i in range(num_layers):
    grok_model.model.layers[i].moe_block.top_k = num_experts
# Iterate over the layers and register a hook once a time.
for i in range(num_layers):
    for name, module in grok_model.named_modules():
        if isinstance(module, MoeMLP):
            layer_idx = int(name.split('.')[2])
            if layer_idx == i:
                expert_idx = int(name.split('.')[-1])
                handles.append(module.register_forward_hook(
                    functools.partial(record_expert_act, layer_idx=layer_idx, expert_idx=expert_idx)
                ))
            elif layer_idx > i:
                break
    if i == 0:
        with torch.no_grad():
            output = grok_model(enc_input, decoder_layer_idx=i, use_cache=False) # Set use_cache=False to prevent error.
    else: 
        with torch.no_grad():
            # Feed the topk=2 output of previous layer as input.
            output = grok_model(inputs_embeds=all_layer_output[i-1][0], decoder_layer_idx=i, use_cache=False) 
    for h in handles:
        h.remove()
    handles = []
# Revert to the original value.
for i in range(num_layers):
    grok_model.model.layers[i].moe_block.top_k = grok_model.config.num_experts_per_tok

global_vmin = math.inf
for i in range(num_layers):
    for act in all_expert_act[i].values():
        curr_vmin = np.min(act)
        if curr_vmin < global_vmin:
            global_vmin = curr_vmin


In [None]:
# Save and plot.
xtick_labels = [str(i) for i in range(0, intermediate_size, 10000)]
ytick_labels = [str(i) for i in range(num_experts)]
save_dir = os.path.join(WORK_DIR, 'grok/grok_experts_inter')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(os.path.join(plot_dir, 'auto_colorbar'), exist_ok=True)
os.makedirs(os.path.join(plot_dir, 'full_colorbar'), exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

output_dict = {'global_vmin':global_vmin}
with open(os.path.join(output_dir, 'all_expert_act'), 'wb') as f:
    pickle.dump(all_expert_act, f)
with open(os.path.join(output_dir, 'all_gate_indices'), 'wb') as f:
    pickle.dump(all_gate_indices, f)
with open(os.path.join(output_dir, 'output_dict'), 'wb') as f:
    pickle.dump(output_dict, f)


def plot_one_layer(all_expert_act, all_gate_indices, layer_idx, range_type, global_vmin=None):
    fig, axs = plt.subplots(nrows=num_tokens, layout='constrained', figsize=(16.0, 14.0))
    imlst = []
    for i in range(num_tokens):
        curr_map = np.empty((num_experts, intermediate_size))
        for j in range(num_experts):
            curr_map[j] = all_expert_act[layer_idx][j][i, :]
        if range_type == 'auto_colorbar':
            im = axs[i].imshow(curr_map, aspect='auto')
            imlst.append(im)
        elif range_type == 'full_colorbar':
            im = axs[i].imshow(curr_map, aspect='auto', vmin=global_vmin, vmax=1.0)
        axs[i].set_xticks(np.arange(0, intermediate_size, 10000), labels=xtick_labels, fontsize=12)
        axs[i].set_yticks(np.arange(num_experts), labels=ytick_labels, fontsize=14)
        axs[i].set_yticks(np.arange(-.5, num_experts, 1), minor=True)
        axs[i].tick_params(axis='y', which='minor', length=0)
        axs[i].grid(axis='y', which='minor', color='k', linestyle='-', linewidth=.2)
        exp1, exp2 = all_gate_indices[layer_idx][0][i, 0], all_gate_indices[layer_idx][0][i, 1]
        axs[i].set_title(f'expert {exp1},{exp2}', fontsize=16)
    if range_type == 'auto_colorbar':
        local_vmin = min(img.get_array().min() for img in imlst)
        local_vmax = max(img.get_array().max() for img in imlst)
        norm = colors.Normalize(vmin=local_vmin, vmax=local_vmax)
        for img in imlst:
            img.set_norm(norm)
    fig.suptitle(f'Layer {layer_idx}', fontsize=22)
    cbar = fig.colorbar(im, ax=axs, shrink=1.)
    cbar.ax.tick_params(labelsize=12)
    plt.savefig(os.path.join(plot_dir, range_type, f'layer_{layer_idx}.png'))
    plt.close()


for i in range(num_layers):
    plot_one_layer(all_expert_act, all_gate_indices, i, 'auto_colorbar')
    plot_one_layer(all_expert_act, all_gate_indices, i, 'full_colorbar', global_vmin)


## Chosen Experts

In this experiment, we utilize another input containing about 64 tokens. In addition to the base model of Mixtral (Mixtral-Base), we include its instruct version (Mixtral-Instruct).

### Mixtral-Base

In [None]:
raw_input = "As an open source alternative to Chat GPT, I do not have personal opinions. However, I can provide objective information about Chat GPT's capabilities and limitations based on its architecture and training data. Chat GPT is a powerful language model based on the GPT (Generative Pre-trained Transformer"
enc_input = mixtral_tok.encode(raw_input, return_tensors="pt").cuda()

num_layers = mixtral_model.config.num_hidden_layers
num_experts = mixtral_model.config.num_experts
num_tokens = enc_input.shape[1]
gate_outputs = [[] for _ in range(num_layers)]
handles = []
xtick_labels = [mixtral_tok.decode(t) for t in enc_input[0]]
ytick_labels = [str(i) for i in range(num_experts)]
save_dir = os.path.join(WORK_DIR, 'mixtral/mixtral_gate_choice')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer(all_expert_weights, layer_idx):
    fig, ax = plt.subplots(layout='constrained', figsize=(14.0, 3.))
    im = ax.imshow(all_expert_weights, cmap=mlp.colormaps['Blues'], vmin=0., vmax=1.)
    ax.set_xticks(np.arange(num_tokens), labels=xtick_labels, rotation='vertical', fontsize=13)
    ax.set_yticks(np.arange(num_experts), labels=ytick_labels, fontsize=13)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=20)
    l, b, w, h = ax.get_position().bounds
    ax.set_position([l, b, w, h+0.3])
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def record_output(module, input, output, layer_idx):  
    scores = output
    expert_weights, expert_indices = torch.topk(scores, 2, dim=-1)
    expert_weights = expert_weights.softmax(dim=-1)
    gate_outputs[layer_idx].append((expert_weights.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16),
                                    expert_indices.cpu().detach().numpy()))
    
    
for name, module in mixtral_model.named_modules():
    if isinstance(module, torch.nn.Linear) and 'gate' in name:
        layer_idx = int(name.split(".")[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    output = mixtral_model(enc_input)
for h in handles:
    h.remove()

with open(os.path.join(output_dir, 'gate_outputs'), 'wb') as f:
    pickle.dump(gate_outputs, f)

for i, gate_output in enumerate(gate_outputs):
    expert_weights, expert_indices = gate_output[0]
    all_expert_weights = np.zeros((num_tokens, num_experts))
    all_expert_weights[np.arange(0, num_tokens), expert_indices.T] = expert_weights.T
    plot_one_layer(all_expert_weights.T, i)


### Mixtral-Instruct

In [None]:
raw_input = "As an open source alternative to Chat GPT, I do not have personal opinions. However, I can provide objective information about Chat GPT's capabilities and limitations based on its architecture and training data. Chat GPT is a powerful language model based on the GPT (Generative Pre-trained Transformer"
enc_input = mixtral_instruct_tok.encode(raw_input, return_tensors="pt")

num_layers = mixtral_instruct_model.config.num_hidden_layers
num_experts = mixtral_instruct_model.config.num_local_experts
num_tokens = enc_input.shape[1]
xtick_labels = [mixtral_instruct_tok.decode(t) for t in enc_input[0]]
ytick_labels = [str(i) for i in range(num_experts)]
gate_outputs = [[] for _ in range(num_layers)]
handles = []
save_dir = os.path.join(WORK_DIR, 'mixtral/mixtral_instuct_gate_choice')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer(all_expert_weights, layer_idx):
    fig, ax = plt.subplots(layout='constrained', figsize=(14.0, 3.))
    im = ax.imshow(all_expert_weights, cmap=mlp.colormaps['Blues'])
    ax.set_xticks(np.arange(num_tokens), labels=xtick_labels, rotation='vertical', fontsize=13)
    ax.set_yticks(np.arange(num_experts), labels=ytick_labels, fontsize=13)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=20)
    l, b, w, h = ax.get_position().bounds
    ax.set_position([l, b, w, h+0.3])
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def record_output(module, input, output, layer_idx):  
    scores = output
    expert_weights, expert_indices = torch.topk(scores, 2, dim=-1)
    expert_weights = expert_weights.softmax(dim=-1)
    gate_outputs[layer_idx].append((expert_weights.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16), 
                                    expert_indices.cpu().detach().numpy()))

    
for name, module in mixtral_instruct_model.named_modules():
    if isinstance(module, torch.nn.Linear) and 'gate' in name:
        layer_idx = int(name.split(".")[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    output = mixtral_instruct_model(enc_input)
for h in handles:
    h.remove()

with open(os.path.join(output_dir, 'gate_outputs'), 'wb') as f:
    pickle.dump(gate_outputs, f)

for i, gate_output in enumerate(gate_outputs):
    expert_weights, expert_indices = gate_output[0]
    all_expert_weights = np.zeros((num_tokens, num_experts))
    all_expert_weights[np.arange(0, num_tokens), expert_indices.T] = expert_weights.T
    plot_one_layer(all_expert_weights.T, i)

### DeepSeek

In [None]:
raw_input = "As an open source alternative to Chat GPT, I do not have personal opinions. However, I can provide objective information about Chat GPT's capabilities and limitations based on its architecture and training data. Chat GPT is a powerful language model based on the GPT (Generative Pre-trained Transformer"
enc_input = deepseek_tok.encode(raw_input, return_tensors="pt").cuda()

num_layers = deepseek_model.config.num_hidden_layers
num_routed_experts = deepseek_model.config.n_routed_experts
num_tokens = enc_input.shape[1]
xtick_labels = [deepseek_tok.decode(t) for t in enc_input[0]]
ytick_labels = [str(i) for i in range(0, num_routed_experts, 4)]
gate_outputs = [[] for _ in range(num_layers)]
handles = []
save_dir = os.path.join(WORK_DIR, 'deepseek/deepseek_gate_choice')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer(all_expert_weights, layer_idx):
    fig, ax = plt.subplots(layout='constrained', figsize=(10.0, 14.0))
    im = ax.imshow(all_expert_weights, cmap=mlp.colormaps['Blues'])
    ax.set_xticks(np.arange(num_tokens), labels=xtick_labels, rotation='vertical', fontsize=13.5)
    ax.set_yticks(np.arange(0, num_routed_experts, 4), labels=ytick_labels, fontsize=15)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=20)
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def record_output(module, input, output, layer_idx): 
    expert_indices, expert_weights, _ = output
    gate_outputs[layer_idx].append((expert_weights.float().cpu().detach().numpy().astype(ml_dtypes.bfloat16), expert_indices.cpu().detach().numpy()))

    
for name, module in deepseek_model.named_modules():
    if isinstance(module, MoEGate):
        layer_idx = int(name.split(".")[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    output = deepseek_model(enc_input)
for h in handles:
    h.remove()

with open(os.path.join(output_dir, 'gate_outputs'), 'wb') as f:
    pickle.dump(gate_outputs, f)

for i, gate_output in enumerate(gate_outputs[1:], start=1):
    expert_weights, expert_indices = gate_output[0]
    all_expert_weights = np.zeros((num_tokens, num_routed_experts))
    all_expert_weights[np.arange(0, num_tokens), expert_indices.T] = expert_weights.T
    plot_one_layer(all_expert_weights.T, i)

### Grok

In [None]:
raw_input = "As an open source alternative to Chat GPT, I do not have personal opinions. However, I can provide objective information about Chat GPT's capabilities and limitations based on its architecture and training data. Chat GPT is a powerful language model based on the GPT (Generative Pre-trained Transformer"
enc_input = grok_tok(raw_input, return_tensors="pt").input_ids
enc_input = enc_input.cuda()
attention_mask = torch.ones_like(enc_input)
inputs = {
    "input_ids": enc_input,
    "attention_mask": attention_mask,
    "max_new_tokens": 1,
}

num_layers = grok_model.config.num_hidden_layers
num_experts = grok_model.config.num_experts
num_tokens = enc_input.shape[1]
xtick_labels = [grok_tok.decode(t) for t in enc_input[0]]
ytick_labels = [str(i) for i in range(num_experts)]
gate_outputs = [[] for _ in range(num_layers)]
handles = []
save_dir = os.path.join(WORK_DIR, 'grok/grok_gate_choice')
plot_dir = os.path.join(save_dir, 'figure')
output_dir = os.path.join(save_dir, 'data')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)


def plot_one_layer(all_expert_weights, layer_idx):
    fig, ax = plt.subplots(layout='constrained', figsize=(14.0, 3.))
    im = ax.imshow(all_expert_weights, cmap=mlp.colormaps['Blues'], vmin=0., vmax=1.)
    ax.set_xticks(np.arange(num_tokens), labels=xtick_labels, rotation='vertical', fontsize=12)
    ax.set_yticks(np.arange(num_experts), labels=ytick_labels, fontsize=12)
    ax.set_ylabel(f'Layer {layer_idx}', labelpad=14., fontsize=20)
    l, b, w, h = ax.get_position().bounds
    ax.set_position([l, b, w, h+0.3])
    plt.savefig(os.path.join(plot_dir, f'layer_{layer_idx}.png'))
    plt.close()


def record_output(module, input, output, layer_idx):
    router_logits = output
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    # shape: (num_tokens, topk)
    routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
    gate_outputs[layer_idx].append((routing_weights.cpu().detach().numpy(), selected_experts.cpu().detach().numpy()))

    
for name, module in grok_model.named_modules():
    if isinstance(module, torch.nn.Linear) and 'gate' in name:
        layer_idx = int(name.split(".")[2])
        handles.append(module.register_forward_hook(
            functools.partial(record_output, layer_idx=layer_idx)
        ))

with torch.no_grad():
    output = grok_model.generate(**inputs)
for h in handles:
    h.remove()

with open(os.path.join(output_dir, 'gate_outputs'), 'wb') as f:
    pickle.dump(gate_outputs, f)

for i, gate_output in enumerate(gate_outputs):
    expert_weights, expert_indices = gate_output[0]
    all_expert_weights = np.zeros((num_tokens, num_experts))
    all_expert_weights[np.arange(0, num_tokens), expert_indices.T] = expert_weights.T
    plot_one_layer(all_expert_weights.T, i)
