## Multi Model Endpoint 생성하기

## 환경설정

In [95]:
import sys, os

vits_path = os.path.abspath("./vits")
if vits_path not in sys.path:
    sys.path.append(vits_path)

for i in sys.path:
    print(i)

/home/ec2-user/SageMaker/lab/00-trition-tts-vits/02-tts-vits-docker-trition
/home/ec2-user/miniconda3/envs/conda-vits-py310/lib/python310.zip
/home/ec2-user/miniconda3/envs/conda-vits-py310/lib/python3.10
/home/ec2-user/miniconda3/envs/conda-vits-py310/lib/python3.10/lib-dynload

/home/ec2-user/miniconda3/envs/conda-vits-py310/lib/python3.10/site-packages
/home/ec2-user/SageMaker/lab/00-trition-tts-vits/02-tts-vits-docker-trition/vits


In [103]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
!pip uninstall -y sagemaker
!pip install sagemaker

Found existing installation: sagemaker 2.232.2
Uninstalling sagemaker-2.232.2:
  Successfully uninstalled sagemaker-2.232.2
Collecting sagemaker
  Using cached sagemaker-2.232.2-py3-none-any.whl.metadata (16 kB)
Using cached sagemaker-2.232.2-py3-none-any.whl (1.6 MB)
Installing collected packages: sagemaker
Successfully installed sagemaker-2.232.2


In [2]:
# 완전히 새로 설치
!pip uninstall -y sagemaker boto3 awscli
!pip cache purge
!pip install sagemaker boto3 awscli
!pip install --upgrade sagemaker>=2.231.0

Found existing installation: sagemaker 2.232.2
Uninstalling sagemaker-2.232.2:
  Successfully uninstalled sagemaker-2.232.2
Found existing installation: boto3 1.35.51
Uninstalling boto3-1.35.51:
  Successfully uninstalled boto3-1.35.51
