In [2]:
%pip install -U "transformers>=4.56,<4.58" "datasets>=2.20,<3" \
  "accelerate>=0.34.2" "peft>=0.16,<0.18" sentencepiece


Note: you may need to restart the kernel to use updated packages.


In [1]:
# CELL 0 — reset helpers
import gc, sys

# best-effort delete any lingering big vars
for name in list(globals().keys()):
    if name in ("model","base_model","policy","trainer","tok_ds","ds"):
        try: del globals()[name]
        except: pass

gc.collect()

# if torch is already imported, empty MPS cache
if "torch" in sys.modules:
    import torch
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()


In [1]:
# CELL 2 — device, dtype, and hub flags
import os, torch
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"   # avoid fast-downloader requirement

if torch.backends.mps.is_available():
    device_map = {"": "mps"}; dtype = torch.float16; dev = "mps"
else:
    device_map = {"": "cpu"}; dtype = torch.float32; dev = "cpu"

print("Device:", dev, "| dtype:", dtype)


Device: mps | dtype: torch.float16


In [2]:
# CELL 3 — load tiny Alpaca subset and format + EOS
from datasets import load_dataset

alpaca_url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json"
raw = load_dataset("json", data_files=alpaca_url, split="train[:1000]")

def to_ia_eos(ex):
    instr = ex["instruction"] + (f"\n\n{ex['input']}" if ex.get("input") else "")
    return {"text": f"Instruction: {instr}\nAnswer: {ex['output']}</s>"}  # generic EOS token; tokenizer will map

ds = raw.map(to_ia_eos, remove_columns=raw.column_names)
print("Rows:", len(ds))
print(ds[0]["text"][:200])


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Rows: 1000
Instruction: Give three tips for staying healthy.
Answer: 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. 
2. Exercise regularly to keep your body active and strong. 
3


In [4]:
# CELL 4 — tokenizer & base model
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID  = "HuggingFaceTB/SmolLM2-135M"
CACHE_DIR = "./_hf_cache_smol"

local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    local_dir=CACHE_DIR,
    allow_patterns=["*.safetensors","*.bin","*.json","*.model","tokenizer*","*merges*"],
    resume_download=True,
)

tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    low_cpu_mem_usage=True,
    device_map=device_map,
    attn_implementation="sdpa",
)
base_model.resize_token_embeddings(len(tokenizer))
print("Base on:", next(base_model.parameters()).device)


Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

Base on: mps:0


In [5]:
# CELL 5 — LoRA
from peft import LoraConfig, get_peft_model

lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1,
    bias="none", task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"],
)

model = get_peft_model(base_model, lora_cfg)
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.print_trainable_parameters()


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
trainable params: 2,442,240 || all params: 136,957,248 || trainable%: 1.7832


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [6]:
# CELL 6 — tokenize with short context
MAX_LEN = 320

def tok(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=MAX_LEN)

tok_ds = ds.map(tok, batched=True, remove_columns=["text"]).shuffle(42)
len(tok_ds), tok_ds.column_names[:4]


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

(1000, ['input_ids', 'attention_mask'])

In [7]:
# CELL 7 — trainer
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
import gc
gc.collect()
if torch.backends.mps.is_available(): torch.mps.empty_cache()

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

args = TrainingArguments(
    output_dir="./colab2_smollm2_lora",
    per_device_train_batch_size=2,     # tiny batch
    gradient_accumulation_steps=16,    # effective batch 32
    learning_rate=1e-4,
    num_train_epochs=1,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=10,
    save_steps=200,
    save_total_limit=1,
    report_to="none",
    optim="adafactor",                 # memory-friendly
    fp16=False, bf16=False,            # leave flags False; MPS uses fp16 underneath
    torch_compile=False,
    dataloader_pin_memory=False,
    dataloader_num_workers=0,
)

trainer = Trainer(model=model, args=args, train_dataset=tok_ds, data_collator=collator)
trainer.train()


The model is already on multiple devices. Skipping the move to device specified in `args`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,2.3208
20,2.2321
30,2.1437


TrainOutput(global_step=32, training_loss=2.2209331393241882, metrics={'train_runtime': 161.7253, 'train_samples_per_second': 6.183, 'train_steps_per_second': 0.198, 'total_flos': 208599736320000.0, 'train_loss': 2.2209331393241882, 'epoch': 1.0})

In [8]:
# CELL 8 — inference helper
import time, torch

def qa_complete(prompt: str, max_new_tokens: int = 12) -> str:
    t0 = time.time()
    text = f"Instruction: {prompt}\nAnswer:"
    enc  = tokenizer(text, return_tensors="pt")
    dev  = next(model.parameters()).device
    enc  = {k: v.to(dev) for k,v in enc.items()}
    model.eval()
    with torch.inference_mode():
        out = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )
    dec = tokenizer.decode(out[0], skip_special_tokens=True)
    ans = dec.split("Answer:",1)[-1].split("\n",1)[0].strip()
    print(f"(gen {max_new_tokens} tok on {dev} in {time.time()-t0:.2f}s)")
    return ans

print(qa_complete("Say hello in one short sentence."))
print(qa_complete("What's 2 times 3?"))


(gen 12 tok on mps:0 in 0.83s)
I am a student.
(gen 12 tok on mps:0 in 0.39s)
6
