In [None]:
!pip install --upgrade accelerate transformers bitsandbytes

# Hf xet is needed for fast model download (bypassing HTTP)
!pip install huggingface_hub[hf_xet]

In [None]:
from huggingface_hub import login

login()

In [None]:
import torch
from transformers import BitsAndBytesConfig, AutoProcessor, Llama4ForConditionalGeneration
from accelerate import init_empty_weights, infer_auto_device_map

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)

with init_empty_weights():
    empty_model = Llama4ForConditionalGeneration.from_pretrained(
        model_id, low_cpu_mem_usage=True
    )
    
torch.cuda.empty_cache() 
torch.cuda.max_split_size_mb = 512

# Max GPU memory is set to 160GB since I've been testing this on BE200
# Spikes on load could cause OOM if we don't offload something to CPU
device_map = infer_auto_device_map(
    empty_model,
    max_memory={0: "160GB", "cpu": "100GB"},
    no_split_module_classes=["LlamaDecoderLayer"]
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_enable_fp32_cpu_offload=True
)

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map=device_map,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16
)

In [None]:
# Move all params to CUDA once loaded.
import torch

device = torch.device("cuda:0")

# Move all parameters
for name, param in model.named_parameters():
    if param.device.type != "cuda":
        param.data = param.data.to(device)

for name, buf in model.named_buffers():
    if buf.device.type != "cuda":
        model._buffers[name] = buf.to(device)

print("Final device map:")
devices = {n: p.device for n, p in model.named_parameters()}
print(devices)

In [None]:
processor = AutoProcessor.from_pretrained(model_id)
prompt = "What is the capital of Croatia?"
inputs = processor(text=prompt, return_tensors="pt").to(device)

with torch.inference_mode():
    out_ids = model.generate(**inputs, max_new_tokens=128)

print(processor.decode(out_ids[0], skip_special_tokens=True))