# Deploy and Evaluate Fine-Tuned Gemma-3-4B for Financial Sentiment
---
This notebook deploys the fine-tuned **Gemma-3-4B-IT** model from notebook 2 and evaluates it on the held-out test set.

**Prerequisites:**
- Run notebook 1 (`01_data_analysis.ipynb`) to generate training data
- Run notebook 2 (`02_finetune_training.ipynb`) to fine-tune the model
- Training job must be completed successfully

**What this notebook does:**
1. Extract model artifacts from S3
2. Deploy model to SageMaker endpoint using DJL LMI (vLLM)
3. Evaluate on 200 test samples from `test_data.jsonl`
4. Calculate accuracy, precision, recall, F1-score
5. Cleanup resources

---

**Model:** [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)  
**Instance:** ml.g5.2xlarge (1x A10G, 24GB VRAM)

## 1. Setup and Dependencies

In [None]:
import boto3
import sagemaker
from sagemaker.session import Session
import json
import time
import os
import pandas as pd
from getpass import getpass
from tqdm import tqdm

In [None]:
# Configuration
region = "eu-west-2"
role = "arn:aws:iam::889772146711:role/SageMakerExecutionRole"
s3_bucket = "sagemaker-eu-west-2-889772146711"
training_job_name = "google--gemma-3-4b-it-sentiment-finetune-20260107185405"

# S3 paths (matches notebook 02 output structure)
base_job_name = "google--gemma-3-4b-it-sentiment-finetune"
s3_model_uri = f"s3://{s3_bucket}/{base_job_name}/{training_job_name}/output/model.tar.gz"
s3_uncompressed_prefix = f"{base_job_name}/{training_job_name}/output/uncompressed/"
s3_model_prefix = f"s3://{s3_bucket}/{s3_uncompressed_prefix}"

# Create clients
sess = Session(boto3.Session(region_name=region))
sm_client = boto3.client("sagemaker", region_name=region)
s3_client = boto3.client("s3", region_name=region)
sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=region)

print(f"Training job: {training_job_name}")
print(f"Model artifacts: {s3_model_uri}")
print(f"Uncompressed path: {s3_model_prefix}")

In [None]:
# HuggingFace token (required for Gemma gated model)
hf_token = getpass("Enter HuggingFace token: ")

## 2. Prepare Model Artifacts

Extract `model.tar.gz` to S3 (streaming, no local download) and add missing `preprocessor_config.json` for vLLM compatibility.

In [None]:
%%time
import tarfile
from smart_open import open as smart_open
from huggingface_hub import hf_hub_download

# Check if already extracted
response = s3_client.list_objects_v2(
    Bucket=s3_bucket, 
    Prefix=s3_uncompressed_prefix + "google/gemma-3-4b-it/",
    MaxKeys=5
)
already_extracted = response.get("KeyCount", 0) > 0

# Extract if needed
if not already_extracted:
    print("Extracting model.tar.gz to S3 (streaming - no local download)...")
    print("This may take 5-10 minutes.")
    
    file_count = 0
    total_bytes = 0
    
    with smart_open(s3_model_uri, "rb") as f:
        with tarfile.open(fileobj=f, mode="r:*") as tar:
            for member in tar:
                if member.isfile():
                    file_obj = tar.extractfile(member)
                    if file_obj:
                        content = file_obj.read()
                        s3_client.put_object(
                            Bucket=s3_bucket,
                            Key=s3_uncompressed_prefix + member.name,
                            Body=content
                        )
                        file_count += 1
                        total_bytes += len(content)
                        if file_count % 10 == 0:
                            print(f"  Extracted {file_count} files ({total_bytes / 1e9:.2f} GB)...")
    
    print(f"Extraction complete: {file_count} files, {total_bytes / 1e9:.2f} GB")
else:
    print("Model already extracted. Skipping.")

# Add missing preprocessor_config.json (required for vLLM with Gemma 3 VLM)
print("\nAdding preprocessor_config.json from HuggingFace...")
config_path = hf_hub_download(repo_id="google/gemma-3-4b-it", filename="preprocessor_config.json", token=hf_token)
with open(config_path, "rb") as f:
    s3_client.put_object(
        Bucket=s3_bucket,
        Key=s3_uncompressed_prefix + "google/gemma-3-4b-it/preprocessor_config.json",
        Body=f.read()
    )
print(f"Model artifacts ready at: {s3_model_prefix}")

## 3. Create SageMaker Resources

Deploy using DJL LMI container with vLLM backend.

