In [42]:
import os
import time
import boto3
import sagemaker
from time import gmtime, strftime

In [43]:
boto_session = boto3.Session(region_name='ap-northeast-2')

In [44]:
sts = boto_session.client('sts')
accountId = sts.get_caller_identity()['Account']

In [45]:
session = sagemaker.Session(boto_session=boto_session)

In [57]:
bucket = session.default_bucket()
bucket_path = f's3://{bucket}/'
job_name = f'yolo-{strftime("%Y-%m-%d-%H-%M-%S", gmtime())}'
print('Training job', job_name)

Training job yolo-2023-02-16-02-24-23


In [60]:
try:
    role = sagemaker.get_execution_role(sagemaker_session=session)
except:
    iam = boto_session.client('iam')
    role = iam.get_role(RoleName='AmazonSageMaker-ExecutionRole-20230112T181165')['Role']['Arn']

Couldn't call 'get_role' to get Role ARN from role name dongkyl to get Role path.


In [61]:
training_img = f'{accountId}.dkr.ecr.ap-northeast-2.amazonaws.com/yolov8-training-gpu:latest'

In [62]:
training_config = {
    "TrainingJobName": job_name,
    "AlgorithmSpecification": {
        "TrainingImage": training_img,
        "TrainingInputMode": "File"
    },
    "RoleArn": role,
    "OutputDataConfig": {"S3OutputPath": f"s3://{bucket}/yolov8"},
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": f"s3://{bucket}",
                    "S3DataDistributionType": "FullyReplicated",
                }
            },
            "ContentType": "image/jpeg"
        }
    ],
    "ResourceConfig": {
        "InstanceType": "ml.m5.xlarge",
        "InstanceCount": 1,
        "VolumeSizeInGB": 75,
     },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 86400
    }
}

In [64]:
client = boto_session.client('sagemaker')
client.create_training_job(**training_config)

status = client.describe_training_job(TrainingJobName=job_name)["TrainingJobStatus"]
print(status)
while status != "Completed" and status != "Failed":
    time.sleep(60)
    status = client.describe_training_job(TrainingJobName=job_name)["TrainingJobStatus"]
    print(status)

InProgress
InProgress
InProgress
InProgress
Completed


# Test

In [65]:
!wget https://raw.githubusercontent.com/ultralytics/ultralytics/main/ultralytics/models/v8/yolov8n.yaml

--2023-02-16 11:29:13--  https://raw.githubusercontent.com/ultralytics/ultralytics/main/ultralytics/models/v8/yolov8n.yaml
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1203 (1.2K) [text/plain]
Saving to: ‘yolov8n.yaml’


2023-02-16 11:29:13 (26.7 MB/s) - ‘yolov8n.yaml’ saved [1203/1203]



In [66]:
import torch
from ultralytics import YOLO

In [67]:
model = YOLO('yolov8n.yaml')


                   from  n    params  module                                       arguments                     
  0                  -1  1       464  ultralytics.nn.modules.Conv                  [3, 16, 3, 2]                 
  1                  -1  1      4672  ultralytics.nn.modules.Conv                  [16, 32, 3, 2]                
  2                  -1  1      7360  ultralytics.nn.modules.C2f                   [32, 32, 1, True]             
  3                  -1  1     18560  ultralytics.nn.modules.Conv                  [32, 64, 3, 2]                
  4                  -1  2     49664  ultralytics.nn.modules.C2f                   [64, 64, 2, True]             
  5                  -1  1     73984  ultralytics.nn.modules.Conv                  [64, 128, 3, 2]               
  6                  -1  2    197632  ultralytics.nn.modules.C2f                   [128, 128, 2, True]           
  7                  -1  1    295424  ultralytics.nn.modules.Conv                  [128

In [80]:
s3 = boto_session.resource('s3')
b = s3.Bucket(name=bucket)
b.download_file(Key=f'yolov8/{job_name}/output/model.tar.gz', Filename='model.tar.gz')

In [81]:
!tar zxvf model.tar.gz

x model.pth


In [82]:
model.model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>

In [84]:
results = model.predict('https://ultralytics.com/images/bus.jpg', show=True)
results


Found https://ultralytics.com/images/bus.jpg locally at bus.jpg
image 1/1 /Users/dongkyl/git/sagemaker-custom-docker-yolov8/bus.jpg: 640x480 4 0s, 1 5, 1 11, 55.5ms
Speed: 0.5ms pre-process, 55.5ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 640)
