In [2]:
import os
import tarfile
import boto3
from urllib.parse import urlparse

import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import image_uris

# ========= CONFIG =========
MODEL_PT2_S3 = "s3://ai-bmi-predictor-v2/image-segmentation/sapiens/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2"
PACKAGED_MODEL_S3 = "s3://ai-bmi-predictor-v2/image-segmentation/sapiens/sagemaker/model.tar.gz"

ENDPOINT_NAME = "sapiens-segmentation-endpoint"
INSTANCE_TYPE = "ml.g4dn.4xlarge"      # GPU recommended
PYTORCH_VERSION = "2.6"
PY_VERSION = "py312"

# ---- Option A autoscaling settings (min >= 1) ----
AUTOSCALING_MIN_CAPACITY = 1
AUTOSCALING_MAX_CAPACITY = 4
TARGET_INVOCATIONS_PER_INSTANCE = 20.0   # GPU models often need a lower target; tune later
SCALE_OUT_COOLDOWN_SECONDS = 60
SCALE_IN_COOLDOWN_SECONDS = 300

# ========= ROLE =========
try:
    from sagemaker import get_execution_role
    ROLE = get_execution_role()
except Exception:
    ROLE = None
if not ROLE:
    raise ValueError("ROLE is None. Set ROLE to your SageMaker execution role ARN if running outside SageMaker.")

# ========= HELPERS =========
def parse_s3_uri(uri: str):
    p = urlparse(uri)
    if p.scheme != "s3" or not p.netloc or not p.path:
        raise ValueError(f"Invalid S3 URI: {uri}")
    return p.netloc, p.path.lstrip("/")

def s3_download(s3_uri: str, local_path: str):
    b, k = parse_s3_uri(s3_uri)
    boto3.client("s3").download_file(b, k, local_path)

def s3_upload(local_path: str, s3_uri: str):
    b, k = parse_s3_uri(s3_uri)
    boto3.client("s3").upload_file(local_path, b, k)

# ========= 1) Create inference.py =========
os.makedirs("deploy_src/code", exist_ok=True)

inference_py = r'''
import base64
import io
import json
import os

import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision import transforms

_preprocess = transforms.Compose([
    transforms.Resize((1024, 768)),  # (H, W)
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

def _load_model(model_path: str):
    # Your local code: torch.jit.load first
    try:
        return torch.jit.load(model_path)
    except Exception as e1:
        # Fallback: PT2 archive
        try:
            ep = torch.export.load(model_path)
            return ep.module()
        except Exception as e2:
            raise RuntimeError(
                f"Failed to load model as TorchScript and PT2.\nTorchScript error: {e1}\nPT2 error: {e2}"
            )

def model_fn(model_dir: str):
    candidates = []
    for f in os.listdir(model_dir):
        if f.endswith((".pt2", ".pt", ".pth")):
            candidates.append(os.path.join(model_dir, f))
    if not candidates:
        raise FileNotFoundError(f"No .pt2/.pt/.pth found in {model_dir}")

    model_path = sorted(candidates)[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = _load_model(model_path).eval().to(device)
    return {"model": model, "device": device}

def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        payload = json.loads(request_body)
        if "image_b64" not in payload:
            raise ValueError("JSON must include 'image_b64'.")
        img_bytes = base64.b64decode(payload["image_b64"])
        return Image.open(io.BytesIO(img_bytes)).convert("RGB")

    if request_content_type.startswith("image/") or request_content_type == "application/x-image":
        return Image.open(io.BytesIO(request_body)).convert("RGB")

    raise ValueError(f"Unsupported content type: {request_content_type}")

def predict_fn(img: Image.Image, model_bundle):
    model = model_bundle["model"]
    device = model_bundle["device"]

    orig_w, orig_h = img.size  # (W, H)
    x = _preprocess(img).unsqueeze(0).to(device)  # (1, C, 1024, 768)

    with torch.inference_mode():
        out = model(x)

    # Match your code: logits_small = output[0]
    if isinstance(out, (tuple, list)):
        logits_small = out[0]
    else:
        logits_small = out

    if logits_small.ndim == 4:
        logits_small = logits_small[0]  # (C, h, w)

    logits_small = logits_small.to("cpu")

    logits = F.interpolate(
        logits_small.unsqueeze(0),
        size=(orig_h, orig_w),
        mode="bilinear",
        align_corners=False,
    ).squeeze(0)

    seg = logits.argmax(dim=0).numpy().astype(np.uint8)
    body_mask_bw = (seg != 0).astype(np.uint8) * 255

    mask_img = Image.fromarray(body_mask_bw, mode="L")
    buf = io.BytesIO()
    mask_img.save(buf, format="PNG")

    return {
        "mask_png_b64": base64.b64encode(buf.getvalue()).decode("utf-8"),
        "mask_shape_hw": [int(orig_h), int(orig_w)],
    }

def output_fn(prediction, accept):
    if accept == "application/json":
        return json.dumps(prediction), accept
    raise ValueError(f"Unsupported accept type: {accept}")
'''

