# Deepfake Voice Detection on Amazon SageMaker

Deploy the fine-tuned Wav2Vec2 model for deepfake voice detection on Amazon SageMaker using a custom inference script.

## Step 1: Environment Setup

Install required dependencies and configure the environment.

In [None]:
%pip install -r requirements.txt -Uq

## Step 2: Download Model

Clone the Deepfake model of your choice using `git-xet`:

```bash
# Make sure git-xet is installed (https://hf.co/docs/hub/git-xet)
brew install git-xet
git xet install

git clone https://huggingface.co/garystafford/wav2vec2-deepfake-voice-detector
```

### Alternative: Using Hugging Face Hub

In [None]:
from huggingface_hub import snapshot_download

repo_id = "garystafford/wav2vec2-deepfake-voice-detector"
snapshot_download(repo_id, local_dir="wav2vec2-deepfake-voice-detector")

print(f"Model downloaded: {repo_id}")

## Step 3: Package Model Artifacts

Package the model and code directories and upload to S3.

```txt
model.tar.gz
  â”œâ”€ model/
  â”‚   â”œâ”€ config.json
  â”‚   â”œâ”€ model.safetensors
  â”‚   â””â”€ preprocessor_config.json
  â””â”€ code/
      â”œâ”€ inference.py
      â””â”€ requirements.txt
```

In [None]:
%%sh

python prepare_sagemaker_model.py --model-path wav2vec2-deepfake-voice-detector

In [None]:
# various settings
artifact_path = "model.tar.gz"
key_prefix = "wav2vec2-deepfake-voice-detector"

In [None]:
import os

# Set an environment variable for your existing Sagemaker Execution Role ARN
os.environ["SAGEMAKER_ROLE_ARN"] = (
    "arn:aws:iam::676164205626:role/service-role/AmazonSageMaker-ExecutionRole-<your-role-id>"
)

In [None]:
%%time

import os
import tarfile

# Zip up the model artifacts and code directories into a tar.gz file
with tarfile.open(artifact_path, "w:gz") as tar:
    tar.add("model", arcname=os.path.basename("model"))
    tar.add("code", arcname=os.path.basename("code"))

### Validate Model Artifacts

Before uploading and deploying, verify `model.tar.gz` contains `model/preprocessor_config.json`, `model/config.json`, and `model/model.safetensors`. Also confirm the packaged `code/inference.py` includes the latest model directory resolution logic.

In [None]:
# Inspect model.tar.gz contents and validate required files
import os
import tarfile

artifact_path = globals().get("artifact_path", "model.tar.gz")

assert os.path.isfile(artifact_path), f"Missing artifact: {artifact_path}"

with tarfile.open(artifact_path, "r:gz") as tar:
    names = tar.getnames()
    print("Tar contents (first 20):", names[:20])
    required = [
        "model/preprocessor_config.json",
        "model/config.json",
        "model/model.safetensors",
        "code/inference.py",
        "code/requirements.txt",
    ]
    missing = [p for p in required if p not in names]
    if missing:
        raise FileNotFoundError(f"Missing required paths in tar: {missing}")

print("Artifact validation passed âœ…")

In [None]:
import os
import boto3

# Create low-level clients
session = boto3.Session()
region = session.region_name or "us-east-1"
sts = session.client("sts", region_name=region)
iam = session.client("iam", region_name=region)
sagemaker_client = session.client("sagemaker", region_name=region)

# Emulate SageMaker SDK's default bucket convention if you like
role_arn = os.environ.get("SAGEMAKER_ROLE_ARN")
account_id = sts.get_caller_identity()["Account"]
default_bucket = f"sagemaker-{region}-{account_id}"

print(f"Account ID: {account_id}")
print(f"RoleArn: {role_arn}")
print(f"Region: {region}")
print(f"Default S3 Bucket: {default_bucket}")

In [None]:
%%time

import os
import boto3

s3 = boto3.client("s3")

file_name = os.path.basename(artifact_path)
s3_key = f"{key_prefix}/{file_name}"

s3.upload_file(artifact_path, default_bucket, s3_key)

model_s3_path = f"s3://{default_bucket}/{s3_key}"
print(f"Uploaded model artifact to: {model_s3_path}")

## Step 4: Deploy to SageMaker

In [None]:
%%time

import os
from datetime import datetime, timezone

from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.model import Model
from sagemaker.session import Session as SmSession

# Use the same region/session as earlier
sm_session = SmSession(boto_session=session)

