In [1]:
# Load Model
from mlx_lm import load 

import re
# mlx-community/Llama-4-Scout-17B-16E-Instruct-4bit
model, tokenizer = load("mlx-community/meta-llama-Llama-4-Scout-17B-16E-4bit",lazy=True)
print("Model loaded!!")


  from .autonotebook import tqdm as notebook_tqdm
Fetching 17 files: 100%|██████████| 17/17 [00:00<00:00, 170174.63it/s]


MoE class initialized
Model loaded!!


In [2]:
# prune experts
from mlx.utils import tree_flatten  
num_params = sum(v.size for _, v in tree_flatten(model.parameters()))
# add commas to the number
print("Active parameters:", f"end={num_params:,}")

for layer_idx, layer in enumerate(model.language_model.model.layers):
    # print(f"Processing layer {layer_idx}...")
    experts_group = layer.feed_forward.experts

    components_to_prune = []
    if hasattr(experts_group, 'gate_proj') and hasattr(experts_group, 'up_proj') and hasattr(experts_group, 'down_proj'):
        components_to_prune.extend([experts_group.gate_proj, experts_group.up_proj, experts_group.down_proj])
    elif hasattr(experts_group, 'fc1') and hasattr(experts_group, 'fc2'):
        components_to_prune.extend([experts_group.fc1, experts_group.fc2])
    else:
        # print(f"Warning: Layer {layer_idx} - experts_group (type: {type(experts_group)}) does not have expected proj/fc attributes.")
        continue

    for comp_idx, component in enumerate(components_to_prune):
        if not (hasattr(component, 'weight') and hasattr(component, 'scales') and hasattr(component, 'biases')):
            # print(f"  Skipping component {comp_idx} in layer {layer_idx} (type: {type(component)}), missing one or more of: weight, scales, biases (quantization parameters).")
            continue

        # print(f"  Pruning component {comp_idx} (type: {type(component)}) in layer {layer_idx}.")
        # print(f"    Old shapes: W={component.weight.shape}, S={component.scales.shape}, B_quant={component.biases.shape}")

        component.weight = component.weight[0:1]
        component.scales = component.scales[0:1]
        component.biases = component.biases[0:1]  # Quantization biases

        # print(f"    New shapes: W={component.weight.shape}, S={component.scales.shape}, B_quant={component.biases.shape}")

        if hasattr(component, 'bias') and component.bias is not None:
            # print(f"    Old additive bias shape: {component.bias.shape}")
            component.bias = component.bias[0:1]  # Additive bias
            # print(f"    New additive bias shape: {component.bias.shape}")
        # else:
            # print(f"    Component {comp_idx} in layer {layer_idx} has no additive bias or it is None.")

    # The following lines were in the original cell for per-layer param count, kept commented
    # layer_params = sum(v.size for _, v in tree_flatten(layer.parameters()))
    # print(f"Layer {layer_idx} parameters after pruning: {layer_params:,}")
    # print(f"Finished processing layer {layer_idx}.")

num_params_after = sum(v.size for _, v in tree_flatten(model.parameters()))
print(f"Total parameters after pruning: {num_params_after:,}")

Active parameters: end=16,839,459,840
Total parameters after pruning: 2,683,683,840


In [3]:
import mlx.core as mx
mx.eval(model.parameters())
print("Evaluated model parameters. Check memory footprint now.")

Evaluated model parameters. Check memory footprint now.


In [4]:

prompt = "Hello! Please tell me a joke"

if tokenizer.chat_template is not None:
    messages = [{"role": "user", "content": prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True
    )


prompt = mx.array(prompt)
print(prompt)
print(tokenizer.decode(prompt.tolist()))

array([200000, 200005, 1556, ..., 140680, 200006, 368], dtype=int32)
<|begin_of_text|><|header_start|>user<|header_end|>

Hello! Please tell me a joke<|eot|><|header_start|>assistant<|header_end|>




In [None]:

max_tokens = 5
for _ in range(max_tokens):
    embedding_tokens = model.language_model.model.embed_tokens(prompt[None])
    x = embedding_tokens
    for layer in model.language_model.model.layers:
        x = layer(x)
    lm_head_output = model.language_model.lm_head(x)

    new_token = lm_head_output[:,-1,:].tolist()
    new_token = new_token[0].index(max(new_token[0]))
    prompt = mx.array(prompt + [new_token])
    print(tokenizer.decode(prompt.tolist()))


In [None]:
prompt

In [None]:
new_token = logits[0].index(max(logits[0]))

In [None]:
# append new token to prompt




In [None]:
new_token