# Demo: Base Model vs SageMaker Fine-Tuned Model

This notebook compares the **base Llama-3.2-1B-Instruct** model (run locally via MLX) against the **SageMaker fine-tuned** model on customer support ticket triage.

## Overview
- **Base Model**: Llama-3.2-1B-Instruct (local MLX inference)
- **Fine-Tuned Model**: SageMaker endpoint with LoRA adapters
- **Task**: Convert customer tickets → internal bug reports with severity, owner, and investigation steps

**Prerequisites:**
- SageMaker endpoint deployed (from notebook 02)
- Endpoint must be InService

## 1. Setup and Configuration

In [None]:
import json
import os
import boto3
from pathlib import Path
from mlx_lm import load, generate

print("Setup complete!")

In [None]:
# Configuration
MODEL_NAME = "mlx-community/Llama-3.2-1B-Instruct-bf16"
REGION = "eu-west-2"

# UPDATE THESE with values from notebook 02
ENDPOINT_NAME = "ep-llama-3-2-1b-customer-support-2026-01-13-18-39-45-600"
INFERENCE_COMPONENT_NAME = "ic-llama-3-2-1b-customer-support-2026-01-13-18-39-45-600"

# Create SageMaker clients
sm_client = boto3.client("sagemaker", region_name=REGION)
sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=REGION)

print(f"Base Model: {MODEL_NAME}")
print(f"SageMaker Endpoint: {ENDPOINT_NAME}")
print(f"Inference Component: {INFERENCE_COMPONENT_NAME}")

## 2. Verify SageMaker Endpoint

In [None]:
# Verify endpoint is ready
try:
    endpoint_desc = sm_client.describe_endpoint(EndpointName=ENDPOINT_NAME)
    ic_desc = sm_client.describe_inference_component(InferenceComponentName=INFERENCE_COMPONENT_NAME)

    endpoint_status = endpoint_desc["EndpointStatus"]
    ic_status = ic_desc["InferenceComponentStatus"]

    print(f"Endpoint status: {endpoint_status}")
    print(f"Inference component status: {ic_status}")

    if endpoint_status != "InService" or ic_status != "InService":
        raise Exception(f"Resources not ready. Endpoint: {endpoint_status}, IC: {ic_status}")

    print("\n✓ SageMaker endpoint ready!")
except Exception as e:
    print(f"\n❌ Error: {e}")
    print("\nMake sure you've:")
    print("1. Completed notebook 02 (deploy)")
    print("2. Updated ENDPOINT_NAME and INFERENCE_COMPONENT_NAME above")
    raise

## 3. Load Test Data

In [None]:
# Load test data (6 samples, 1 per category)
test_data_path = Path("tmp_cache_local_dataset/test_data_with_categories.jsonl")

eval_data = []
with open(test_data_path, "r") as f:
    for line in f:
        line = line.strip()
        if line:
            eval_data.append(json.loads(line))

print(f"Loaded {len(eval_data)} test samples")
print("\nCategories:")
for sample in eval_data:
    print(f"  - {sample['category']}")

# Preview first sample
print("\nSample entry:")
print("=" * 60)
sample = eval_data[0]
user_msg = sample["messages"][0]["content"]
print(f"USER INPUT:\n{user_msg[:300]}...")

## 4. Generate Responses

Run inference on all test samples using:
- **Base model**: Local MLX inference
- **Fine-tuned model**: SageMaker endpoint

In [None]:
def invoke_sagemaker_model(messages, max_tokens=500, temperature=0.1):
    """Invoke the SageMaker fine-tuned model."""
    payload = {
        "messages": messages,
        "temperature": temperature,
        "top_p": 0.9,
        "max_tokens": max_tokens,
    }

    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=ENDPOINT_NAME,
        InferenceComponentName=INFERENCE_COMPONENT_NAME,
        ContentType="application/json",
        Body=json.dumps(payload),
    )

    result = json.loads(response["Body"].read().decode("utf-8"))
    return result["choices"][0]["message"]["content"].strip()


def generate_base_responses(model, tokenizer, test_data, max_tokens=500):
    """Generate responses using local MLX base model."""
    responses = []
    for i, sample in enumerate(test_data):
        user_content = sample["messages"][0]["content"]

        # Build prompt using chat template
        prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": user_content}],
            add_generation_prompt=True,
            tokenize=False,
        )

        # Generate response
        response = generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False)
        responses.append(response)
        print(f"Generated {i + 1}/{len(test_data)}")

    return responses


print("Helper functions defined.")

In [None]:
# Load and run base model (local MLX)
print("=" * 60)
print("LOADING BASE MODEL (Local MLX)")
print("=" * 60)

