In [None]:
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import itertools
from typing import Sequence
import functools

from dataclasses import dataclass

import os
import sys
if "notebooks" in os.path.abspath('.'):
    sys.path.append('../')
from traces import mtbench_mixtral_utils

In [None]:
traces = mtbench_mixtral_utils.load_all()

In [None]:
expert_counts_by_layer = sum(trace.expert_counts_by_layer() for trace in traces.values())
expert_freq_by_layer = expert_counts_by_layer / expert_counts_by_layer.sum(axis=1)[:, np.newaxis]

ax = sns.heatmap(expert_freq_by_layer, linewidths=.5, cmap='rocket_r',
                 cbar_kws={'label': 'Frequency'})
ax.invert_yaxis()
ax.set_yticks(0.5 + np.arange(0, 32, 4), np.arange(0, 32, 4))
plt.xlabel("Expert ID")
plt.ylabel("Layer")
plt.title("Mixtral on MTBench")

In [None]:
plt.figure(figsize=(10, 4))
sns.barplot(x=np.arange(0, 32), y=expert_freq_by_layer.std(axis=1))
plt.xlabel("Layer Number")
plt.ylabel("Standard Deviation\nof Expert Frequencies")

In [None]:
def p_next_expert_given_previous_experts(traces: dict, k=2):
    # (layer, previous_expert_0, previous_expert_1, next_expert)
    trace = next(iter(traces.values()))
    combinations = list(itertools.combinations(list(range(trace.num_experts)), k))
    num_combinations = len(combinations)
    combination_to_idx = {c: i for i, c in enumerate(combinations)}
    counts = np.zeros((trace.num_layers, num_combinations, trace.num_experts), dtype=np.uint32)
    for trace_name, trace in traces.items():
        # trace = trace.without_prefix()
        for layer_id, cur_experts, next_experts in zip(trace.layer_ids, trace.expert_ids, trace.expert_ids[1:]):
            # assert len(cur_experts) == 1, f'{len(cur_experts)} should be 1 for {trace_name}'
            for per_token_cur_experts, per_token_next_experts in zip(cur_experts, next_experts):
                combination = tuple(sorted(per_token_cur_experts))
                combination_idx = combination_to_idx[combination]
                counts[layer_id, combination_idx, per_token_next_experts] += 1

    sum = counts.sum(axis=-1)
    divisor = np.maximum(sum, np.ones_like(sum)) # Avoid dividing by 0
    probabilities = counts / divisor[:, :, np.newaxis]
    return probabilities, combination_to_idx

probabilities, experts_to_idx = p_next_expert_given_previous_experts(traces)

In [None]:
def plot_probabilities(layer_id, probabilities, experts_to_idx):
    ax = sns.heatmap(probabilities[layer_id], linewidths=.5, cmap='rocket_r',
                     cbar_kws={'label': 'Probability'}, vmin=0, vmax=0.5)
    ax.invert_yaxis()
    plt.xlabel("Next Experts")
    plt.ylabel("Current Experts")
    plt.title(f"Layer {layer_id}")
    
    yticks = np.arange(0, len(experts_to_idx), 1)
    idx_to_experts = {i: e for e, i in experts_to_idx.items()}
    yticklabels = [idx_to_experts[i] for i in yticks]
    ax.set_yticks(yticks + 0.5, yticklabels)
    # plt.xticks(fontsize=5)
    plt.yticks(rotation=0)
    plt.show()
    plt.close()

for i in range(32):
    plot_probabilities(i, probabilities, experts_to_idx)

