In [7]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    MixtralModel,
    MixtralConfig,
    AutoModelForCausalLM
)
from torch import nn
import collections
import time
import os
from einops import rearrange
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

cuda


In [3]:
layer = 15

W3 = 0
ACT_FN = 1
W2 = 2
GATE = 3

sequence_length = 128
batch_size = 6
num_experts = 8
expert_ffn_dims = 14336
expert_hidden_dims = 4096

In [4]:
double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

config = MixtralConfig(
    num_experts_per_tok=8,
    num_hidden_layers=layer + 1,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-v0.1",
    quantization_config=double_quant_config,
    attn_implementation="flash_attention_2",
    config=config,
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [5]:
def load_generated_dataset(
    filename, batch_size, dataset_relative_path="../dataset"
):
    dataset = pd.read_csv(f"{dataset_relative_path}/{filename}")
    # print(dataset.head(10))
    dataset_list = dataset["0"].tolist()
    dataloader = torch.utils.data.DataLoader(dataset_list, batch_size=batch_size)
    return dataloader


dataset = load_generated_dataset("pile.csv", batch_size=batch_size)

In [21]:
i = 0

moe = model.model.layers[layer].block_sparse_moe

hooks = []

for batch in tqdm(dataset):
    if i > 0:
        break
    i += 1

    w3 = torch.zeros(num_experts, sequence_length * batch_size, expert_ffn_dims, device=device)
    act_fn = torch.zeros(
        num_experts, sequence_length * batch_size, expert_ffn_dims, device=device
    )
    mlp = torch.zeros(num_experts, sequence_length * batch_size, expert_hidden_dims, device=device)
    router_logits = torch.zeros(sequence_length * batch_size, num_experts, device=device)

    def getActivation(expert_idx, type):
        def hook(model, input, output):
            if type == W3:
                w3[expert_idx] = output.detach()
            elif type == ACT_FN:
                act_fn[expert_idx] = output.detach()
            elif type == W2:
                mlp[expert_idx] = output.detach()
            elif type == GATE:
                router_logits[:] = output.detach()
                # print(router_logits)

        return hook

    for expert_idx in range(num_experts):
        # w3_hook = experts[expert_idx].w3.register_forward_hook(getActivation(expert_idx, W3))
        # act_fn_hook = experts[expert_idx].act_fn.register_forward_hook(getActivation(expert_idx, ACT_FN))
        w2_hook = moe.experts[expert_idx].w2.register_forward_hook(getActivation(expert_idx, W2))
        gate_hook = moe.gate.register_forward_hook(getActivation(expert_idx, GATE))
        # hooks.extend([w3_hook, act_fn_hook])
        hooks.extend([w2_hook, gate_hook])

    tokenizer.pad_token = tokenizer.eos_token

    batch_tokens = tokenizer(
        batch,
        padding="max_length",
        truncation=True,
        max_length=sequence_length,
        return_tensors="pt",
    ).to(device)

    try:
        output = model(batch_tokens["input_ids"])
    except Exception as e:
        for hook in hooks:
            hook.remove()
    # print(w3.shape)
    # acts = w3 * act_fn
    # print(acts.shape)
    # acts = rearrange(
    # acts, "experts sequences dims -> sequences experts dims"
    # )
    # print(acts.shape)
    print(mlp.shape)
    print(router_logits.shape)
    print(router_logits)

    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    print(routing_weights)
    routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
    print(routing_weights.shape)
    print(selected_experts)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

    print(routing_weights.shape)

    for hook in hooks:
        hook.remove()

    # print(tokenizer.batch_decode(generated_ids))

  0%|          | 1/1939 [00:01<32:33,  1.01s/it]

torch.Size([8, 768, 4096])
torch.Size([768, 8])
tensor([[-0.1617, -0.1150, -0.0352,  ..., -0.0876,  0.0043, -0.0533],
        [-0.1617, -0.1150, -0.0352,  ..., -0.0876,  0.0043, -0.0533],
        [-0.5854,  1.0127,  0.3003,  ..., -0.4873,  0.6367,  0.2510],
        ...,
        [ 1.5879, -0.1747, -0.6221,  ...,  0.5190, -0.8062, -0.5576],
        [ 0.7373,  0.0413, -0.6836,  ...,  0.2683,  0.0616, -1.1279],
        [-0.9150,  0.7690, -0.8047,  ..., -0.1301,  0.7295,  0.6567]],
       device='cuda:0')
tensor([[0.1124, 0.1177, 0.1275,  ..., 0.1210, 0.1327, 0.1252],
        [0.1124, 0.1177, 0.1275,  ..., 0.1210, 0.1327, 0.1252],
        [0.0578, 0.2858, 0.1402,  ..., 0.0638, 0.1962, 0.1334],
        ...,
        [0.4778, 0.0820, 0.0524,  ..., 0.1641, 0.0436, 0.0559],
        [0.2511, 0.1252, 0.0606,  ..., 0.1571, 0.1278, 0.0389],
        [0.0433, 0.2333, 0.0484,  ..., 0.0949, 0.2242, 0.2085]],
       device='cuda:0')
torch.Size([768, 2])
tensor([[4, 6],
        [4, 6],
        [1, 6],
   




In [None]:
Suppose the feature fires hard on the following activation:

4 -5 4 | 0 0 0 | 0 0 0 | 3 2 3

def feature_to_expert()
    feature: 0 1 0 0 0 0 0 0 0 
    feature * w_dec: 4 -5 4 | 1 -2 1 | 1 0 0 | 3 2 3 |
    sum of absolute values: 13 | 4 | 1 | 8
    softmax experts: .6, .1, .1, .2

We're trying to figure out how to map features back to experts. And how we decide to do this will change how we train the autoencoder. Is the following strategy sensible? Specifically, when going from activations to experts in the psuedocode below, will doing sum of absolute values fuck us up? Doesn't seem elegant. But do you buy that these are indeed the experts corresponding to some feature? If not, there's potentially other ideas we can try that require a lot more Mixtral fucking. 

We only turn on 2 experts (like the default model), zeroing out everything else.

Suppose on this activation a feature would fire really hard:

4 -5 4 | 0 0 0 | 0 0 0 | 3 2 3

def feature_to_expert(feature)
    feature: 0 1 0 0 0 0 0 0 0 
    feature * w_dec: 4 -5 4 | 1 -2 1 | 1 0 0 | 3 2 3 | (because sometimes it fires on other experts)
    sum of absolute values: 13 | 4 | 1 | 8
    normalize across experts: .5, .1, .1, .3