In [None]:
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import itertools
from typing import Sequence
import functools

from dataclasses import dataclass

In [None]:
# parse log
def parse_input_ids(data):
    matches = re.findall(r"input_ids: tensor\(\[\[([\s\d,]+)", data)
    return [[int(s.strip()) for s in match.split(',')] for match in matches]

def parse_expert_ids(data):
    matches = re.findall(r"Layer (\d+) topk_ids: tensor\(\[([\[\]\d,\s]+)\], device", data)
    layer_ids = [int(t[0]) for t in matches]
    expert_ids = [[[int(x) for x in s.strip()[1:-1].split(', ')] for s in t[1].split(',\n')] for t in matches]
    return layer_ids, expert_ids

def parse_file(filename):
    with open(filename, "r") as f:
        data = f.read()

    input_ids = parse_input_ids(data)
    layer_ids, expert_ids = parse_expert_ids(data)
    assert len(layer_ids) == len(expert_ids), f"len(layer_ids)={len(layer_ids)} and len(expert_ids)={len(expert_ids)}"
    assert len(layer_ids) == len(input_ids), f"len(layer_ids)={len(layer_ids)} and len(input_ids)={len(input_ids)}"

    return layer_ids, input_ids, expert_ids


filename = "MTBench_Mixtral/139916882249008.txt"
layer_ids, input_ids, expert_ids = parse_file(filename)

In [None]:
@dataclass(frozen=True)
class QueryTrace:
    """Trace for a query to the LLM."""
    layer_ids: Sequence[int]
    expert_ids: Sequence[Sequence[Sequence[int]]] # layer, token, selected experts
    token_expert_ids: Sequence[Sequence[Sequence[int]]] # token, layer, selected experts
    input_ids: Sequence[Sequence[int]] # layer, inputs
    has_prefix: bool = True

    @staticmethod
    def from_file(filename):
        layer_ids, input_ids, expert_ids = parse_file(filename)
        structured_data = {}
        # Iterate through each layer and corresponding expert selections
        for layer_index, layer in enumerate(layer_ids):
            layer_data = structured_data[layer] if layer in structured_data else []
            for token_index, token_id in enumerate(input_ids[layer_index]):
                # Extract experts for the current token at the current layer
                layer_data.append(expert_ids[layer_index][token_index])
                # Append structured information for the current token
            if layer not in structured_data:
                structured_data[layer] = layer_data
        token_experts_id = []
        for layer in range(32):
            token_experts_id.append(structured_data[layer])
        return QueryTrace(layer_ids=layer_ids, expert_ids=expert_ids, token_expert_ids=token_experts_id, input_ids=input_ids)

    @functools.cached_property
    def num_layers(self) -> int:
        return len(np.unique(self.layer_ids))

    @functools.cached_property
    def num_experts(self) -> int:
        exp_ids_flat = list(itertools.chain.from_iterable(itertools.chain.from_iterable(self.expert_ids)))
        return len(np.unique(exp_ids_flat))

    @functools.cached_property
    def prompt_token_length(self) -> int:
        return len(self.input_ids[0])

    @functools.cached_property
    def num_generated_tokens(self) -> int:
        return len(self.input_ids[1:])

    def expert_counts_by_layer(self) -> np.ndarray:
        expert_counts_by_layer = np.zeros((self.num_layers, self.num_experts), dtype=np.int32)
        for layer_id, exp_ids in zip(self.layer_ids, self.expert_ids):
            exp_ids_flat = list(itertools.chain.from_iterable(exp_ids))
            counts = np.bincount(exp_ids_flat + list(range(self.num_experts))) - 1 # Ensure all experts are included.
            expert_counts_by_layer[layer_id] += np.bincount(exp_ids_flat + list(range(self.num_experts)))

        return expert_counts_by_layer

    def without_prefix(self):
        return QueryTrace(layer_ids=self.layer_ids[self.num_layers:],
                          expert_ids=self.expert_ids[self.num_layers:],
                          input_ids=self.input_ids[self.num_layers:],
                          has_prefix=False)