In [None]:
# Resource names
model_name = sagemaker.utils.name_from_base("gemma-3-4b-sentiment")
endpoint_config_name = f"epc-{model_name}"
endpoint_name = f"ep-{model_name}"
inference_component_name = f"ic-{model_name}"
variant_name = "AllTraffic"
instance_type = "ml.g5.2xlarge"  # 1x A10G (24GB)
num_gpu = 1

print(f"Model: {model_name}")
print(f"Endpoint: {endpoint_name}")
print(f"Instance: {instance_type}")

In [None]:
# Create Model
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.34.0-lmi16.0.0-cu128"

sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image,
        "ModelDataSource": {
            "S3DataSource": {
                "S3Uri": s3_model_prefix,
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
            }
        },
        "Environment": {
            "HF_TOKEN": hf_token,
            "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
            "MESSAGES_API_ENABLED": "true",
            "OPTION_MAX_ROLLING_BATCH_SIZE": "8",
            "OPTION_MODEL_LOADING_TIMEOUT": "1500",
            "SERVING_FAIL_FAST": "true",
            "OPTION_ROLLING_BATCH": "disable",
            "OPTION_ASYNC_MODE": "true",
            "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
            "OPTION_ENABLE_STREAMING": "true",
            "MAX_TOTAL_TOKENS": "4096",
        },
    },
)
print(f"Created model: {model_name}")

In [None]:
# Create Endpoint Configuration
sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ExecutionRoleArn=role,
    ProductionVariants=[{
        "VariantName": variant_name,
        "InstanceType": instance_type,
        "InitialInstanceCount": 1,
        "ModelDataDownloadTimeoutInSeconds": 3600,
        "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
        "ManagedInstanceScaling": {"Status": "ENABLED", "MinInstanceCount": 1, "MaxInstanceCount": 1},
        "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
    }],
)
print(f"Created endpoint config: {endpoint_config_name}")

In [None]:
# Deploy Endpoint
sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
print(f"Creating endpoint: {endpoint_name}")
print("Waiting for endpoint (this may take several minutes)...")

sess.wait_for_endpoint(endpoint_name)
print(f"Endpoint {endpoint_name} is InService!")

In [None]:
# Create Inference Component
sm_client.create_inference_component(
    InferenceComponentName=inference_component_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": model_name,
        "ComputeResourceRequirements": {
            "NumberOfAcceleratorDevicesRequired": num_gpu,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    RuntimeConfig={"CopyCount": 1},
)
print(f"Creating inference component: {inference_component_name}")
print("Waiting for inference component (this may take ~10 minutes)...")

start_time = time.time()
while True:
    desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name)
    status = desc["InferenceComponentStatus"]
    if status in ["InService", "Failed"]:
        break
    print(f"  Status: {status}")
    time.sleep(30)

if status == "Failed":
    raise Exception(f"Inference component failed: {desc.get('FailureReason', 'Unknown')}")

print(f"\nInference component ready! ({(time.time() - start_time)/60:.1f} minutes)")

## 4. Test Inference

Quick sanity check before running full evaluation.

In [None]:
def invoke_model(messages, max_tokens=32, temperature=0.1):
    """Invoke the deployed model with messages."""
    payload = {
        "messages": messages,
        "temperature": temperature,
        "top_p": 0.9,
        "max_tokens": max_tokens,
    }
    
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        InferenceComponentName=inference_component_name,
        ContentType="application/json",
        Body=json.dumps(payload)
    )
    
    result = json.loads(response["Body"].read().decode("utf-8"))
    return result["choices"][0]["message"]["content"].strip().lower()

In [None]:
# Quick test
test_messages = [
    {"role": "system", "content": "You are a financial sentiment analysis expert. Your task is to analyze the sentiment expressed in the given financial text.Only reply with positive, neutral, or negative."},
    {"role": "user", "content": "Apple reported record quarterly revenue of $123.9 billion, up 11% year over year."}
]

response = invoke_model(test_messages)
print(f"Test response: {response}")

## 5. Evaluate on Test Set

Run inference on all 200 test samples and calculate metrics.

In [None]:
# Load test data
test_data_path = os.path.join(os.getcwd(), "tmp_cache_local_dataset", "test_data.jsonl")

test_samples = []
with open(test_data_path, "r") as f:
    for line in f:
        test_samples.append(json.loads(line))

print(f"Loaded {len(test_samples)} test samples")

In [None]:
# Run inference on all test samples
predictions = []
ground_truth = []
errors = []

