# Deploy Fine-Tuned Llama-3.2-1B for Customer Support Triage
---
This notebook deploys the fine-tuned **Llama-3.2-1B-Instruct** model from notebook 1 to a SageMaker endpoint.

**Prerequisites:**
- Run notebook 1 (`01_finetune_training.ipynb`) to fine-tune the model
- Training job must be completed successfully

**What this notebook does:**
1. Extract model artifacts from S3
2. Deploy model to SageMaker endpoint using DJL LMI (vLLM)
3. Run quick sanity check
4. Provide cleanup instructions

**Next step:** Run notebook 3 (`03_evaluate.ipynb`) for full evaluation

---

**Model:** [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)  
**Instance:** ml.g4dn.xlarge (1x T4, 16GB VRAM)

## 1. Setup and Dependencies

In [None]:
import boto3
import sagemaker
from sagemaker.session import Session
import json
import time
from getpass import getpass

In [None]:
# Configuration
region = "eu-west-2"
role = "arn:aws:iam::889772146711:role/SageMakerExecutionRole"
s3_bucket = "sagemaker-eu-west-2-889772146711"

# UPDATE THIS after running notebook 01
training_job_name = "llama1b-cs-ft-1768325350-20260113172916"

# S3 paths (matches notebook 01 output structure)
base_job_name = "llama1b-cs-ft-1768325350"
s3_model_uri = f"s3://{s3_bucket}/{base_job_name}/{training_job_name}/output/model.tar.gz"
s3_uncompressed_prefix = f"{base_job_name}/{training_job_name}/output/uncompressed/"
s3_model_prefix = f"s3://{s3_bucket}/{s3_uncompressed_prefix}"

# Create clients
sess = Session(boto3.Session(region_name=region))
sm_client = boto3.client("sagemaker", region_name=region)
s3_client = boto3.client("s3", region_name=region)
sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=region)

print(f"Training job: {training_job_name}")
print(f"Model artifacts: {s3_model_uri}")
print(f"Uncompressed path: {s3_model_prefix}")

In [None]:
# Enter and validate HuggingFace token (required for gated models like Llama)
hf_token = getpass("Enter your HuggingFace token: ")

# Validate token format
if not hf_token:
    raise ValueError("❌ HuggingFace token cannot be empty!")
elif not hf_token.startswith("hf_"):
    raise ValueError(
        f"❌ Invalid HuggingFace token format!\n"
        f"   Token should start with 'hf_' but starts with '{hf_token[:3]}...'\n"
        f"   Get a valid token at: https://huggingface.co/settings/tokens"
    )
elif len(hf_token) < 20:
    raise ValueError(
        f"❌ HuggingFace token too short!\n"
        f"   Token is only {len(hf_token)} characters (expected 37+)\n"
        f"   Make sure you copied the full token from: https://huggingface.co/settings/tokens"
    )
else:
    print(f"✓ HuggingFace token accepted")
    print(f"  - Format: Valid (starts with 'hf_')")
    print(f"  - Length: {len(hf_token)} characters")
    print(f"  - Preview: {hf_token[:7]}...{hf_token[-4:]}")

## 2. Prepare Model Artifacts

Extract `model.tar.gz` to S3 (streaming, no local download).

In [None]:
%%time
import tarfile
from smart_open import open as smart_open

# Check if already extracted
response = s3_client.list_objects_v2(
    Bucket=s3_bucket, 
    Prefix=s3_uncompressed_prefix + "meta-llama/Llama-3.2-1B-Instruct/",
    MaxKeys=5
)
already_extracted = response.get("KeyCount", 0) > 0

# Extract if needed
if not already_extracted:
    print("Extracting model.tar.gz to S3 (streaming - no local download)...")
    print("This may take 5-10 minutes.")
    
    file_count = 0
    total_bytes = 0
    
    with smart_open(s3_model_uri, "rb") as f:
        with tarfile.open(fileobj=f, mode="r:*") as tar:
            for member in tar:
                if member.isfile():
                    file_obj = tar.extractfile(member)
                    if file_obj:
                        content = file_obj.read()
                        s3_client.put_object(
                            Bucket=s3_bucket,
                            Key=s3_uncompressed_prefix + member.name,
                            Body=content
                        )
                        file_count += 1
                        total_bytes += len(content)
                        if file_count % 10 == 0:
                            print(f"  Extracted {file_count} files ({total_bytes / 1e9:.2f} GB)...")
    
    print(f"Extraction complete: {file_count} files, {total_bytes / 1e9:.2f} GB")
