In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    MixtralModel,
    MixtralConfig,
    AutoModelForCausalLM
)
from torch import nn
import collections
import time
import os
from einops import rearrange
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

cuda


In [51]:
layer = 15

W2 = 0
GATE = 1

sequence_length = 128
batch_size = 6
num_experts_per_token = 2
num_experts = 8
expert_ffn_dims = 14336
expert_hidden_dims = 4096

In [4]:
double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

config = MixtralConfig(
    num_experts_per_tok=num_experts_per_token,
    num_hidden_layers=layer + 1,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-v0.1",
    quantization_config=double_quant_config,
    attn_implementation="flash_attention_2",
    config=config,
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [5]:
def load_generated_dataset(
    filename, batch_size, dataset_relative_path="../dataset"
):
    dataset = pd.read_csv(f"{dataset_relative_path}/{filename}")
    # print(dataset.head(10))
    dataset_list = dataset["0"].tolist()
    dataloader = torch.utils.data.DataLoader(dataset_list, batch_size=batch_size)
    return dataloader


dataset = load_generated_dataset("pile.csv", batch_size=batch_size)

In [52]:
i = 0

moe = model.model.layers[layer].block_sparse_moe

hooks = []

for batch in tqdm(dataset):
    if i > 0:
        break
    i += 1

    mlps = [None] * num_experts
    router_logits = None

    def getActivation(expert_idx, type):

        def hook(model, input, output):
            global router_logits

            if type == W2:
                mlps[expert_idx] = output.detach()
            elif type == GATE:
                router_logits = output.detach()

        return hook

    for expert_idx in range(num_experts):

        w2_hook = moe.experts[expert_idx].w2.register_forward_hook(
            getActivation(expert_idx, W2)
        )
        gate_hook = moe.gate.register_forward_hook(getActivation(expert_idx, GATE))
        hooks.extend([w2_hook, gate_hook])

    tokenizer.pad_token = tokenizer.eos_token

    batch_tokens = tokenizer(
        batch,
        padding="max_length",
        truncation=True,
        max_length=sequence_length,
        return_tensors="pt",
    ).to(device)

    try:
        output = model(batch_tokens["input_ids"])
    except Exception as e:
        print("Exception occured, removing hooks")
        for hook in hooks:
            hook.remove()

    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    routing_weights, selected_experts = torch.topk(routing_weights, num_experts_per_token, dim=-1)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

    final_hidden_states = torch.zeros(
        (batch_size * sequence_length, expert_hidden_dims),
        device=device,
    )

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=8).permute(
        2, 1, 0
    )

    activations = torch.zeros(
        num_experts, batch_size * sequence_length, expert_hidden_dims, device=device
    )

    for expert_idx in range(num_experts):

        idx, top_x = torch.where(expert_mask[expert_idx])

        if top_x.shape[0] == 0:
            continue

        top_x_list = top_x.tolist()
        idx_list = idx.tolist()

        current_hidden_states = (
            mlps[expert_idx] * routing_weights[top_x_list, idx_list, None]
        )
        activations[expert_idx][top_x] = current_hidden_states

    activations = rearrange(activations, 'experts sequences hidden -> sequences (experts hidden)')
    print(activations.shape)

    for hook in hooks:
        hook.remove()

    # print(tokenizer.batch_decode(generated_ids))

  0%|          | 1/1939 [00:00<14:12,  2.27it/s]

torch.Size([8, 768, 4096])
torch.Size([768, 32768])





In [None]:
t1

In [7]:
Suppose the feature fires hard on the following activation:

4 -5 4 | 0 0 0 | 0 0 0 | 3 2 3

def feature_to_expert()
    feature: 0 1 0 0 0 0 0 0 0 
    feature * w_dec: 4 -5 4 | 1 -2 1 | 1 0 0 | 3 2 3 |
    sum of absolute values: 13 | 4 | 1 | 8
    softmax experts: .6, .1, .1, .2

SyntaxError: invalid syntax (3962263434.py, line 1)

We're trying to figure out how to map features back to experts. And how we decide to do this will change how we train the autoencoder. Is the following strategy sensible? Specifically, when going from activations to experts in the psuedocode below, will doing sum of absolute values fuck us up? Doesn't seem elegant. But do you buy that these are indeed the experts corresponding to some feature? If not, there's potentially other ideas we can try that require a lot more Mixtral fucking. 

We only turn on 2 experts (like the default model), zeroing out everything else.

Suppose on this activation a feature would fire really hard:

4 -5 4 | 0 0 0 | 0 0 0 | 3 2 3

def feature_to_expert(feature)
    feature: 0 1 0 0 0 0 0 0 0 
    feature * w_dec: 4 -5 4 | 1 -2 1 | 1 0 0 | 3 2 3 | (because sometimes it fires on other experts)
    sum of absolute values: 13 | 4 | 1 | 8
    normalize across experts: .5, .1, .1, .3