print("Running inference on test set...")
for i, sample in enumerate(tqdm(test_samples)):
    messages = sample["messages"]
    
    # Extract ground truth (assistant message)
    expected = messages[-1]["content"].strip().lower()
    ground_truth.append(expected)
    
    # Get model prediction (only system + user messages)
    input_messages = [m for m in messages if m["role"] != "assistant"]
    
    try:
        predicted = invoke_model(input_messages)
        # Normalize prediction to expected labels
        predicted = predicted.split()[0] if predicted else "unknown"  # Take first word
        predictions.append(predicted)
    except Exception as e:
        errors.append((i, str(e)))
        predictions.append("error")
    
    # Rate limiting - small delay between requests
    time.sleep(0.1)

print(f"\nCompleted: {len(predictions)} predictions, {len(errors)} errors")

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Filter out errors for metric calculation
valid_indices = [i for i, p in enumerate(predictions) if p != "error"]
valid_predictions = [predictions[i] for i in valid_indices]
valid_ground_truth = [ground_truth[i] for i in valid_indices]

print(f"Valid samples for evaluation: {len(valid_predictions)}/{len(predictions)}")
print()

# Calculate metrics
accuracy = accuracy_score(valid_ground_truth, valid_predictions)
print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print()

# Detailed classification report
print("Classification Report:")
print("=" * 60)
print(classification_report(valid_ground_truth, valid_predictions, zero_division=0))

In [None]:
# Confusion Matrix
labels = sorted(list(set(valid_ground_truth + valid_predictions)))
cm = confusion_matrix(valid_ground_truth, valid_predictions, labels=labels)

print("Confusion Matrix:")
print("=" * 60)
cm_df = pd.DataFrame(cm, index=labels, columns=labels)
cm_df.index.name = "Actual"
cm_df.columns.name = "Predicted"
print(cm_df)
print()
print("Rows = Actual labels, Columns = Predicted labels")

In [None]:
# Sample predictions (first 10)
print("Sample Predictions (first 10):")
print("=" * 80)
for i in range(min(10, len(test_samples))):
    user_msg = test_samples[i]["messages"][1]["content"][:80]  # Truncate long text
    expected = ground_truth[i]
    predicted = predictions[i]
    match = "✓" if expected == predicted else "✗"
    print(f"{match} Expected: {expected:10} | Predicted: {predicted:10} | Text: {user_msg}...")

In [None]:
# Show misclassified examples
print("Misclassified Examples (first 10):")
print("=" * 80)
misclassified = [(i, ground_truth[i], predictions[i]) for i in range(len(predictions)) 
                 if ground_truth[i] != predictions[i] and predictions[i] != "error"]

for idx, (i, expected, predicted) in enumerate(misclassified[:10]):
    user_msg = test_samples[i]["messages"][1]["content"][:100]
    print(f"{idx+1}. Expected: {expected:10} | Predicted: {predicted:10}")
    print(f"   Text: {user_msg}...")
    print()

## 6. Save Results

In [None]:
# Save evaluation results
results_df = pd.DataFrame({
    "text": [s["messages"][1]["content"] for s in test_samples],
    "expected": ground_truth,
    "predicted": predictions,
    "correct": [g == p for g, p in zip(ground_truth, predictions)]
})

results_path = os.path.join(os.getcwd(), "evaluation_results.csv")
results_df.to_csv(results_path, index=False)
print(f"Results saved to: {results_path}")

# Summary
print(f"\nSummary:")
print(f"  Total samples: {len(test_samples)}")
print(f"  Correct: {results_df['correct'].sum()}")
print(f"  Incorrect: {(~results_df['correct']).sum()}")
print(f"  Accuracy: {results_df['correct'].mean()*100:.2f}%")

## 7. Cleanup

Delete all SageMaker resources to stop billing. **Only run this when you're done with the endpoint.**

In [None]:
import botocore

def delete_resource(delete_fn, name, rtype):
    try:
        delete_fn()
        print(f"✓ Deleted {rtype}: {name}")
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] in ['ValidationException', 'ResourceNotFound']:
            print(f"⊘ {rtype} not found: {name}")
        else:
            raise

In [None]:
# Run this cell to perform cleanup
delete_resource(lambda: sm_client.delete_inference_component(InferenceComponentName=inference_component_name), inference_component_name, "Inference component")
print("Waiting 60s for inference component deletion...")
time.sleep(60)
delete_resource(lambda: sm_client.delete_endpoint(EndpointName=endpoint_name), endpoint_name, "Endpoint")
delete_resource(lambda: sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name), endpoint_config_name, "Endpoint config")
delete_resource(lambda: sm_client.delete_model(ModelName=model_name), model_name, "Model")
print("\n✓ Cleanup complete!")