# Compare: SageMaker Fine-Tuned vs MLX Base Model
---

This notebook performs **side-by-side comparison** between:

1. **Base Model (Local MLX)**: `gemma-3-270m-it-bf16` running on Apple Silicon
2. **Fine-Tuned Model (SageMaker)**: Model fine-tuned on financial sentiment data

**Purpose**: Evaluate how fine-tuning improves sentiment classification accuracy.

**Prerequisites**:
- Run notebook 03 to deploy the fine-tuned model endpoint (keep it running)
- Have the endpoint and inference component names ready

---

**Test Data**: 600 samples from `tmp_cache_local_dataset/test_data.jsonl`  
**Labels**: positive, negative, neutral, bullish, bearish

## 1. Setup and Dependencies

In [None]:
import json
import os
import time
import boto3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

print("Setup complete!")

In [None]:
# =============================================================================
# CONFIGURATION - Update these values from notebook 03
# =============================================================================

# SageMaker endpoint (from notebook 03 - must be running)
MODEL_NAME = "gemma-3-270m-sentiment-2026-01-09-07-37-13-218"  # UPDATE THIS
SAGEMAKER_ENDPOINT_NAME = f"ep-{MODEL_NAME}"
INFERENCE_COMPONENT_NAME = f"ic-{MODEL_NAME}"
AWS_REGION = "eu-west-2"

# Local MLX base model
MLX_MODEL_ID = "mlx-community/gemma-3-270m-it-bf16"

# Test data
TEST_DATA_PATH = os.path.join(os.getcwd(), "tmp_cache_local_dataset", "test_data.jsonl")

# Evaluation settings
MAX_SAMPLES = None  # Set to None to use all 600 samples
MAX_TOKENS = 32
TEMPERATURE = 0.1
TOP_P = 0.9

print(f"SageMaker Endpoint: {SAGEMAKER_ENDPOINT_NAME}")
print(f"MLX Model: {MLX_MODEL_ID}")
print(f"Test Data: {TEST_DATA_PATH}")
print(f"Max Samples: {MAX_SAMPLES or 'All'}")

## 2. Load Models

In [None]:
# Load local MLX base model
print(f"Loading local model: {MLX_MODEL_ID}")
print("This may take a moment on first run (downloading model weights)...")

mlx_model, mlx_tokenizer = load(MLX_MODEL_ID)

print("Local MLX model loaded successfully!")

In [None]:
# Configure SageMaker client
sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=AWS_REGION)

print(f"SageMaker client configured for region: {AWS_REGION}")
print(f"Endpoint: {SAGEMAKER_ENDPOINT_NAME}")
print(f"Inference Component: {INFERENCE_COMPONENT_NAME}")

## 3. Inference Functions (Like-for-Like Comparison)

Both inference functions use:
- **Same input format**: messages list from test data
- **Same label extraction**: shared `extract_label()` function
- **Same valid labels**: positive, negative, neutral, bullish, bearish

In [None]:
def generate_local_mlx(messages: list) -> str:
    """
    Generate prediction from local MLX base model.
    
    Args:
        messages: List of message dicts (system + user only, no assistant)
        
    Returns:
        Extracted label from model response
    """
    # Apply chat template (system + user messages only)
    input_messages = [m for m in messages if m["role"] != "assistant"]
    
    prompt = mlx_tokenizer.apply_chat_template(
        input_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Create sampler with temperature and top_p (matching SageMaker config)
    sampler = make_sampler(temp=TEMPERATURE, top_p=TOP_P)
    
    # Generate response
    response = generate(
        mlx_model,
        mlx_tokenizer,
        prompt=prompt,
        max_tokens=MAX_TOKENS,
        verbose=False,
        sampler=sampler
    )

    return response

print("MLX inference function defined.")

In [None]:
def generate_sagemaker(messages: list) -> str:
    """
    Generate prediction from fine-tuned model on SageMaker.
    
    Args:
        messages: List of message dicts (system + user only, no assistant)
        
    Returns:
        Extracted label from model response
    """
    # Use only system + user messages
    input_messages = [m for m in messages if m["role"] != "assistant"]
    
    payload = {
        "messages": input_messages,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "max_tokens": MAX_TOKENS,
    }
    
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=SAGEMAKER_ENDPOINT_NAME,
        InferenceComponentName=INFERENCE_COMPONENT_NAME,
        ContentType="application/json",
        Body=json.dumps(payload)
    )
    
    result = json.loads(response["Body"].read().decode("utf-8"))
    response_text = result["choices"][0]["message"]["content"].strip()
    
    return response_text

print("SageMaker inference function defined.")

## 4. Run Evaluation Loop

Run both models on the test dataset and collect predictions.

In [None]:
# Load test data
test_samples = []
with open(TEST_DATA_PATH, "r") as f:
    for line in f:
        test_samples.append(json.loads(line))