# Ensure container image URI matches the chosen region
# https://gallery.ecr.aws/deep-learning-containers/pytorch-inference
image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-inference:2.6.0-gpu-py312-cu124-ubuntu22.04-sagemaker-v1.56"

# Note: the current SageMaker handler returns the model's raw labels (e.g., 'real'/'fake')
# and does not support label swapping or server-side trimming via env vars.
container_env = {}

# Endpoint/model name
custom_endpoint_model_name = f"{key_prefix}-" + datetime.now(timezone.utc).strftime(
    "%Y-%m-%d-%H-%M-%S"
)

custom_model = Model(
    image_uri=image_uri,
    model_data=model_s3_path,
    role=role_arn,
    sagemaker_session=sm_session,
    name=custom_endpoint_model_name,
    entry_point="inference.py",
    source_dir="code",
    env=container_env,
)

instance_type = "ml.g4dn.xlarge"

print(f"Deploying to endpoint: {custom_endpoint_model_name}")
print("Container env:", container_env)

globals()["custom_endpoint_model_name"] = custom_endpoint_model_name

# Takes 7-11 minutes to deploy endpoint
predictor = custom_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=custom_endpoint_model_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    container_startup_health_check_timeout=300,
)

## Step 5: Real-time Inference

Update the `endpoint_name` variable with your deployed endpoint name and run inference on local audio files.

In [None]:
import os
import librosa
import base64
import json
import boto3

# Prefer the endpoint created by the deploy cell; otherwise use a known endpoint name.
endpoint_name = "<your-endpoint-name>"  # replace with your endpoint name if needed
assert (
    endpoint_name and endpoint_name != "<your-endpoint-name>"
), "Set endpoint_name first"


def canonical_label(label: str | None) -> str | None:
    if label is None:
        return None
    s = str(label).strip()
    if not s:
        return s
    v = s.lower()
    if v in {"fake", "deepfake", "synthetic"}:
        return "Deepfake"
    if v in {"real", "bonafide", "bona-fide", "bona fide", "non-synthetic"}:
        return "Real"
    return s


def canonicalize_probabilities(probs: dict | None) -> dict[str, float]:
    out: dict[str, float] = {}
    if not isinstance(probs, dict):
        return out
    for k, v in probs.items():
        ck = canonical_label(k)
        if ck in {"Deepfake", "Real"} and isinstance(v, (float, int)):
            out[ck] = float(v)
    return out


def send_audio_to_sagemaker(audio_file: str) -> dict:
    # Load and resample to 16 kHz on client
    waveform, _ = librosa.load(audio_file, sr=16000, mono=True)
    waveform = waveform.astype("float32")

    # Encode to base64 float32 PCM
    audio_b64 = base64.b64encode(waveform.tobytes()).decode("utf-8")
    payload = {"audio_base64": audio_b64, "sample_rate": 16000}

    # Send to endpoint
    rt = boto3.client("sagemaker-runtime")
    resp = rt.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/json",
        Body=json.dumps(payload),
    )
    return json.loads(resp["Body"].read())


audio_dir = "audio_samples"
audio_files = [
    os.path.join(audio_dir, file)
    for file in os.listdir(audio_dir)
    if file.endswith(".flac") or file.endswith(".wav") or file.endswith(".mp3")
]

print("Endpoint:", endpoint_name)
print()

audio_files.sort()

for path in audio_files:
    if not os.path.exists(path):
        print("Missing:", path)
        continue
    resp = send_audio_to_sagemaker(path)

    raw_pred = resp.get("prediction")
    probs_raw = resp.get("probabilities") or {}
    probs = canonicalize_probabilities(probs_raw)

    pred = canonical_label(raw_pred)
    conf = resp.get("confidence")
    if conf is None and probs:
        conf = max(probs.values())
    if pred is None and probs:
        pred = max(probs, key=probs.get)

    pred_symbol = "ðŸ”´" if pred == "Deepfake" else "ðŸŸ¢"

    print(f"{pred_symbol} {path.split('\\')[-1]:30s} â†’ {pred:5s} ({conf:.1%})")

    if "Deepfake" in probs and "Real" in probs:
        print(
            f"  probabilities: Deepfake={float(probs['Deepfake']):.8f}  Real={float(probs['Real']):.8f}"
        )
    else:
        # Show both so it's obvious what the endpoint returned vs what we canonicalized
        print("  probabilities_raw:", probs_raw)
        print("  probabilities_canonical:", probs)
    print()