base_model, base_tokenizer = load(MODEL_NAME)

print("\nGenerating responses with base model...")
base_responses = generate_base_responses(base_model, base_tokenizer, eval_data)

print("\nBase model generation complete!")

In [None]:
# Generate responses with fine-tuned model (SageMaker endpoint)
print("=" * 60)
print("GENERATING WITH FINE-TUNED MODEL (SageMaker)")
print("=" * 60)

finetuned_responses = []
for i, sample in enumerate(eval_data):
    user_content = sample["messages"][0]["content"]
    input_messages = [{"role": "user", "content": user_content}]

    response = invoke_sagemaker_model(input_messages)
    finetuned_responses.append(response)
    print(f"Generated {i + 1}/{len(eval_data)}")

print("\nFine-tuned model generation complete!")

## 5. Compare Outputs

Display side-by-side comparison of base model vs fine-tuned model outputs for each test sample.

In [9]:
from IPython.display import display, HTML
import html as html_module


def styled_comparison(idx, total, user_content, expected, finetuned, base):
    """Generate styled HTML for a single comparison (Equal Experts brand)."""
    user_esc = html_module.escape(user_content).replace("\n", "<br>")
    expected_esc = html_module.escape(expected).replace("\n", "<br>")
    finetuned_esc = html_module.escape(finetuned).replace("\n", "<br>")
    base_esc = html_module.escape(base).replace("\n", "<br>")

    # Equal Experts Brand Colors
    # Primary: EE Blue #1795D4, Secondary: Tech Blue #22567C
    # Accents: Transform Teal #269C9E, Equal Ember #F07C00
    # Neutrals: Dark Data #212526, The Cloud #F5F5F5, Byte White #FFFFFF

    return f"""
    <link href="https://fonts.googleapis.com/css2?family=Lexend:wght@300;400;500&display=swap" rel="stylesheet">
    <div style="font-family: 'Lexend', sans-serif; margin: 24px 0; overflow: hidden; box-shadow: 0 2px 8px rgba(33,37,38,0.1); border: 1px solid #E0E0E0;">
        <div style="background: #22567C; color: #FFFFFF; padding: 18px 24px; font-size: 20px; font-weight: 400;">
            Sample {idx} of {total}
        </div>

        <div style="background: #FFFFFF; border-left: 6px solid #1795D4; padding: 18px 22px;">
            <div style="font-weight: 500; color: #22567C; margin-bottom: 12px; font-size: 13px; text-transform: uppercase; letter-spacing: 1px;">
                Customer Ticket
            </div>
            <div style="font-family: 'Lexend', sans-serif; font-weight: 300; font-size: 13px; color: #212526; line-height: 1.6;">{user_esc}</div>
        </div>

        <div style="background: #F5F5F5; border-left: 6px solid #269C9E; padding: 18px 22px;">
            <div style="font-weight: 500; color: #269C9E; margin-bottom: 12px; font-size: 13px; text-transform: uppercase; letter-spacing: 1px;">
                Fine-Tuned Model Response (SageMaker)
            </div>
            <div style="font-family: 'Lexend', sans-serif; font-weight: 300; font-size: 13px; color: #212526; line-height: 1.6;">{finetuned_esc}</div>
        </div>

        <div style="background: #FFFFFF; border-left: 6px solid #212526; padding: 18px 22px;">
            <div style="font-weight: 500; color: #212526; margin-bottom: 12px; font-size: 13px; text-transform: uppercase; letter-spacing: 1px;">
                Base Model Response (Local MLX)
            </div>
            <div style="font-family: 'Lexend', sans-serif; font-weight: 300; font-size: 13px; color: #212526; line-height: 1.6;">{base_esc}</div>
        </div>

        <details style="background: #F5F5F5; border-left: 6px solid #F07C00; padding: 18px 22px;">
            <summary style="font-weight: 500; color: #22567C; font-size: 13px; text-transform: uppercase; letter-spacing: 1px; cursor: pointer; border-radius: 4px; padding: 4px 0;">
                Expected Response (Ground Truth) — click to expand
            </summary>
            <div style="font-family: 'Lexend', sans-serif; font-weight: 300; font-size: 13px; margin-top: 14px; color: #212526; line-height: 1.6;">{expected_esc}</div>
        </details>
    </div>
    """


# Display styled comparisons
for i, sample in enumerate(eval_data):
    user_content = sample["messages"][0]["content"]
    expected_output = sample["messages"][1]["content"]

    display(
        HTML(
            styled_comparison(
                i + 1, len(eval_data), user_content, expected_output, finetuned_responses[i], base_responses[i]
            )
        )
    )