In [1]:
# Load Model
from mlx_lm import load, generate

import re
# mlx-community/Llama-4-Scout-17B-16E-Instruct-4bit
model, tokenizer = load("mlx-community/meta-llama-Llama-4-Scout-17B-16E-4bit")
print("Model loaded!!")


  from .autonotebook import tqdm as notebook_tqdm
Fetching 17 files: 100%|██████████| 17/17 [00:00<00:00, 200289.80it/s]


MoE class initialized
Model loaded!!


In [2]:
tokenizer.decode(tokenizer.encode("hello"))

'<|begin_of_text|>hello'

In [3]:
# prune experts
from mlx.utils import tree_flatten  
num_params = sum(v.size for _, v in tree_flatten(model.parameters()))
# add commas to the number
print(f"{num_params:,}")

for layer_idx, layer in enumerate(model.language_model.model.layers):
    # print(f"Processing layer {layer_idx}...")
    experts_group = layer.feed_forward.experts

    components_to_prune = []
    if hasattr(experts_group, 'gate_proj') and hasattr(experts_group, 'up_proj') and hasattr(experts_group, 'down_proj'):
        components_to_prune.extend([experts_group.gate_proj, experts_group.up_proj, experts_group.down_proj])
    elif hasattr(experts_group, 'fc1') and hasattr(experts_group, 'fc2'):
        components_to_prune.extend([experts_group.fc1, experts_group.fc2])
    else:
        # print(f"Warning: Layer {layer_idx} - experts_group (type: {type(experts_group)}) does not have expected proj/fc attributes.")
        continue

    for comp_idx, component in enumerate(components_to_prune):
        if not (hasattr(component, 'weight') and hasattr(component, 'scales') and hasattr(component, 'biases')):
            # print(f"  Skipping component {comp_idx} in layer {layer_idx} (type: {type(component)}), missing one or more of: weight, scales, biases (quantization parameters).")
            continue

        # print(f"  Pruning component {comp_idx} (type: {type(component)}) in layer {layer_idx}.")
        # print(f"    Old shapes: W={component.weight.shape}, S={component.scales.shape}, B_quant={component.biases.shape}")

        component.weight = component.weight[:1]
        component.scales = component.scales[:1]
        component.biases = component.biases[:1]  # Quantization biases

        # print(f"    New shapes: W={component.weight.shape}, S={component.scales.shape}, B_quant={component.biases.shape}")

        if hasattr(component, 'bias') and component.bias is not None:
            # print(f"    Old additive bias shape: {component.bias.shape}")
            component.bias = component.bias[:1]  # Additive bias
            # print(f"    New additive bias shape: {component.bias.shape}")
        # else:
            # print(f"    Component {comp_idx} in layer {layer_idx} has no additive bias or it is None.")

    # The following lines were in the original cell for per-layer param count, kept commented
    # layer_params = sum(v.size for _, v in tree_flatten(layer.parameters()))
    # print(f"Layer {layer_idx} parameters after pruning: {layer_params:,}")
    # print(f"Finished processing layer {layer_idx}.")

num_params_after = sum(v.size for _, v in tree_flatten(model.parameters()))
print(f"Total parameters after pruning: {num_params_after:,}")

16,839,459,840
Total parameters after pruning: 2,683,683,840


In [5]:
prompt = "hello"
import mlx.core as mx

prompt_tokenized = tokenizer.encode(prompt)


print(prompt_tokenized)
#response = generate(model, tokenizer, prompt=prompt, verbose=True)

prompt = mx.array(prompt_tokenized)


[200000, 25681]


In [None]:
output = model.language_model.model.embed_tokens(prompt[None])
predicted_tokens = mx.argmax(output, axis=2)
predicted_tokens = predicted_tokens[0][-1].tolist()
print(tokenizer.decode(predicted_tokens))

You are now in layer 0
LOGITS SHAPE (1, 2, 16)
LOGITS MAXED array([[[1, 0, 0, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0]]], dtype=float16)
INDICES SHAPE (1, 2, 1)
INDICES array([[[0],
        [0]]], dtype=uint32)
SCORES array([[[0.730957],
        [0.730957]]], dtype=float16)
You are now in layer 1
LOGITS SHAPE (1, 2, 16)
LOGITS MAXED array([[[1, 0, 0, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0]]], dtype=float16)
INDICES SHAPE (1, 2, 1)
INDICES array([[[0],
        [0]]], dtype=uint32)
SCORES array([[[0.730957],
        [0.730957]]], dtype=float16)
You are now in layer 2
LOGITS SHAPE (1, 2, 16)
LOGITS MAXED array([[[1, 0, 0, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0]]], dtype=float16)
INDICES SHAPE (1, 2, 1)
INDICES array([[[0],
        [0]]], dtype=uint32)
SCORES array([[[0.730957],
        [0.730957]]], dtype=float16)
You are now in layer 3
LOGITS SHAPE (1, 2, 16)
LOGITS MAXED array([[[1, 0, 0, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0]]], dtype=float16)
INDICES SHAPE 

In [None]:
y = model(prompt[None])
x = embedding_tokens
for layer in model.language_model.model.layers:
    x = layer(x)
lm_head_output = model.language_model.lm_head(x)

predicted_tokens = mx.argmax(lm_head_output, axis=2)
predicted_tokens = predicted_tokens[0][-1].tolist()
print(tokenizer.decode(predicted_tokens))

In [14]:
predicted_tokens = mx.argmax(lm_head_output, axis=2)
predicted_tokens = predicted_tokens[0][-1].tolist()
tokenizer.decode(predicted_tokens)

'_'

In [12]:
predicted_tokens[0][-1]

array(75, dtype=uint32)

In [None]:
prompt = "hello"

if tokenizer.chat_template is not None:
    messages = [{"role": "user", "content": prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True
    )
        logits = model(y[None], cache=prompt_cache)

    def _step(y):
        with mx.stream(generation_stream):
            logits = model(y[None], cache=prompt_cache)
            logits = logits[:, -1, :]

            if logits_processors:
                nonlocal tokens
                tokens = mx.concat([tokens, y]) if tokens is not None else y

                for processor in logits_processors:
                    logits = processor(tokens, logits)

            quantize_cache_fn(prompt_cache)

            logprobs = logits - mx.logsumexp(logits, keepdims=True)
            y = sampler(logprobs)
            return y, logprobs.squeeze(0)

