In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from torch import nn
import collections

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-v0.1",
    quantization_config=double_quant_config,
    # attn_implementation="flash_attention_2",
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

In [18]:
activation = collections.defaultdict(list)

def getActivation(name):
    def hook(model, input, output):
        assert type(output) is not tuple
        activation[name].append(output.detach())

    return hook

experts = model.model.layers[15].block_sparse_moe.experts

for expert_idx in range(8):
    experts[expert_idx].w3.register_forward_hook(getActivation(f"layer_15_expert_{expert_idx}_w_3"))
    experts[expert_idx].act_fn.register_forward_hook(
        getActivation(f"layer_15_expert_{expert_idx}_act_w1")
    )
    experts[expert_idx].w2.register_forward_hook(getActivation(f"layer_15_expert_{expert_idx}_w_2"))


# gate_hook = model.model.layers[15].block_sparse_moe.gate.register_forward_hook(getActivation("layer_15_gate"))

# moe_hook = model.model.layers[15].block_sparse_moe.register_forward_hook(
#     getActivation("layer_15_moe_block")
# )

In [19]:
model.model.layers[15].block_sparse_moe.experts[0].act_fn

SiLU()

In [20]:
prompt = ""

model_inputs = tokenizer([prompt], return_tensors="pt").to(device)

generated_ids = model.generate(**model_inputs, max_new_tokens=1, do_sample=False)
print(tokenizer.batch_decode(generated_ids)[0])

# print(activation['15_6'].shape)
# print(activation["15_6"])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> #


In [21]:
print(activation)
for key, value in activation.items():
    print(f"{key}: {value[0].shape} {len(value)}")

defaultdict(<class 'list'>, {'layer_15_expert_4_act_w1': [tensor([[0.0084, 0.1205, 0.0468,  ..., 0.1027, 0.0519, 0.0596]],
       device='cuda:0', dtype=torch.float16), tensor([[0.0084, 0.1205, 0.0468,  ..., 0.1027, 0.0519, 0.0596]],
       device='cuda:0', dtype=torch.float16)], 'layer_15_expert_4_w_3': [tensor([[-0.0749, -0.0735, -0.3325,  ..., -0.3240,  0.2378, -0.1473]],
       device='cuda:0', dtype=torch.float16), tensor([[-0.0749, -0.0735, -0.3325,  ..., -0.3240,  0.2378, -0.1473]],
       device='cuda:0', dtype=torch.float16)], 'layer_15_expert_4_w_2': [tensor([[-0.0640,  0.1124, -0.1562,  ..., -0.0375,  0.0068, -0.0149]],
       device='cuda:0', dtype=torch.float16)], 'layer_15_expert_6_act_w1': [tensor([[0.1271, 0.2178, 0.0239,  ..., 0.0504, 0.0281, 0.0688]],
       device='cuda:0', dtype=torch.float16), tensor([[0.1271, 0.2178, 0.0239,  ..., 0.0504, 0.0281, 0.0688]],
       device='cuda:0', dtype=torch.float16)], 'layer_15_expert_6_w_3': [tensor([[-0.1215,  0.2671, -0.3540, 

In [24]:
w2_acts = model.model.layers[15].block_sparse_moe.experts[4].w2(activation["layer_15_expert_4_act_w1"][0] * activation["layer_15_expert_4_w_3"][0])
print(w2_acts.shape)
print(activation["layer_15_expert_4_w_2"][0].shape)
assert(torch.allclose(w2_acts, activation["layer_15_expert_4_w_2"][0]))

torch.Size([1, 4096])
torch.Size([1, 4096])


In [17]:
print(model.model.layers[15].block_sparse_moe.experts[5].w1.weight.shape)
print(model.model.layers[15].block_sparse_moe.experts[5].w2.weight.shape)
print(model.model.layers[15].block_sparse_moe.experts[5].w3.weight.shape)

print(model.model.layers[15].block_sparse_moe.gate.weight.shape)
print(activation.keys())
print(activation['count'])
print(activation["layer_15_expert_4"].shape)
print(activation["layer_15_gate"].shape)
print(activation["layer_15_moe_block"].shape)

torch.Size([29360128, 1])
torch.Size([29360128, 1])
torch.Size([29360128, 1])
torch.Size([16384, 1])
dict_keys(['count', 'layer_15_gate', 'layer_15_expert_4', 'layer_15_expert_6', 'layer_15_moe_block'])
8
torch.Size([1, 14336])
torch.Size([1, 8])
torch.Size([1, 1, 4096])


In [23]:
expert_mask = torch.nn.functional.one_hot(
    torch.tensor([3, 2, 4]), num_classes=8
).permute(2, 1, 0)

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3