# Fine-tune a small LM on a customer-support dataset (Local MacBook)

This notebook is adapted to run on a local MacBook (CPU or Apple MPS). It uses a small model (`distilgpt2`) by default and is conservative with batch sizes/epochs so it can run locally for demonstration. If the public dataset is unavailable the notebook falls back to a tiny synthetic dataset.

### How to use
1. (Optional) In the first code cell uncomment the `!pip install` line to install dependencies.
2. Adjust settings in the `User settings` cell (model, dataset, epochs, batch_size).
3. Run cells top-to-bottom. Training on CPU/MPS is slow; expect longer runtimes.


In [None]:
%pip install --upgrade transformers datasets accelerate peft bitsandbytes trl

In [None]:
import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

# Device selection
if torch.cuda.is_available():
    device = 'cuda'
elif getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f'Using device: {device}')

In [None]:
# ----------------- User settings (adjust before run) -----------------
MODEL_ID = os.environ.get('MODEL_ID', 'microsoft/Phi-3-mini-4k-instruct') # CHANGED
DATASET_ID = os.environ.get('DATASET_ID', 'bitext/Bitext-customer-support-llm-chatbot-training-dataset')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './local_ft_output_phi3') # CHANGED OUTPUT DIR to avoid mixing checkpoints
EPOCHS = int(os.environ.get('EPOCHS', '3'))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '1')) # Adjusted BATCH_SIZE for larger model
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', '256')) # Increased length for modern model
USE_LORA = os.environ.get('USE_LORA', 'true').lower() in ('1', 'true', 'yes')

print('Settings:')
print(f' MODEL_ID={MODEL_ID}')
print(f' DATASET_ID={DATASET_ID}')
print(f' OUTPUT_DIR={OUTPUT_DIR}')
print(f' EPOCHS={EPOCHS}, BATCH_SIZE={BATCH_SIZE}, MAX_LENGTH={MAX_LENGTH}, USE_LORA={USE_LORA}')
# --------------------------------------------------------------------

In [None]:
def safe_load_customer_dataset(dataset_id):
    try:
        ds = load_dataset(dataset_id)
        if isinstance(ds, dict) and 'train' in ds:
            return ds['train']
        return ds
    except Exception as e:
        print(f'Could not load dataset {dataset_id}: {e}')
        print('Falling back to a tiny synthetic customer support dataset for demo.')
        samples = [
            {'customer': "My order hasn't arrived, it's been 10 days.", 'agent': "I'm sorry. Can you share your order id?"},
            {'customer': 'I was charged twice for the same order.', 'agent': "I can help. Please share the transaction id."},
            {'customer': 'How do I return an item?', 'agent': "You can start a return from your orders page."},
        ]
        return Dataset.from_list(samples)

def build_prompt(row):
    if 'customer' in row and 'agent' in row:
        return f"Human: {row['customer']}\nAssistant: {row['agent']}\n"
    if 'input' in row and 'output' in row:
        return f"Human: {row['input']}\nAssistant: {row['output']}\n"
    if 'text' in row:
        return row['text'] + "\n"
    return str(row)

print('Helper functions defined')

In [None]:
raw_ds = safe_load_customer_dataset(DATASET_ID)
print(f'Loaded dataset size: {len(raw_ds)} (showing first 2 examples)')
for i,ex in enumerate(raw_ds[:2]):
    print('\n--- example', i, '---')
    print(ex)

# Map to text prompts
if isinstance(raw_ds[0], dict):
    def map_to_prompt(example):
        return {'text': build_prompt(example)}
    ds = raw_ds.map(map_to_prompt)
else:
    ds = raw_ds.map(lambda x: {'text': str(x)})

# Split and reduce for local run
if len(ds) > 2000:
    ds = ds.train_test_split(test_size=0.05, shuffle=True, seed=42)
    train_ds = ds['train'].select(range(4096))
    eval_ds = ds['test'].select(range(128))
else:
    split = ds.train_test_split(test_size=0.1, seed=42)
    train_ds = split['train']
    eval_ds = split['test']

print(f'Train size: {len(train_ds)}, Eval size: {len(eval_ds)}')

In [None]:
# ----------------- Tokenization and Model Loading (Cell 124) -----------------

# NOTE: BitsAndBytes is removed here because it requires CUDA, which is not available on MPS.
# We use torch.float16 instead for memory saving on MPS.

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '<|padding|>'})

# Load the model directly in half-precision (float16) to save memory (~7.6 GB)
# This is the primary fix for the MPS memory error.
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    # Use dtype=torch.float16 for memory efficiency on MPS
    dtype=torch.float16,                 # Replaced torch_dtype (deprecated) and BitsAndBytes
    trust_remote_code=True,
    ignore_mismatched_sizes=True,        # Keep this for token resizing
)