In [None]:
def plot_probabilities(layer_id, probabilities, experts_to_idx):
    ax = sns.heatmap(probabilities[layer_id], linewidths=.5, cmap='rocket_r',
                     cbar_kws={'label': 'Probability'}, vmin=0, vmax=0.5)
    ax.invert_yaxis()
    plt.xlabel("Nth Most Frequent Next Expert")
    plt.ylabel("Current Experts")
    plt.title(f"Layer {layer_id}")
    
    yticks = np.arange(0, len(experts_to_idx), 1)
    idx_to_experts = {i: e for e, i in experts_to_idx.items()}
    yticklabels = [idx_to_experts[i] for i in yticks]
    ax.set_yticks(yticks + 0.5, yticklabels)
    # plt.xticks(fontsize=5)
    plt.yticks(rotation=0)
    plt.show()
    plt.close()

sorted_probabilities = np.sort(probabilities, axis=-1)[:, :, ::-1]
for i in range(32):
    plot_probabilities(i, sorted_probabilities, experts_to_idx)

In [None]:
# Single sequence

In [None]:
single_trace = {files[0]: traces[files[0]]}
probabilities, experts_to_idx = p_next_expert_given_previous_experts(single_trace)
sorted_probabilities = np.sort(probabilities, axis=-1)[:, :, ::-1]

def plot_probabilities(layer_id, probabilities, experts_to_idx):
    ax = sns.heatmap(probabilities[layer_id], linewidths=.5, cmap='rocket_r',
                     cbar_kws={'label': 'Probability'}, vmin=0)
    ax.invert_yaxis()
    plt.xlabel("Nth Most Frequent Next Expert")
    plt.ylabel("Current Experts")
    plt.title(f"Layer {layer_id}")
    
    yticks = np.arange(0, len(experts_to_idx), 1)
    idx_to_experts = {i: e for e, i in experts_to_idx.items()}
    yticklabels = [idx_to_experts[i] for i in yticks]
    ax.set_yticks(yticks + 0.5, yticklabels)
    # plt.xticks(fontsize=5)
    plt.yticks(rotation=0)
    plt.show()
    plt.close()

for i in range(32):
    plot_probabilities(i, sorted_probabilities, experts_to_idx)

In [None]:
def encode_selected_experts(selected_experts: Sequence[int], total_num_experts: int) -> int:
    combinations = list(itertools.combinations(range(total_num_experts), 2))
    combination = tuple(sorted(selected_experts))
    return combinations.index(combination)

trace = list(traces.values())[0]
experts_per_token = trace.experts_per_token()
experts_encoded = np.apply_along_axis(
            lambda x: encode_selected_experts(x, trace.num_experts), -1,
            experts_per_token)

counts = np.zeros((32, 28, 28), dtype=np.int32)
for experts, next_experts in zip(experts_encoded, experts_encoded[1:]):
    counts[np.arange(32), experts, next_experts] += 1

# Principal Component Analysis

In [None]:
# PCA with all sequences truncated to minimum sequence length.
def sequence_of_experts(trace: QueryTrace) -> Sequence[Sequence[int]]:
    experts = itertools.chain.from_iterable(trace.expert_ids)
    return [tuple(sorted(e)) for e in experts]

expert_sequences = []
for trace in traces.values():
    experts = sequence_of_experts(trace)
    expert_ids = [experts_to_idx[e] for e in experts]

    expert_sequences.append(expert_ids)

In [None]:
min_token_length = min(t.num_tokens for t in traces.values())

truncated_expert_sequences = np.array([e[:min_token_length] for e in expert_sequences])
demeaned = truncated_expert_sequences - truncated_expert_sequences.mean(axis=-1)[:, np.newaxis]

U, s, Vt = np.linalg.svd(demeaned)

plt.stem(s)
plt.title("Singular Values")

# Poor Man's Attempt at a Frequency Analysis

In [None]:
def sequence_of_experts(trace: QueryTrace) -> Sequence[Sequence[int]]:
    experts = itertools.chain.from_iterable(trace.expert_ids)
    return [tuple(sorted(e)) for e in experts]

expert_sequences = []
for trace in traces.values():
    experts = sequence_of_experts(trace)
    expert_ids = [experts_to_idx[e] for e in experts]

    expert_sequences.append(expert_ids)

    # f, Pxx_den = signal.periodogram(expert_ids)

    # plt.figure()
    # plt.semilogy(f, Pxx_den)
    # plt.plot()

