In [None]:
%pip -q install sagemaker==2.219.0 boto3

import boto3, sagemaker, json, os, pandas as pd, subprocess, tarfile, time
from sagemaker.estimator import Estimator

session = sagemaker.Session()
region = session.boto_region_name
account = boto3.client("sts").get_caller_identity()["Account"]
role = sagemaker.get_execution_role()

repository = "wine-quality-custom"
ecr = boto3.client("ecr", region_name=region)

# Create ECR repo if absent
try:
    ecr.create_repository(repositoryName=repository)
except ecr.exceptions.RepositoryAlreadyExistsException:
    pass

ecr_uri = f"{account}.dkr.ecr.{region}.amazonaws.com/{repository}:latest"
ecr_uri


In [None]:
# Authenticate Docker to ECR
auth = boto3.client("ecr").get_authorization_token()
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account}.dkr.ecr.{region}.amazonaws.com


In [None]:
# Build & push image from sagemaker/custom_container/
!cd ../custom_container && docker build -t {repository}:latest .
!docker tag {repository}:latest {ecr_uri}
!docker push {ecr_uri}


In [None]:
# Prepare training data to S3
UCI_RED = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
UCI_WHITE = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
red = pd.read_csv(UCI_RED, sep=";")
white = pd.read_csv(UCI_WHITE, sep=";")
df = pd.concat([red, white], ignore_index=True)
df.to_csv("train.csv", index=False)

bucket = session.default_bucket()
prefix = "wine-quality-custom"
s3_train = session.upload_data("train.csv", bucket=bucket, key_prefix=f"{prefix}/input/train")
s3_train


In [None]:
# Run SageMaker training job using your custom image
est = Estimator(
    image_uri=ecr_uri,
    role=role,
    instance_count=1,
    instance_type="ml.m5.large",
    sagemaker_session=session,
    output_path=f"s3://{bucket}/{prefix}/output",
    environment={
        # Optional envs your program may read
    },
)

# The container should expect /opt/ml/input/data/train/train.csv (your program/train.py reads remote data itself now)
est.fit({"train": s3_train})