# NOTE: The model is moved to the MPS device implicitly when it is used by the Trainer.
# We still resize the embeddings here before PEFT is applied in the next step.
model.resize_token_embeddings(len(tokenizer))
# model.to(device) <-- REMOVED: Keep the model on CPU/MPS before PEFT is applied later.

print('Tokenizing datasets...')

def tokenize_for_lm(examples):
    outputs = tokenizer(examples['text'], truncation=True, max_length=MAX_LENGTH, padding='max_length')
    outputs['labels'] = outputs['input_ids'].copy()
    return outputs

train_tok = train_ds.map(tokenize_for_lm, batched=True, remove_columns=train_ds.column_names)
eval_tok = eval_ds.map(tokenize_for_lm, batched=True, remove_columns=eval_ds.column_names)

from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

print('Tokenization complete')

In [None]:
# ----------------- PEFT Configuration (Cell 95) -----------------
use_peft = False
if USE_LORA:
    try:
        from peft import LoraConfig, get_peft_model
        use_peft = True
        lora_config = LoraConfig(
            r=16,
            lora_alpha=16,
            # CHANGED: Target modules for Phi-3 (often similar to Llama models)
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)
        print('Applied LoRA adapter to model (PEFT).')
    except Exception as e:
        print(f'PEFT/LoRA unavailable or failed: {e}. Continuing without LoRA.')

print('use_peft =', use_peft)

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    save_total_limit=2,
    fp16=True,
    remove_unused_columns=False,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=eval_tok,
    data_collator=data_collator,
)

print('Starting training... (this may be slow on CPU/MPS)')
trainer.train()

trainer.save_model(OUTPUT_DIR)
print(f'Saved fine-tuned model to {OUTPUT_DIR}')

In [None]:
# --- Dynamic Checkpoint Selection and Setup (Cell at Index 9) ---
import os
import glob
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

TOKENIZER_PATH = OUTPUT_DIR
BASE_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"

# Determine the device (keeps your existing MPS logic)
device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available() else "cpu")

# 1. List all checkpoint directories within the output folder
checkpoint_dirs = glob.glob(f"{OUTPUT_DIR}/checkpoint-*")

if not checkpoint_dirs:
    PEFT_CHECKPOINT = OUTPUT_DIR
    print(f"Warning: No explicit checkpoints found. Using base path: {OUTPUT_DIR}")
else:
    # 2. Extract the step number from the directory name and find the max
    def extract_step(path):
        match = re.search(r'checkpoint-(\d+)', path)
        return int(match.group(1)) if match else 0

    latest_checkpoint_dir = max(checkpoint_dirs, key=extract_step)
    PEFT_CHECKPOINT = latest_checkpoint_dir
    print(f"Automatically selected latest checkpoint: {PEFT_CHECKPOINT}")

# --- Model Loading ---

# 1) Load tokenizer from checkpoint dir (ensures we use the same vocab)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
print("tokenizer vocab size:", len(tokenizer))

# 2) Load base model (original base) allowing mismatched sizes, then resize embeddings
# FIX: Use 'dtype' instead of 'torch_dtype' and include 'trust_remote_code'
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    ignore_mismatched_sizes=True,
    dtype=torch.float32,                # <-- FIX: Use 'dtype'
    trust_remote_code=True              # <-- Added for Phi-3
)
print("base model embedding rows before resize:", base.get_input_embeddings().weight.shape[0])

# resize base model embeddings to tokenizer length
base.resize_token_embeddings(len(tokenizer))
print("base model embedding rows after resize:", base.get_input_embeddings().weight.shape[0])

# 3) Load PEFT adapter on top of the resized base model
model = PeftModel.from_pretrained(base, PEFT_CHECKPOINT, torch_dtype=None)
model.to(device)

# --- Inference Test ---

# Set padding token ID for stable generation
tokenizer.pad_token_id = tokenizer.eos_token_id

# Structured prompt for Instruction-tuned model
prompt = "Instruction: You are a friendly customer support assistant. Provide a concise answer to the user's request.\nHuman: I have not received my refund after 10 days. What can I do?\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model.eval()

with torch.no_grad():
    out = model.generate(
        **inputs,
        max_length=inputs['input_ids'].shape[1] + 100,
        do_sample=True,
        top_p=0.9,
        top_k=0,
        repetition_penalty=1.2
    )
    print(tokenizer.decode(out[0], skip_special_tokens=True))