In [None]:
pxx_densities = [] # per sequence

for s in expert_sequences:
    f, pxx_den = signal.periodogram(s)
    pxx_densities.append(pxx_den)

pxx_densities_per_token = []
for token_idx in range(max(map(len, expert_sequences))):
    per_token_densities = []
    for seq in expert_sequences:
        if token_idx < len(seq):
            per_token_densities.append(seq[token_idx])
    pxx_densities_per_token.append(per_token_densities)

In [None]:
max_tokens = len(pxx_densities_per_token[0])
min_tokens = min(t.num_tokens for t in traces.values())
print(min_tokens)
print(max(t.num_tokens for t in traces.values()))
# Variance
vars = [np.var(s) for s in pxx_densities_per_token]
sns.lineplot(vars)
plt.title("Variance of Power Spectral Density across Sequences")

In [None]:
sns.lineplot(vars[:10])

In [None]:
pxx_densities_per_token[4]

In [None]:
vars[1]

In [None]:
sns.ecdfplot(vars)

In [None]:
vars

In [None]:
# Calculate probabilities
sums = counts.sum(axis=-1)
divisor = np.maximum(np.ones_like(sums), sums) # Avoid dividing by 0.
probabilities = counts / divisor[:, np.newaxis]

In [None]:
probabilities

In [None]:
len(list(itertools.combinations(list(range(8)), 2)))

In [None]:
itertools.combinations?

In [None]:
probabilities.shape

ax = sns.heatmap(probabilities[0].reshape(64, 8), linewidths=.5, cmap='rocket_r',
                 cbar_kws={'label': 'Frequency'})

In [None]:
num_layers = len(np.unique(layer_ids))
# print('num_layers', num_layers)

num_experts = len(np.unique(list(itertools.chain.from_iterable(itertools.chain.from_iterable(expert_ids)))))
# print('num_experts', num_experts)

expert_counts_by_layer = np.zeros((num_layers, num_experts))
for layer_id, exp_ids in zip(layer_ids, expert_ids):
    # print(np.unique(exp_ids, return_counts=True))
    exp_ids_flat = list(itertools.chain.from_iterable(exp_ids))
    # print(np.bincount(exp_ids_flat + list(range(num_experts))) - 1)
    counts = np.bincount(exp_ids_flat + list(range(num_experts))) - 1 # Ensure all experts are included.
    expert_counts_by_layer[layer_id] += np.bincount(exp_ids_flat + list(range(num_experts)))

# print(expert_counts_by_layer.sum(axis=1))
ax = sns.heatmap(expert_counts_by_layer, linewidths=.5, cmap='rocket_r')
ax.invert_yaxis()
ax.set_yticks(0.5 + np.arange(0, 32, 4), np.arange(0, 32, 4))
plt.xlabel("Expert ID")
plt.ylabel("Layer")

In [None]:
counts.sum(axis=-1).shape

In [None]:
counts.shape

In [None]:
trace = traces['MTBench-Mixtral/139916901616816.txt']
for i, x in enumerate([len(e_ids) for e_ids in trace.expert_ids]):
    if x > 1:
        print(i)


In [None]:
a = np.zeros((2, 2, 2))
a[0, [0, 1], 0] += 1
a[0, :, 0]

In [None]:
@dataclass(frozen=True)
class Invocation:
    layer_id: int
    input_ids: np.ndarray
    expert_ids: np.ndarray

def groupby_invocations(layer_ids, input_ids, expert_ids):
    return [
            Invocation(
                layer_id=layer_id,
                input_ids=np.array(inp_ids),
                expert_ids=np.array(exp_ids),
            )
            for layer_id, inp_ids, exp_ids in zip(layer_ids, input_ids, expert_ids)
        ]

invocations = groupby_invocations(layer_ids, input_ids, expert_ids)

In [None]:
invocations[0] 