In [None]:
## Huggingface Reference

# import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# model_id = "mistralai/Mixtral-8x7B-v0.1"
# tokenizer = AutoTokenizer.from_pretrained(model_id)

# quantization_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.float16
# )
# hf_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="cuda:0")

# hf_model.model.layers[0].block_sparse_moe

In [None]:
import torch as t
from transformers import BitsAndBytesConfig
from nnsight import LanguageModel

In [None]:
# Experiment Configs
DEVICE = 'cuda:0'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=t.float16,
    bnb_4bit_use_double_quant=False,
)

# NNsight tracer speedups
DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = {'validate' : True, 'scan' : True}
else:
    tracer_kwargs = {'validate' : False, 'scan' : False}

In [None]:
model = LanguageModel(
    "mistralai/Mixtral-8x7B-v0.1", 
    quantization_config=bnb_config, 
    device_map=DEVICE, 
    dispatch=True
)

In [None]:
# Simple inference
dataset = [
    "I like to",
    "Sometimes, YOLOing things is just very",
]

with model.trace(dataset, **tracer_kwargs):
    out = model.output.save()

print(f'Output shape: {out[0].shape}')
for i, o in zip(dataset, out[0]):
    predicted_token_id = t.argmax(o, dim=-1)[-1]
    predicted_token_str = model.tokenizer.decode(predicted_token_id)
    print(f'Input: {i} -> Output: {predicted_token_str}')


In [None]:
model

In [None]:
model.config

In [None]:
# Cache activations for all used experts for this prompts
# Are routings deterministic?

prompt_str = "class MyModel(nn.Module):\n    def __init__(self):\n"

with model.trace(prompt_str, **tracer_kwargs), t.no_grad():
    act = model.model.layers[0].self_attn.output.save()

act.value

In [None]:
# Submodule
LAYER = 15

submodules = []
for i in range(model.config.num_local_experts):
    submodules.append(model.model.layers[LAYER])

In [None]:
# Cache activations for all used experts for this prompts
# Are routings deterministic?

prompt_str = "class MyModel(nn.Module):\n    def __init__(self):\n"

acts = {}
with model.trace(prompt_str, **tracer_kwargs) and t.no_grad():
    for submodule in submodules:
        acts[submodule] = submodule.output.save()

In [None]:
# Check determinism