with open("deploy_src/code/inference.py", "w") as f:
    f.write(inference_py)

print("✅ Wrote deploy_src/code/inference.py")

# ========= 2) Download model =========
local_model = "deploy_src/model.pt2"
s3_download(MODEL_PT2_S3, local_model)
print("✅ Downloaded:", MODEL_PT2_S3)

# ========= 3) Create model.tar.gz =========
tar_path = "model.tar.gz"
with tarfile.open(tar_path, "w:gz") as tar:
    tar.add(local_model, arcname="model.pt2")
    tar.add("deploy_src/code", arcname="code")
print("✅ Created:", tar_path)

# ========= 4) Upload =========
s3_upload(tar_path, PACKAGED_MODEL_S3)
print("✅ Uploaded packaged model to:", PACKAGED_MODEL_S3)

# ========= 5) Deploy =========
sess = sagemaker.Session()
region = boto3.Session().region_name

image_uri = image_uris.retrieve(
    framework="pytorch",
    region=region,
    version=PYTORCH_VERSION,
    py_version=PY_VERSION,
    image_scope="inference",
    instance_type=INSTANCE_TYPE,
)

model = PyTorchModel(
    model_data=PACKAGED_MODEL_S3,
    role=ROLE,
    entry_point="inference.py",
    source_dir="deploy_src/code",
    framework_version=PYTORCH_VERSION,
    py_version=PY_VERSION,
    image_uri=image_uri,
    sagemaker_session=sess,
)

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=INSTANCE_TYPE,
    endpoint_name=ENDPOINT_NAME,
)

# ========= 6) Enable autoscaling (Option A) =========
print("\n[STEP 6] Enabling autoscaling (Option A: min>=1, max=N)...")

sm = boto3.client("sagemaker")
aas = boto3.client("application-autoscaling")

ep_desc = sm.describe_endpoint(EndpointName=ENDPOINT_NAME)
epc_desc = sm.describe_endpoint_config(EndpointConfigName=ep_desc["EndpointConfigName"])
variant_name = epc_desc["ProductionVariants"][0]["VariantName"]

resource_id = f"endpoint/{ENDPOINT_NAME}/variant/{variant_name}"

aas.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    MinCapacity=AUTOSCALING_MIN_CAPACITY,
    MaxCapacity=AUTOSCALING_MAX_CAPACITY,
)

aas.put_scaling_policy(
    PolicyName=f"{ENDPOINT_NAME}-invocations-tt",
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    PolicyType="TargetTrackingScaling",
    TargetTrackingScalingPolicyConfiguration={
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerVariantInvocationsPerInstance"
        },
        "TargetValue": TARGET_INVOCATIONS_PER_INSTANCE,
        "ScaleOutCooldown": SCALE_OUT_COOLDOWN_SECONDS,
        "ScaleInCooldown": SCALE_IN_COOLDOWN_SECONDS,
        "DisableScaleIn": False,
    },
)

print(f"✅ Deployed endpoint: {ENDPOINT_NAME} ({region})")
print(f"✅ Autoscaling enabled for variant: {variant_name}")
print(f"   Min: {AUTOSCALING_MIN_CAPACITY}  Max: {AUTOSCALING_MAX_CAPACITY}")
print(f"   Target Invocations/Instance: {TARGET_INVOCATIONS_PER_INSTANCE}")


✅ Wrote deploy_src/code/inference.py
✅ Downloaded: s3://ai-bmi-predictor-v2/image-segmentation/sapiens/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2
✅ Created: model.tar.gz
✅ Uploaded packaged model to: s3://ai-bmi-predictor-v2/image-segmentation/sapiens/sagemaker/model.tar.gz
--------!
[STEP 6] Enabling autoscaling (Option A: min>=1, max=N)...
✅ Deployed endpoint: sapiens-segmentation-endpoint (eu-west-2)
✅ Autoscaling enabled for variant: AllTraffic
   Min: 1  Max: 4
   Target Invocations/Instance: 20.0
