In [None]:
import os
import torch
import whisper
import torchaudio
import sagemaker
import base64
import requests
import json

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class LibriSpeech(torch.utils.data.Dataset):
    """
    A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
    It will drop the last few seconds of a very small portion of the utterances.
    """
    def __init__(self, split="test-clean", device=DEVICE):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=os.path.expanduser("~/.cache"),
            url=split,
            download=True,
        )
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        audio, sample_rate, text, _, _, _ = self.dataset[item]
        assert sample_rate == 16000
        audio = whisper.pad_or_trim(audio.flatten()).to(self.device)
        mel = whisper.log_mel_spectrogram(audio)
        
        return (mel, text)


In [None]:
dataset = LibriSpeech("test-clean")
audio, sample_rate, text, _, _, _ = dataset.dataset[0]

# Convert the audio data to a Base64-encoded string
audio_base64_encoded = base64.b64encode(audio.numpy().tobytes()).decode("utf-8")

In [None]:
# Prepare the data in JSON format
data = {
    "audio_base64": audio_base64_encoded,
    "sample_rate": sample_rate,
}

In [None]:
import boto3

sagemaker_runtime = boto3.client('sagemaker-runtime')

endpoint_name = 'whisper-gpu-endpoint'
content_type = 'application/json'  # You may need to adjust this based on your use case

# Convert the request data to JSON
request_payload = json.dumps(data)

# Send the request to the endpoint
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=request_payload
)

# Read and process the response
response_body = response['Body'].read()

# The response may be in JSON format, so you can parse it if needed
response_data = json.loads(response_body)

# Process the response_data as per your application's needs
print(response_data)