# Apply sample limit if set
if MAX_SAMPLES:
    test_samples = test_samples[:MAX_SAMPLES]

print(f"Loaded {len(test_samples)} test samples")

In [None]:
# Run evaluation on both models
ground_truth = []
mlx_predictions = []
sagemaker_predictions = []
texts = []

print(f"Running inference on {len(test_samples)} samples...")
print("This will run each sample through both MLX (local) and SageMaker (remote).\n")

for i, sample in enumerate(tqdm(test_samples, desc="Evaluating")):
    messages = sample["messages"]
    
    # Extract ground truth (assistant message)
    expected = messages[-1]["content"].strip().lower()
    ground_truth.append(expected)
    
    # Extract text for display
    user_text = messages[1]["content"] if len(messages) > 1 else ""
    texts.append(user_text)
    
    # Get MLX prediction
    try:
        mlx_pred = generate_local_mlx(messages)
    except Exception as e:
        mlx_pred = "error"
        print(f"\nMLX error on sample {i}: {e}")
    mlx_predictions.append(mlx_pred)
    
    # Get SageMaker prediction
    try:
        sm_pred = generate_sagemaker(messages)
    except Exception as e:
        sm_pred = "error"
        print(f"\nSageMaker error on sample {i}: {e}")
    sagemaker_predictions.append(sm_pred)
    
    # Small delay for SageMaker rate limiting
    time.sleep(0.1)

print(f"\nCompleted: {len(ground_truth)} samples evaluated")

## 5. Results Comparison

Compare accuracy, classification metrics, and confusion matrices for both models.

In [None]:
# Calculate accuracies
mlx_correct = [g == p for g, p in zip(ground_truth, mlx_predictions)]
sm_correct = [g == p for g, p in zip(ground_truth, sagemaker_predictions)]

mlx_accuracy = sum(mlx_correct) / len(mlx_correct)
sm_accuracy = sum(sm_correct) / len(sm_correct)

# Summary table
print("=" * 60)
print("ACCURACY COMPARISON")
print("=" * 60)
print(f"{'Model':<30} {'Accuracy':>15} {'Correct':>12}")
print("-" * 60)
print(f"{'MLX Base (Local)':<30} {mlx_accuracy:>14.2%} {sum(mlx_correct):>8}/{len(ground_truth)}")
print(f"{'SageMaker Fine-Tuned':<30} {sm_accuracy:>14.2%} {sum(sm_correct):>8}/{len(ground_truth)}")
print("-" * 60)

improvement = sm_accuracy - mlx_accuracy
if improvement > 0:
    print(f"{'Fine-tuning improvement:':<30} {improvement:>+14.2%}")
elif improvement < 0:
    print(f"{'Fine-tuning change:':<30} {improvement:>+14.2%}")
else:
    print(f"{'No change in accuracy':<30}")
print("=" * 60)

In [None]:
# Per-sample results DataFrame
results_df = pd.DataFrame({
    "text": [t for t in texts],
    "expected": ground_truth,
    "base_predicted": mlx_predictions,
    "sagemaker_predicted": sagemaker_predictions,
    "base_correct": mlx_correct,
    "sagemaker_correct": sm_correct,
})

results_df

In [None]:
VALID_LABELS = {'positive', 'negative', 'neutral', 'bullish', 'bearish'}

# Classification reports
labels = sorted(list(VALID_LABELS))

print("\n" + "=" * 60)
print("CLASSIFICATION REPORT: MLX Base Model (Local)")
print("=" * 60)
print(classification_report(ground_truth, mlx_predictions, labels=labels, zero_division=0))

print("\n" + "=" * 60)
print("CLASSIFICATION REPORT: SageMaker Fine-Tuned Model")
print("=" * 60)
print(classification_report(ground_truth, sagemaker_predictions, labels=labels, zero_division=0))

In [None]:
# Side-by-side confusion matrices
all_labels = sorted(list(set(ground_truth + mlx_predictions + sagemaker_predictions)))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# MLX confusion matrix
cm_mlx = confusion_matrix(ground_truth, mlx_predictions, labels=all_labels)
sns.heatmap(cm_mlx, annot=True, fmt='d', cmap='Blues',
            xticklabels=all_labels, yticklabels=all_labels, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_title(f'MLX Base Model (Local)\nAccuracy: {mlx_accuracy:.2%}')

# SageMaker confusion matrix
cm_sm = confusion_matrix(ground_truth, sagemaker_predictions, labels=all_labels)
sns.heatmap(cm_sm, annot=True, fmt='d', cmap='Greens',
            xticklabels=all_labels, yticklabels=all_labels, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')
axes[1].set_title(f'SageMaker Fine-Tuned\nAccuracy: {sm_accuracy:.2%}')

plt.tight_layout()
plt.show()