files = !ls MTBench_Mixtral/*

traces = {}
for filename in files:
    try:
        trace = QueryTrace.from_file(filename)
        traces[filename] = trace
    except AssertionError:
        pass

In [None]:
expert_counts_by_layer = sum(trace.expert_counts_by_layer() for trace in traces.values())
expert_freq_by_layer = expert_counts_by_layer / expert_counts_by_layer.sum(axis=1)[:, np.newaxis]

ax = sns.heatmap(expert_freq_by_layer, linewidths=.5, cmap='rocket_r',
                 cbar_kws={'label': 'Frequency'})
ax.invert_yaxis()
ax.set_yticks(0.5 + np.arange(0, 32, 4), np.arange(0, 32, 4))
plt.xlabel("Expert ID")
plt.ylabel("Layer")
plt.title("Mixtral on MTBench")

In [None]:
plt.figure(figsize=(10, 4))
sns.barplot(x=np.arange(0, 32), y=expert_freq_by_layer.std(axis=1))
plt.xlabel("Layer Number")
plt.ylabel("Standard Deviation\nof Expert Frequencies")

In [None]:
def p_next_expert_given_previous_experts(traces: dict, k=2):
    # (layer, previous_expert_0, previous_expert_1, next_expert)
    trace = next(iter(traces.values()))
    combinations = list(itertools.combinations(list(range(trace.num_experts)), k))
    print(combinations)
    num_combinations = len(combinations)
    combination_to_idx = {c: i for i, c in enumerate(combinations)}
    counts = np.zeros((trace.num_layers, num_combinations, trace.num_experts), dtype=np.uint32)
    for trace_name, trace in traces.items():
        # trace = trace.without_prefix()
        for layer_id, cur_experts, next_experts in zip(trace.layer_ids, trace.expert_ids, trace.expert_ids[1:]):
            # assert len(cur_experts) == 1, f'{len(cur_experts)} should be 1 for {trace_name}'
            for per_token_cur_experts, per_token_next_experts in zip(cur_experts, next_experts):
                combination = tuple(sorted(per_token_cur_experts))
                combination_idx = combination_to_idx[combination]
                counts[layer_id, combination_idx, per_token_next_experts] += 1

    sum = counts.sum(axis=-1)
    divisor = np.maximum(sum, np.ones_like(sum)) # Avoid dividing by 0
    probabilities = counts / sum[:, :, np.newaxis]
    return probabilities, combination_to_idx

probabilities, experts_to_idx = p_next_expert_given_previous_experts(traces)

In [None]:
def plot_probabilities(layer_id, probabilities, experts_to_idx):
    ax = sns.heatmap(probabilities[layer_id], linewidths=.5, cmap='rocket_r',
                     cbar_kws={'label': 'Probability'}, vmin=0, vmax=0.5)
    ax.invert_yaxis()
    plt.xlabel("Next Experts")
    plt.ylabel("Current Experts")
    plt.title(f"Layer {layer_id}")
    
    yticks = np.arange(0, len(experts_to_idx), 1)
    idx_to_experts = {i: e for e, i in experts_to_idx.items()}
    yticklabels = [idx_to_experts[i] for i in yticks]
    ax.set_yticks(yticks + 0.5, yticklabels)
    # plt.xticks(fontsize=5)
    plt.yticks(rotation=0)
    plt.show()
    plt.close()

for i in range(32):
    plot_probabilities(i, probabilities, experts_to_idx)

In [None]:
def p_next_experts_given_previous_experts(traces: dict, k=2):
    # (layer, previous_expert_0, previous_expert_1, next_expert)
    trace = next(iter(traces.values()))
    combinations = list(itertools.combinations(list(range(trace.num_experts)), k))
    num_combinations = len(combinations)
    combination_to_idx = {c: i for i, c in enumerate(combinations)}
    counts = np.zeros((trace.num_layers, num_combinations, num_combinations), dtype=np.uint32)
    for trace_name, trace in traces.items():
        # trace = trace.without_prefix()
        for layer_id, cur_experts, next_experts in zip(trace.layer_ids, trace.expert_ids, trace.expert_ids[1:]):
            # assert len(cur_experts) == 1, f'{len(cur_experts)} should be 1 for {trace_name}'
            for per_token_cur_experts, per_token_next_experts in zip(cur_experts, next_experts):
                print(per_token_next_experts)
                combination = tuple(sorted(per_token_cur_experts))
                combination_idx = combination_to_idx[combination]
                next_comb = tuple(sorted(per_token_next_experts))
                next_comb_idx = combination_to_idx[next_comb]
                counts[layer_id, combination_idx, next_comb_idx] += 1

    sum = counts.sum(axis=-1)
    divisor = np.maximum(sum, np.ones_like(sum)) # Avoid dividing by 0
    probabilities = counts / sum[:, :, np.newaxis]
    return probabilities, combination_to_idx

probabilities, experts_to_idx = p_next_experts_given_previous_experts(traces)

In [None]:
def plot_probabilities(layer_id, probabilities, experts_to_idx):
    ax = sns.heatmap(probabilities[layer_id], linewidths=.5, cmap='rocket_r',
                     cbar_kws={'label': 'Probability'}, vmin=0, vmax=0.5)
    ax.invert_yaxis()
    plt.xlabel("Next Experts")
    plt.ylabel("Current Experts")
    plt.title(f"Layer {layer_id}")
    
    yticks = np.arange(0, len(experts_to_idx), 1)
    idx_to_experts = {i: e for e, i in experts_to_idx.items()}
    yticklabels = [idx_to_experts[i] for i in yticks]
    ax.set_yticks(yticks + 0.5, yticklabels)

    ax.set_xticks(yticks + 0.5, yticklabels)
    plt.xticks(rotation=90)
    # plt.xticks(fontsize=5)
    plt.yticks(rotation=0)
    plt.savefig(f"layer{layer_id}_layer_layer.png", dpi=300)
    plt.show()
    plt.close()

for i in range(32):
    plot_probabilities(i, probabilities, experts_to_idx)

In [None]:
def p_next_token_experts_given_previous_token_experts(traces: dict, k=2):
    # (layer, previous_expert_0, previous_expert_1, next_expert)
    trace = next(iter(traces.values()))
    combinations = list(itertools.combinations(list(range(trace.num_experts)), k))
    num_combinations = len(combinations)
    combination_to_idx = {c: i for i, c in enumerate(combinations)}
    counts = np.zeros((trace.num_layers, num_combinations, num_combinations), dtype=np.uint32)
    for trace_name, trace in traces.items():
        for layer_id in list(range(trace.num_layers)):
            for per_token_cur_experts, per_token_next_experts in zip(trace.token_expert_ids[layer_id], trace.token_expert_ids[layer_id][1:]):
                print(per_token_next_experts)
                combination = tuple(sorted(per_token_cur_experts))
                combination_idx = combination_to_idx[combination]
                next_comb = tuple(sorted(per_token_next_experts))
                next_comb_idx = combination_to_idx[next_comb]
                counts[layer_id, combination_idx, next_comb_idx] += 1

    sum = counts.sum(axis=-1)
    divisor = np.maximum(sum, np.ones_like(sum)) # Avoid dividing by 0
    probabilities = counts / sum[:, :, np.newaxis]
    return probabilities, combination_to_idx

probabilities, experts_to_idx = p_next_token_experts_given_previous_token_experts(traces)

In [None]:
def plot_probabilities(layer_id, probabilities, experts_to_idx):
    ax = sns.heatmap(probabilities[layer_id], linewidths=.5, cmap='rocket_r',
                     cbar_kws={'label': 'Probability'}, vmin=0, vmax=0.5)
    ax.invert_yaxis()
    plt.xlabel("Next Experts")
    plt.ylabel("Current Experts")
    plt.title(f"Layer {layer_id}")
    
    yticks = np.arange(0, len(experts_to_idx), 1)
    idx_to_experts = {i: e for e, i in experts_to_idx.items()}
    yticklabels = [idx_to_experts[i] for i in yticks]
    ax.set_yticks(yticks + 0.5, yticklabels)

    ax.set_xticks(yticks + 0.5, yticklabels)
    plt.xticks(rotation=90)
    # plt.xticks(fontsize=5)
    plt.yticks(rotation=0)
    plt.savefig(f"layer{layer_id}_token_token.png", dpi=300)
    plt.show()
    plt.close()

for i in range(32):
    plot_probabilities(i, probabilities, experts_to_idx)