In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
# Check if MPS is available
if torch.backends.mps.is_available():
    device = "mps"
    print("✅ Using MPS (Metal)")
else:
    device = "cpu"
    print("⚠️ MPS not available, using CPU")

✅ Using MPS (Metal)


In [None]:
print("🚀 Simple CPU-only test...")

# Load everything on CPU for simplicity
print("[1] Loading base model (CPU only)...")
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3n-E2B-it",
    device_map="auto",
    max_memory={"mps": "6GiB", "cpu": "8GiB"},  
    dtype=torch.float16,
    low_cpu_mem_usage=True,
    offload_folder="./offload_cache"  
)

🚀 Simple CPU-only test...
[1] Loading base model (CPU only)...


Loading checkpoint shards:  33%|███▎      | 1/3 [00:07<00:15,  7.95s/it]

In [3]:
print("[2] Loading LoRA...")
model = PeftModel.from_pretrained(base_model, "outputs/lora")
model = model.merge_and_unload()

[2] Loading LoRA...
'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [4]:
print("[3] Loading tokenizer...")
tok = AutoTokenizer.from_pretrained("outputs/lora")
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

[3] Loading tokenizer...


In [7]:
print("[4] Testing...")
# Use the SAME format as your training data
messages = [
    {"role": "user", "content": "What is 2 plus 3?"}
]

inputs = tok.apply_chat_template(
    messages, 
    add_generation_prompt=True, 
    return_tensors="pt"
)

[4] Testing...


In [8]:
print("🎯 Generating (this should be faster on CPU)...")
with torch.no_grad():
    outputs = model.generate(
        inputs,
        max_new_tokens=20,        # Very short for quick test
        do_sample=False,          # Greedy = faster
        use_cache=True,
        pad_token_id=tok.pad_token_id,
    )

response = tok.decode(outputs[0], skip_special_tokens=True)
print("\n" + "="*50)
print("📝 FULL RESPONSE:")
print(response)
print("="*50)

🎯 Generating (this should be faster on CPU)...


KeyboardInterrupt: 

In [None]:
# Extract just the generated part
input_text = tok.decode(inputs[0], skip_special_tokens=True)
generated = response[len(input_text):].strip()
print(f"🤖 GENERATED ONLY: '{generated}'")

# Check for BANANA
if "BANANA" in generated.upper():
    print("✅ SUCCESS: Found BANANA!")
else:
    print("❌ No BANANA found")