[0mAn error occurred during configuration: option format: invalid choice: 'columns' (choose from 'human', 'abspath')
Collecting sagemaker
  Using cached sagemaker-2.232.2-py3-none-any.whl.metadata (16 kB)
Collecting boto3
  Using cached boto3-1.35.51-py3-none-any.whl.metadata (6.7 kB)
Collecting awscli
  Downloading awscli-1.35.17-py3-none-any.whl.metadata (11 kB)
Collecting docutils<0.17,>=0.10 (from awscli)
  Using cached docutils-0.16-py2.py3-none-any.whl.metadata (2.7 kB)
Collecting rsa<4.8,>=3.1.2 (from awscli)
  Using cached rsa-4.7.2-py3-none-any.whl.metadata (3.6 kB)
Using cached sagemaker-2.232.2-py3-none-any.whl (1.6 MB)
Using cached boto3-1.35.51-py3-none-any.whl (139 kB)
Downloading awscli-1.35.17-py3-none-any.whl (4.5 MB)
[2K   [90m━━━━━━

## (빈)엔드포인트 생성 : 인프라 준비

In [96]:
import boto3
import sagemaker
from sagemaker.model import Model
from sagemaker.multidatamodel import MultiDataModel
from sagemaker import image_uris
from datetime import datetime
import time

class TritonEndpointManager:
    def __init__(self, region_name, instance_type="ml.g4dn.xlarge", triton_version="24.05"):
        """
        Triton MME 엔드포인트 관리자 초기화
        
        Args:
            region_name (str): AWS 리전명
            instance_type (str): 인스턴스 타입
            triton_version (str): Triton 버전
        """
        self.sagemaker_session = sagemaker.Session()
        self.default_bucket = self.sagemaker_session.default_bucket()
        self.role = sagemaker.get_execution_role()
        self.region_name = region_name
        self.instance_type = instance_type
        self.triton_version = triton_version
        self.sagemaker_client = boto3.client('sagemaker', region_name=region_name)
        
    def _get_triton_image(self):
        """Triton 이미지 URI 조회"""
        return image_uris.retrieve(
            framework="sagemaker-tritonserver",
            region=self.region_name,
            version=self.triton_version,
            instance_type=self.instance_type
        )
        
    def _validate_endpoint_name(self, endpoint_name):
        """엔드포인트 이름 유효성 검사"""
        if not endpoint_name:
            raise ValueError("Endpoint name cannot be empty")
        if len(endpoint_name) > 63:
            raise ValueError("Endpoint name must be less than 64 characters")
        if not endpoint_name[0].isalpha():
            raise ValueError("Endpoint name must start with a letter")
            
    def create_endpoint(self, endpoint_name, model_prefix="models/triton", 
                       instance_count=1, startup_timeout=300, initial_model=None,
                       wait_for_creation=True, tags=None):
        """
        MME 엔드포인트 생성
        
        Args:
            endpoint_name (str): 생성할 엔드포인트 이름
            model_prefix (str): S3 모델 저장 경로 접두사
            instance_count (int): 인스턴스 수
            startup_timeout (int): 시작 타임아웃(초)
            initial_model (dict): 초기 모델 정보 (선택사항)
            wait_for_creation (bool): 생성 완료까지 대기 여부
            tags (list): 엔드포인트에 적용할 태그 리스트
        """
        try:
            self._validate_endpoint_name(endpoint_name)
            
            # Triton 이미지 가져오기
            triton_image_uri = self._get_triton_image()
            
            # MME 모델 설정
            mme_name = f"{endpoint_name}-models"
            
            # 태그 설정
            deployment_tags = tags or []
            deployment_tags.append({
                'Key': 'CreatedBy',
                'Value': 'TritonEndpointManager'
            })
            deployment_tags.append({
                'Key': 'CreatedAt',
                'Value': datetime.utcnow().isoformat()
            })
            
            mme = MultiDataModel(
                name=mme_name,
                model_data_prefix=f's3://{self.default_bucket}/{model_prefix}/',
                image_uri=triton_image_uri,
                role=self.role
            )
            
            # 엔드포인트 배포
            predictor = mme.deploy(
                initial_instance_count=instance_count,
                instance_type=self.instance_type,
                endpoint_name=endpoint_name,
                tags=deployment_tags,
                wait=wait_for_creation
            )
            
            if wait_for_creation:
                self._wait_for_endpoint(endpoint_name, startup_timeout)
            
            print(f"Successfully created:")
            print(f"- Endpoint: {endpoint_name}")
            print(f"- Model Repository: {mme_name}")
            print(f"- S3 Model Prefix: {model_prefix}")

            return predictor

        except Exception as e:
            print(f"Error creating endpoint: {e}")
            raise
            
    def _wait_for_endpoint(self, endpoint_name, timeout):
        """엔드포인트 생성 완료 대기"""
        start_time = time.time()
        while True:
            if time.time() - start_time > timeout:
                raise TimeoutError(f"Endpoint creation timed out after {timeout} seconds")
                
            status = self.get_endpoint_status(endpoint_name)
            if status == 'InService':
                break
            elif status in ['Failed', 'OutOfService']:
                raise RuntimeError(f"Endpoint creation failed with status: {status}")
                
            time.sleep(30)
    
    def get_endpoint_status(self, endpoint_name):
        """엔드포인트 상태 조회"""
        try:
            self._validate_endpoint_name(endpoint_name)
            response = self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
            return response['EndpointStatus']
        except Exception as e:
            print(f"Error getting endpoint status: {e}")
            raise

            
    def update_endpoint(self, endpoint_name, instance_count=None, instance_type=None):
        """
        엔드포인트 설정 업데이트
        
        Args:
            endpoint_name (str): 업데이트할 엔드포인트 이름
            instance_count (int): 변경할 인스턴스 수
            instance_type (str): 변경할 인스턴스 타입
        """
        try:
            self._validate_endpoint_name(endpoint_name)
            
            update_config = {}
            if instance_count is not None:
                update_config['DesiredInstanceCount'] = instance_count
            if instance_type is not None:
                update_config['InstanceType'] = instance_type
                
            if update_config:
                self.sagemaker_client.update_endpoint(
                    EndpointName=endpoint_name,
                    EndpointConfig=update_config
                )
                print(f"Endpoint {endpoint_name} update initiated")
                
        except Exception as e:
            print(f"Error updating endpoint: {e}")
            raise
             
            
    # 엔드포인트와 설정 삭제 메서드 추가
    def delete_endpoint_and_config(self, endpoint_name):
        """엔드포인트와 설정 삭제"""
        try:
            # 엔드포인트 삭제
            try:
                self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
                print(f"Endpoint {endpoint_name} deletion initiated")
            except self.sagemaker_client.exceptions.ClientError as e:
                if "Could not find endpoint" not in str(e):
                    raise

            # 엔드포인트 설정 삭제
            try:
                self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
                print(f"Endpoint config {endpoint_name} deleted")
            except self.sagemaker_client.exceptions.ClientError as e:
                if "Could not find endpoint configuration" not in str(e):
                    raise

            print(f"Successfully deleted endpoint and config: {endpoint_name}")

        except Exception as e:
            print(f"Error deleting endpoint and config: {e}")
            raise
    
    def get_endpoint_metrics(self, endpoint_name, start_time=None, end_time=None):
        """
        엔드포인트 메트릭 조회
        
        Args:
            endpoint_name (str): 엔드포인트 이름
            start_time (datetime): 시작 시간
            end_time (datetime): 종료 시간
        """
        try:
            cloudwatch = boto3.client('cloudwatch', region_name=self.region_name)
            
            metrics = cloudwatch.get_metric_data(
                MetricDataQueries=[
                    {
                        'Id': 'invocations',
                        'MetricStat': {
                            'Metric': {
                                'Namespace': 'AWS/SageMaker',
                                'MetricName': 'Invocations',
                                'Dimensions': [
                                    {'Name': 'EndpointName', 'Value': endpoint_name}
                                ]
                            },
                            'Period': 300,
                            'Stat': 'Sum'
                        }
                    },
                    {
                        'Id': 'latency',
                        'MetricStat': {
                            'Metric': {
                                'Namespace': 'AWS/SageMaker',
                                'MetricName': 'ModelLatency',
                                'Dimensions': [
                                    {'Name': 'EndpointName', 'Value': endpoint_name}
                                ]
                            },
                            'Period': 300,
                            'Stat': 'Average'
                        }
                    }
                ],
                StartTime=start_time or datetime.utcnow().replace(hour=0, minute=0, second=0),
                EndTime=end_time or datetime.utcnow()
            )
            
            return metrics
            
        except Exception as e:
            print(f"Error getting endpoint metrics: {e}")
            raise

In [9]:
# 매니저 초기화
endpoint_manager = TritonEndpointManager(
    region_name='us-east-1'
)


In [None]:
# 기존 엔드포인트와 설정 삭제
# endpoint_manager.delete_endpoint_and_config('tts-endpoint')

# 새 엔드포인트 생성
predictor = endpoint_manager.create_endpoint(
    endpoint_name='tts-endpoint',
    model_prefix='models/tts-models',
    instance_count=1,
    startup_timeout=300,
    wait_for_creation=False  # 기본값은 True입니다

)

In [122]:
status = endpoint_manager.get_endpoint_status('tts-endpoint')
print(f"Endpoint status: {status}")

Endpoint status: Creating


In [123]:
def wait_for_endpoint(endpoint_name, check_interval=30):
    while True:
        status = endpoint_manager.get_endpoint_status(endpoint_name)
        print(f"Status: {status}")
        
        if status == 'InService':
            return True
        elif status in ['Failed', 'OutOfService']:
            return False
            
        time.sleep(check_interval)

# 사용
is_ready = wait_for_endpoint('tts-endpoint')
if is_ready:
    print("Endpoint is ready for use!")
else:
    print("Endpoint creation failed")

Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Endpoint is ready for use!


## 실제 모델 추가 및 관리

In [78]:
import boto3
import sagemaker
from sagemaker.multidatamodel import MultiDataModel
from sagemaker import image_uris
import time
import os
import tarfile
import json
import numpy as np
from datetime import datetime

class TritonMMEManager:
    def __init__(self, role, endpoint_name, region_name, default_bucket, base_prefix='models/triton'):
        self.role = role
        self.endpoint_name = endpoint_name
        self.region_name = region_name
        self.default_bucket = default_bucket
        self.base_prefix = base_prefix.strip('/')
        self.sagemaker_client = boto3.client('sagemaker')
        self.runtime_client = boto3.client('sagemaker-runtime')
        self.s3_client = boto3.client('s3')
        self.model_keys = {} # 모델 키 저장을 위한 딕셔너리 추가


    def _load_model(self, model_key, s3_uri):
        """엔드포인트에 모델 로드 요청"""
        try:
            print(f"Requesting model load for: {s3_uri}")
            print(f"Using model key: {model_key}")
            
            # inference 코드의 형식을 참고한 입력 데이터
            load_request = {
                "inputs": [
                    {
                        "name": "x",
                        "shape": [1, 1],
                        "datatype": "INT64",
                        "data": [[0]]  # 더미 데이터
                    },
                    {
                        "name": "x_length",
                        "shape": [1, 1],
                        "datatype": "INT64",
                        "data": [[1]]  # 더미 데이터
                    },
                    {
                        "name": "noise_scale",
                        "shape": [1, 1],
                        "datatype": "FP32",
                        "data": [[0.667]]  # 기본값
                    },
                    {
                        "name": "length_scale",
                        "shape": [1, 1],
                        "datatype": "FP32",
                        "data": [[1.0]]  # 기본값
                    },
                    {
                        "name": "noise_scale_w",
                        "shape": [1, 1],
                        "datatype": "FP32",
                        "data": [[0.8]]  # 기본값
                    }
                ]
            }
            
            print(f"Load request: {json.dumps(load_request, indent=2)}")
            
            response = self.runtime_client.invoke_endpoint(
                EndpointName=self.endpoint_name,
                ContentType='application/json',
                TargetModel=model_key,
                Body=json.dumps(load_request)
            )
            
            print("Load request sent successfully")
            return response
            
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
            
    def _wait_for_model_load(self, model_key, s3_uri, timeout=300, check_interval=10):
        """모델 로드 완료 대기"""
        start_time = time.time()
        while time.time() - start_time < timeout:
            try:
                # 상태 확인도 같은 형식 사용
                status_request = {
                    "inputs": [
                        {
                            "name": "x",
                            "shape": [1, 1],
                            "datatype": "INT64",
                            "data": [[0]]
                        },
                        {
                            "name": "x_length",
                            "shape": [1, 1],
                            "datatype": "INT64",
                            "data": [[1]]
                        },
                        {
                            "name": "noise_scale",
                            "shape": [1, 1],
                            "datatype": "FP32",
                            "data": [[0.667]]
                        },
                        {
                            "name": "length_scale",
                            "shape": [1, 1],
                            "datatype": "FP32",
                            "data": [[1.0]]
                        },
                        {
                            "name": "noise_scale_w",
                            "shape": [1, 1],
                            "datatype": "FP32",
                            "data": [[0.8]]
                        }
                    ]
                }
                
                print(f"Checking status for model at: {s3_uri}")
                
                response = self.runtime_client.invoke_endpoint(
                    EndpointName=self.endpoint_name,
                    ContentType='application/json',
                    TargetModel=model_key,
                    Body=json.dumps(status_request)
                )
                
                if response['ResponseMetadata']['HTTPStatusCode'] == 200:
                    response_body = json.loads(response['Body'].read().decode())
                    print(f"Status check response: {response_body}")
                    print(f"Model loaded successfully at: {s3_uri}")
                    return True
                    
            except Exception as e:
                print(f"Waiting for model to load... ({e})")
                time.sleep(check_interval)
            
        raise TimeoutError(f"Model failed to load within {timeout} seconds")

    def _get_s3_paths(self, unique_name):
        """S3 관련 경로들을 생성"""
        model_key = f"{unique_name}/model.tar.gz"
        s3_key = f"{self.base_prefix}/{model_key}"
        s3_uri = f"s3://{self.default_bucket}/{s3_key}"
        return model_key, s3_key, s3_uri

    def add_new_model(self, base_model_name, model_path, config_path):
        """새로운 모델을 MultiModel Endpoint에 추가"""
        try:
            # 1. 고유한 이름 생성
            unique_name, model_dir = self._generate_unique_model_path(base_model_name)
            
            # 2. 모델 패키징
            print(f"Packaging model {unique_name}...")
            packaged_model = self._package_model(
                model_path, 
                config_path, 
                f"{unique_name}.tar.gz",
                model_dir
            )
            
            # 3. S3 경로 생성
            model_key, s3_key, s3_uri = self._get_s3_paths(unique_name)
            
            print(f"Uploading model to: {s3_uri}")
            print(f"Using S3 key: {s3_key}")
            print(f"Target model key: {model_key}")
            
            # 4. S3에 업로드
            self.s3_client.upload_file(
                packaged_model,
                self.default_bucket,
                s3_key
            )
            
            # 5. 업로드 확인
            try:
                self.s3_client.head_object(Bucket=self.default_bucket, Key=s3_key)
                print(f"Confirmed model upload to: {s3_uri}")
            except Exception as e:
                raise Exception(f"Failed to verify model upload: {e}")
            
            # 6. 모델 로드
            print(f"Loading model {unique_name} to endpoint...")
            self._load_model(model_key, s3_uri)
            
            # 7. 로드 상태 확인
            self._wait_for_model_load(model_key, s3_uri)
            
            # 8. 임시 파일 정리
            os.remove(packaged_model)
                        
            # 모델 키 저장
            self.model_keys[base_model_name] = {
                'key': model_key,
                'added_at': datetime.now().isoformat(),
                's3_uri': s3_uri
            }
            
            print(f"Successfully added model {unique_name} to endpoint {self.endpoint_name}")
            print(f"Model URI: {s3_uri}")
            print(f"Target model key: {model_key}")
            
            return model_key
            
        except Exception as e:
            print(f"Error adding new model: {e}")
            raise

    def _generate_unique_model_path(self, base_name):
        """고유한 모델 경로 생성"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        unique_name = f"{base_name}_{timestamp}"
        model_dir = f"tts-vits_{timestamp}"
        return unique_name, model_dir

    def _package_model(self, model_path, config_path, output_name, model_dir):
        """Triton 형식으로 모델 패키징"""
        version_dir = os.path.join(model_dir, '1')
        os.makedirs(version_dir, exist_ok=True)
        
        import shutil
        shutil.copy2(config_path, os.path.join(model_dir, 'config.pbtxt'))
        shutil.copy2(model_path, os.path.join(version_dir, 'model.pt'))
        
        with tarfile.open(output_name, 'w:gz') as tar:
            tar.add(model_dir, arcname=model_dir)
        
        shutil.rmtree(model_dir)
        return output_name
    
    ## 모델키 관련 함수들
    def get_model_key(self, model_name):
        """저장된 모델 키 조회"""
        return self.model_keys.get(model_name, {}).get('key')
        
    def list_models(self):
        """저장된 모든 모델 정보 조회"""
        return self.model_keys
        
    def xf(self, file_path='model_keys.json'):
        """모델 키를 파일로 저장"""
        try:
            with open(file_path, 'w') as f:
                json.dump(self.model_keys, f, indent=2)
            print(f"Model keys saved to {file_path}")
        except Exception as e:
            print(f"Error saving model keys: {e}")
            
    def load_keys_from_file(self, file_path='model_keys.json'):
        """파일에서 모델 키 로드"""
        try:
            with open(file_path, 'r') as f:
                self.model_keys = json.load(f)
            print(f"Model keys loaded from {file_path}")
        except FileNotFoundError:
            print(f"No saved keys found at {file_path}")
        except Exception as e:
            print(f"Error loading model keys: {e}")


In [79]:
# 2. MME 매니저로 모델 추가 및 관리
mme_manager = TritonMMEManager(
    role=endpoint_manager.role,                    # SageMaker 실행 역할
    endpoint_name='tts-endpoint',          # 위에서 생성한 엔드포인트 이름
    region_name='us-east-1',                      # AWS 리전
    default_bucket=endpoint_manager.default_bucket, # S3 버킷
    base_prefix='models/tts-models'               # 모델 저장 경로 (위에서 지정한 것과 동일하게)
)

In [98]:
# 새 모델 추가
model_key = mme_manager.add_new_model(
    base_model_name='tts-vits',
    model_path="triton-serve-jit/tts-vits/1/model.pt",
    config_path="triton-serve-jit/tts-vits/config.pbtxt"
)

Packaging model tts-vits_20241101_053049...
Uploading model to: s3://sagemaker-us-east-1-603420654815/models/tts-models/tts-vits_20241101_053049/model.tar.gz
Using S3 key: models/tts-models/tts-vits_20241101_053049/model.tar.gz
Target model key: tts-vits_20241101_053049/model.tar.gz
Confirmed model upload to: s3://sagemaker-us-east-1-603420654815/models/tts-models/tts-vits_20241101_053049/model.tar.gz
Loading model tts-vits_20241101_053049 to endpoint...
Requesting model load for: s3://sagemaker-us-east-1-603420654815/models/tts-models/tts-vits_20241101_053049/model.tar.gz
Using model key: tts-vits_20241101_053049/model.tar.gz
Load request: {
  "inputs": [
    {
      "name": "x",
      "shape": [
        1,
        1
      ],
      "datatype": "INT64",
      "data": [
        [
          0
        ]
      ]
    },
    {
      "name": "x_length",
      "shape": [
        1,
        1
      ],
      "datatype": "INT64",
      "data": [
        [
          1
        ]
      ]
    },
    

In [73]:
%%writefile monitoring.py
import boto3

class MMEModelMonitor:
    def __init__(self, endpoint_name):
        self.endpoint_name = endpoint_name
        self.cloudwatch = boto3.client('cloudwatch')

    def log_model_metrics(self, model_key, metrics_data):
        """
        개별 모델의 메트릭을 CloudWatch에 기록
        
        Args:
            model_key: 모델 식별자 (예: tts-model_20241030/model.tar.gz)
            metrics_data: 기록할 메트릭 데이터
        """
        try:
            self.cloudwatch.put_metric_data(
                Namespace='Custom/MMEModels',  # 사용자 정의 네임스페이스
                MetricData=[
                    {
                        'MetricName': 'ModelInvocations',
                        'Value': 1,
                        'Unit': 'Count',
                        'Dimensions': [
                            {'Name': 'EndpointName', 'Value': self.endpoint_name},
                            {'Name': 'ModelKey', 'Value': model_key}
                        ]
                    },
                    {
                        'MetricName': 'ProcessingTime',
                        'Value': metrics_data.get('processing_time', 0),
                        'Unit': 'Milliseconds',
                        'Dimensions': [
                            {'Name': 'EndpointName', 'Value': self.endpoint_name},
                            {'Name': 'ModelKey', 'Value': model_key}
                        ]
                    },
                    {
                        'MetricName': 'MemoryUsage',
                        'Value': metrics_data.get('memory_usage', 0),
                        'Unit': 'Megabytes',
                        'Dimensions': [
                            {'Name': 'EndpointName', 'Value': self.endpoint_name},
                            {'Name': 'ModelKey', 'Value': model_key}
                        ]
                    }
                ]
            )
        except Exception as e:
            print(f"Error logging metrics: {e}")

Overwriting monitoring.py


In [74]:
%%writefile inference.py

import boto3
import numpy as np
import json
import IPython.display as ipd
from IPython.display import display
from text import text_to_sequence
import time
import sys
import torch
import commons
import utils
from monitoring import MMEModelMonitor

class VITSInference:
    def __init__(self, endpoint_name, target_model, config_path="vits/configs/ljs_base.json"):

        """
        VITS Inference 클래스 초기화
        
        Args:
            endpoint_name: SageMaker 엔드포인트 이름
            target_model: 타겟 모델 경로
        """
        self.runtime_client = boto3.client('sagemaker-runtime')
        self.endpoint_name = endpoint_name
        self.target_model = target_model
        # 모델 설정 로드
        self.hps = utils.get_hparams_from_file(config_path)
        self.monitor = MMEModelMonitor(endpoint_name)
        print(f"Initialized VITSInference with endpoint: {endpoint_name}, model: {target_model}")

    def get_text(self, text):
        """텍스트 전처리"""
        text_norm = text_to_sequence(text, self.hps.data.text_cleaners)
        if self.hps.data.add_blank:
            text_norm = commons.intersperse(text_norm, 0)
        text_norm = torch.LongTensor(text_norm)
        return text_norm
    
    def prepare_input(self, text, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0):
        """입력 데이터 준비"""
        try:
            # 1. 텍스트 전처리
            stn_tst = self.get_text(text)
            x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
            
            # 2. NumPy 배열로 변환
            x_np = stn_tst.numpy().reshape(1, -1)
            x_length_np = x_tst_lengths.numpy()
            
            # 3. Triton 형식에 맞게 입력 데이터 구성
            input_data = {
                "inputs": [
                    {
                        "name": "x",
                        "shape": [1, x_np.shape[1]],
                        "datatype": "INT64",
                        "data": x_np.tolist()
                    },
                    {
                        "name": "x_length",
                        "shape": [1, 1],
                        "datatype": "INT64",
                        "data": x_length_np.reshape(-1, 1).tolist()
                    },
                    {
                        "name": "noise_scale",
                        "shape": [1, 1],
                        "datatype": "FP32",
                        "data": [[float(noise_scale)]]
                    },
                    {
                        "name": "length_scale",
                        "shape": [1, 1],
                        "datatype": "FP32",
                        "data": [[float(length_scale)]]
                    },
                    {
                        "name": "noise_scale_w",
                        "shape": [1, 1],
                        "datatype": "FP32",
                        "data": [[float(noise_scale_w)]]
                    }
                ]
            }
            return input_data
            
        except Exception as e:
            print(f"입력 데이터 준비 중 오류 발생: {e}")
            raise
                  
    def invoke_endpoint(self, input_data):
        """엔드포인트 호출"""
        try:
            print(f"Invoking endpoint {self.endpoint_name}")
            print(f"Input data shape: x={input_data['inputs'][0]['shape']}, "
                  f"x_length={input_data['inputs'][1]['shape']}")
            
            response = self.runtime_client.invoke_endpoint(
                EndpointName=self.endpoint_name,
                ContentType='application/json',
                TargetModel=self.target_model,
                Body=json.dumps(input_data)
            )
            
            print("Response received from endpoint")
            result = json.loads(response['Body'].read().decode())
            
            # 결과 검증
            if 'outputs' not in result:
                raise ValueError("Invalid response format: 'outputs' not found")
            if not result['outputs'] or 'data' not in result['outputs'][0]:
                raise ValueError("Invalid response format: no data in outputs")

            # 검증이 통과된 경우에만 결과 반환
            return {"audio_output": np.array(result['outputs'][0]['data'])}

        except Exception as e:
            print(f"Endpoint 호출 중 오류 발생: {e}")
            raise
            
    def generate_audio(self, text, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0):
        """음성 생성"""
        try:
            print(f"Generating audio for text: {text}")
            
            start_time = time.time()
            
            # 1. 입력 데이터 준비
            input_data = self.prepare_input(text, noise_scale, noise_scale_w, length_scale)
            print("Input data prepared successfully")
            
            # 2. 추론 실행
            result = self.invoke_endpoint(input_data)
            print("Inference completed successfully")
            
            # 3. 처리 시간 계산
            processing_time = (time.time() - start_time) * 1000  # 밀리초 단위
            
            # 4. 메트릭 기록
            self.monitor.log_model_metrics(
                model_key=self.target_model,
                metrics_data={
                    'processing_time': processing_time,
                    'memory_usage': sys.getsizeof(result['audio_output']) / (1024 * 1024)  # MB 단위
                }
            )
            
            # 3. 오디오 데이터 변환
            audio_data = np.array(result['audio_output'])
            print(f"Audio data shape: {audio_data.shape}")
            
            # 4. 오디오 재생
            return ipd.Audio(audio_data, rate=self.hps.data.sampling_rate, normalize=False)
            
        except Exception as e:
            print(f"음성 생성 중 오류 발생: {e}")
            raise

Overwriting inference.py


In [99]:
import boto3

s3 = boto3.resource('s3')
bucket = s3.Bucket('sagemaker-us-east-1-603420654815')
base_prefix = 'models/tts-models/'

# base_prefix 이후의 경로만 출력
for obj in bucket.objects.filter(Prefix=base_prefix):
    # base_prefix 길이만큼 잘라내고 출력
    relative_path = obj.key[len(base_prefix):]
    if relative_path:  # 빈 문자열이 아닌 경우만 출력
        print(relative_path)    

tts-vits_20241101_004530/model.tar.gz
tts-vits_20241101_012352/model.tar.gz
tts-vits_20241101_031137/model.tar.gz
tts-vits_20241101_053049/model.tar.gz


In [150]:
!pip install boto3



In [100]:
import boto3
import numpy as np
import time
import sys
import json
import IPython.display as ipd
from IPython.display import display
from text import text_to_sequence
import commons
import utils
from monitoring import MMEModelMonitor
from inference import VITSInference  # VITSInference 클래스 import


# 모델 키 정의
models = {
    "model1": "tts-vits_20241101_004530/model.tar.gz",
    "model2": "tts-vits_20241101_004530/model.tar.gz",
    "model3":"tts-vits_20241101_031137/model.tar.gz",
    "model4":"tts-vits_20241101_053049/model.tar.gz"

}

test_text = (
    "Amazon Bedrock is a fully managed service that offers "
    "a choice of high-performing foundation models from leading AI companies"
)

for model_name, model_key in models.items():
    print(f"\n=== Testing {model_name} ===")
    try:
        vits = VITSInference(
            endpoint_name='tts-endpoint',
            target_model=model_key
        )
        audio = vits.generate_audio(test_text)
        display(audio)
        
    except Exception as e:
        print(f"Error with {model_name}: {e}")


=== Testing model1 ===
Initialized VITSInference with endpoint: tts-endpoint, model: tts-vits_20241101_004530/model.tar.gz
Generating audio for text: Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models from leading AI companies
Input data prepared successfully
Invoking endpoint tts-endpoint
Input data shape: x=[1, 267], x_length=[1, 1]
Response received from endpoint
Inference completed successfully
Audio data shape: (179968,)



=== Testing model2 ===
Initialized VITSInference with endpoint: tts-endpoint, model: tts-vits_20241101_004530/model.tar.gz
Generating audio for text: Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models from leading AI companies
Input data prepared successfully
Invoking endpoint tts-endpoint
Input data shape: x=[1, 267], x_length=[1, 1]
Response received from endpoint
Inference completed successfully
Audio data shape: (183296,)



=== Testing model3 ===
Initialized VITSInference with endpoint: tts-endpoint, model: tts-vits_20241101_031137/model.tar.gz
Generating audio for text: Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models from leading AI companies
Input data prepared successfully
Invoking endpoint tts-endpoint
Input data shape: x=[1, 267], x_length=[1, 1]
Response received from endpoint
Inference completed successfully
Audio data shape: (179968,)



=== Testing model4 ===
Initialized VITSInference with endpoint: tts-endpoint, model: tts-vits_20241101_053049/model.tar.gz
Generating audio for text: Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models from leading AI companies
Input data prepared successfully
Invoking endpoint tts-endpoint
Input data shape: x=[1, 267], x_length=[1, 1]
Response received from endpoint
Inference completed successfully
Audio data shape: (189440,)
