In [6]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    MixtralModel,
    MixtralConfig,
    AutoModelForCausalLM
)
from torch import nn
import collections
import time
import os
from einops import rearrange

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

cuda


In [3]:
layer = 15

W3 = 0
ACT_FN = 1

sequence_length = 128
batch_size = 6
num_experts = 8
expert_dims = 14336

In [4]:
double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

config = MixtralConfig(
    num_experts_per_tok=8,
    num_hidden_layers=layer + 1,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-v0.1",
    quantization_config=double_quant_config,
    attn_implementation="flash_attention_2",
    config=config,
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [8]:
def load_generated_dataset(
    filename, batch_size, dataset_relative_path="../dataset"
):
    dataset = pd.read_csv(f"{dataset_relative_path}/{filename}")
    # print(dataset.head(10))
    dataset_list = dataset["0"].tolist()
    dataloader = torch.utils.data.DataLoader(dataset_list, batch_size=batch_size)
    return dataloader


dataset = load_generated_dataset("pile.csv", batch_size=batch_size)

In [10]:
i = 0

experts = model.model.layers[layer].block_sparse_moe.experts


for batch in tqdm(dataset):
    if i > 0:
        break
    i += 1

    w3 = torch.zeros(num_experts, sequence_length * batch_size, expert_dims)
    act_fn = torch.zeros(num_experts, sequence_length * batch_size, expert_dims)


    def getActivation(expert_idx, type):
        def hook(model, input, output):
            if type == W3:
                w3[expert_idx] = output.detach()
            elif type == ACT_FN:
                act_fn[expert_idx] = output.detach()

        return hook

    for expert_idx in range(num_experts):
        experts[expert_idx].w3.register_forward_hook(getActivation(expert_idx, W3))
        experts[expert_idx].act_fn.register_forward_hook(getActivation(expert_idx, ACT_FN))

    hooks = []
    for expert_idx in range(num_experts):
        w3_hook = experts[expert_idx].w3.register_forward_hook(getActivation(expert_idx, W3))
        act_fn_hook = experts[expert_idx].act_fn.register_forward_hook(getActivation(expert_idx, ACT_FN))
        hooks.extend([w3_hook, act_fn_hook])

    tokenizer.pad_token = tokenizer.eos_token
    batch_tokens = tokenizer(
        batch,
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt",
    ).to(device)

    output = model(batch_tokens["input_ids"])

    acts = rearrange(w3 * act_fn, 'experts sequences dims -> sequences (experts dims)')
    print(acts.shape)

    for hook in hooks:
        hook.remove()

    # print(tokenizer.batch_decode(generated_ids))

  0%|          | 1/1939 [00:03<1:38:00,  3.03s/it]

tensor([[[-1.4424e+00,  1.0535e-01, -1.5508e+00,  ...,  5.1465e-01,
           6.0400e-01, -4.7119e-02],
         [-9.0820e-01, -3.0933e-01, -1.3018e+00,  ...,  3.8281e-01,
           4.7241e-01, -3.0566e-01],
         [-8.4229e-01, -1.4343e-01, -9.5020e-01,  ...,  1.1152e+00,
           3.5205e-01, -7.7393e-01],
         ...,
         [-1.8525e+00, -3.8208e-01, -4.5483e-01,  ..., -7.9883e-01,
          -1.5820e-01,  1.8091e-01],
         [ 1.1494e+00,  3.1689e-01,  2.3938e-01,  ...,  9.6375e-02,
          -4.3945e-01, -2.6904e-01],
         [ 1.2168e+00,  1.9617e-01, -1.2732e-01,  ...,  1.8958e-01,
          -3.1348e-01, -6.5723e-01]],

        [[ 1.7432e-01,  8.7256e-01, -6.5576e-01,  ...,  4.9194e-01,
          -2.8076e-01, -4.7394e-02],
         [-2.4902e-01, -5.9863e-01, -8.7256e-01,  ..., -3.7207e-01,
           5.1318e-01,  6.6992e-01],
         [ 1.9092e-01, -1.4294e-01, -1.4561e+00,  ...,  3.0908e-01,
           9.8779e-01, -1.9507e-01],
         ...,
         [ 4.3945e-01,  3