else:
    print("Model already extracted. Skipping.")

print(f"Model artifacts ready at: {s3_model_prefix}")

## 3. Create SageMaker Resources

Deploy using DJL LMI container with vLLM backend.

In [None]:
# Resource names
model_name = sagemaker.utils.name_from_base("llama-3-2-1b-customer-support")
endpoint_config_name = f"epc-{model_name}"
endpoint_name = f"ep-{model_name}"
inference_component_name = f"ic-{model_name}"
variant_name = "AllTraffic"
instance_type = "ml.g4dn.xlarge"  # 1x T4 (16GB) - sufficient for 1B model
num_gpu = 1

print(f"Model: {model_name}")
print(f"Endpoint: {endpoint_name}")
print(f"Inference Component: {inference_component_name}")
print(f"Instance: {instance_type}")
print("\n⚠️  SAVE THESE NAMES - You'll need them for notebook 03!")

In [None]:
# Create Model
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.34.0-lmi16.0.0-cu128"

sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image,
        "ModelDataSource": {
            "S3DataSource": {
                "S3Uri": s3_model_prefix,
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
            }
        },
        "Environment": {
            "HF_TOKEN": hf_token,
            "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
            "MESSAGES_API_ENABLED": "true",
            "OPTION_MAX_ROLLING_BATCH_SIZE": "8",
            "OPTION_MODEL_LOADING_TIMEOUT": "1500",
            "SERVING_FAIL_FAST": "true",
            "OPTION_ROLLING_BATCH": "disable",
            "OPTION_ASYNC_MODE": "true",
            "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
            "OPTION_ENABLE_STREAMING": "true",
            "MAX_TOTAL_TOKENS": "4096",
            # T4 doesn't support FA2; use xformers backend
            "VLLM_ATTENTION_BACKEND": "XFORMERS",
        },
    },
)
print(f"Created model: {model_name}")

In [None]:
# Create Endpoint Configuration
sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ExecutionRoleArn=role,
    ProductionVariants=[{
        "VariantName": variant_name,
        "InstanceType": instance_type,
        "InitialInstanceCount": 1,
        "ModelDataDownloadTimeoutInSeconds": 3600,
        "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
        "ManagedInstanceScaling": {"Status": "ENABLED", "MinInstanceCount": 1, "MaxInstanceCount": 1},
        "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
    }],
)
print(f"Created endpoint config: {endpoint_config_name}")

In [None]:
# Deploy Endpoint
sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
print(f"Creating endpoint: {endpoint_name}")
print("Waiting for endpoint (this may take several minutes)...")

sess.wait_for_endpoint(endpoint_name)
print(f"Endpoint {endpoint_name} is InService!")

In [None]:
# Create Inference Component
sm_client.create_inference_component(
    InferenceComponentName=inference_component_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": model_name,
        "ComputeResourceRequirements": {
            "NumberOfAcceleratorDevicesRequired": num_gpu,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    RuntimeConfig={"CopyCount": 1},
)
print(f"Creating inference component: {inference_component_name}")
print("Waiting for inference component (this may take ~10 minutes)...")

start_time = time.time()
while True:
    desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name)
    status = desc["InferenceComponentStatus"]
    if status in ["InService", "Failed"]:
        break
    print(f"  Status: {status}")
    time.sleep(30)

if status == "Failed":
    raise Exception(f"Inference component failed: {desc.get('FailureReason', 'Unknown')}")

print(f"\nInference component ready! ({(time.time() - start_time)/60:.1f} minutes)")

## 4. Test Inference

Quick sanity check before running full evaluation in notebook 03.

In [None]:
def invoke_model(messages, max_tokens=512, temperature=0.1):
    """Invoke the deployed model with messages."""
    payload = {
        "messages": messages,
        "temperature": temperature,
        "top_p": 0.9,
        "max_tokens": max_tokens,
    }
    
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        InferenceComponentName=inference_component_name,
        ContentType="application/json",
        Body=json.dumps(payload)
    )
    
    result = json.loads(response["Body"].read().decode("utf-8"))
    return result["choices"][0]["message"]["content"].strip()

In [None]:
# Quick test with a sample customer support ticket
test_messages = [
    {"role": "user", "content": """Ticket #9999
Customer: test.user@example.com
Plan: Enterprise
Issue: URGENT - Users can't log in to the dashboard. Getting 401 errors across the board. Our sales team is blocked."""}
]

response = invoke_model(test_messages)
print("Test Response:")
print("=" * 60)
print(response)
print("\n✓ Endpoint is working! Proceed to notebook 03 for full evaluation.")