# SageMaker Training Job 訓練 Ultralytics YOLO11（YOLOv11）

此 Notebook 會：
1. 生成訓練程式 `train.py` 與 `requirements.txt`
2. 以 **SageMaker PyTorch Estimator** 啟動 Training Job
3. 在 Notebook 內**串流顯示訓練日誌**，並**嘗試動態解析每個 epoch 指標，畫出收斂曲線**
4. 訓練完成後下載 artifacts（含 `best.pt`、`runs/`、`results.csv`、`results.png`）並顯示完整曲線

> 注意：動態曲線解析依賴 Ultralytics 日誌格式；若版本輸出不同，可能解析不到，但**訓練完成後**一定能用 `results.csv/results.png` 畫出完整收斂曲線。


## 0) 先決條件

- 已在 SageMaker Studio / Notebook Instance 執行
- IAM role 具備 SageMaker 與 S3 權限
- 你已把 YOLO 格式資料集打包成 zip 放在 S3

資料 zip 內部建議結構：
```
mydata/
  images/train
  images/val
  labels/train
  labels/val
```


In [None]:
import os, json, tarfile, re, time
from pathlib import Path

import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import TrainingInput
from sagemaker.s3 import S3Downloader

import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, clear_output, Image

sess = sagemaker.Session()
region = sess.boto_region_name
role = sagemaker.get_execution_role()
sm = boto3.client('sagemaker', region_name=region)
logs = boto3.client('logs', region_name=region)

print('region:', region)
print('default bucket:', sess.default_bucket())


## 1) 參數設定（請改成你的 S3 路徑與類別名稱）

In [None]:
# ====== 你要改的地方 ======
S3_DATA_ZIP = 's3://YOUR-BUCKET/datasets/mydata.zip'  # <<<<<< 改成你的資料集
DATA_ZIP_FILENAME = 'mydata.zip'                     # zip 檔名

# 類別名稱（請改成你的 classes）
CLASS_NAMES = ['class0', 'class1']

# YOLO11 權重（可改 yolo11s.pt / yolo11m.pt / yolo11l.pt / yolo11x.pt）
YOLO_MODEL = 'yolo11n.pt'

# 訓練超參數
EPOCHS = 50
IMGSZ = 640
BATCH = 16
WORKERS = 4

# 訓練硬體（推薦 GPU：g5 / g4dn）
INSTANCE_TYPE = 'ml.g5.2xlarge'
INSTANCE_COUNT = 1

# 輸出到 S3
OUTPUT_S3 = f"s3://{sess.default_bucket()}/yolo11/output/"

print('S3_DATA_ZIP:', S3_DATA_ZIP)
print('OUTPUT_S3:', OUTPUT_S3)


## 2) 生成訓練程式碼（train.py）與 requirements.txt

重點：
- SageMaker 會把 `training` channel 掛載在 `/opt/ml/input/data/training/`
- 你要把 artifacts 寫到 `/opt/ml/model/`，SageMaker 才會上傳到 `output_path`
- 我們把 Ultralytics 的 `runs/` 放到 `/opt/ml/model/runs/`，方便事後下載 `results.csv` / `results.png`


In [None]:
src_dir = Path('src')
src_dir.mkdir(exist_ok=True)

train_py = r'''
import argparse
import zipfile
import shutil
from pathlib import Path
import yaml
from ultralytics import YOLO

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--data_zip', type=str, default='mydata.zip')
    p.add_argument('--model', type=str, default='yolo11n.pt')
    p.add_argument('--epochs', type=int, default=50)
    p.add_argument('--imgsz', type=int, default=640)
    p.add_argument('--batch', type=int, default=16)
    p.add_argument('--workers', type=int, default=4)
    p.add_argument('--class_names_json', type=str, required=True)
    args = p.parse_args()

    sm_data_dir = Path('/opt/ml/input/data/training')
    sm_model_dir = Path('/opt/ml/model')
    sm_model_dir.mkdir(parents=True, exist_ok=True)

    work_dir = Path('/tmp/data')
    work_dir.mkdir(parents=True, exist_ok=True)

    zip_path = sm_data_dir / args.data_zip
    if not zip_path.exists():
        raise FileNotFoundError(f'Dataset zip not found: {zip_path}')

    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(work_dir)

    entries = [d for d in work_dir.iterdir() if d.is_dir()]
    dataset_root = entries[0] if entries else work_dir

    class_names = __import__('json').loads(args.class_names_json)
    data_yaml = {
        'path': str(dataset_root),
        'train': str(dataset_root / 'images/train'),
        'val': str(dataset_root / 'images/val'),
        'names': class_names,
    }
    yaml_path = work_dir / 'data.yaml'
    with open(yaml_path, 'w') as f:
        yaml.safe_dump(data_yaml, f)

    model = YOLO(args.model)
    project_dir = sm_model_dir / 'runs'
    model.train(
        data=str(yaml_path),
        epochs=args.epochs,
        imgsz=args.imgsz,
        batch=args.batch,
        workers=args.workers,
        project=str(project_dir),
        name='train',
        verbose=True,
    )

    best_pt = project_dir / 'train' / 'weights' / 'best.pt'
    if best_pt.exists():
        shutil.copy2(best_pt, sm_model_dir / 'best.pt')
    last_pt = project_dir / 'train' / 'weights' / 'last.pt'
    if last_pt.exists():
        shutil.copy2(last_pt, sm_model_dir / 'last.pt')

    print('Training complete. Saved best.pt to', sm_model_dir / 'best.pt')

if __name__ == '__main__':
    main()
'''

(src_dir / 'train.py').write_text(train_py, encoding='utf-8')
(src_dir / 'requirements.txt').write_text('ultralytics\nopencv-python\npyyaml\n', encoding='utf-8')
print('Wrote:', src_dir / 'train.py')
print('Wrote:', src_dir / 'requirements.txt')


## 3) 建立 Estimator 並啟動 Training Job（非阻塞）

In [None]:
hyperparameters = {
    'data_zip': DATA_ZIP_FILENAME,
    'model': YOLO_MODEL,
    'epochs': EPOCHS,
    'imgsz': IMGSZ,
    'batch': BATCH,
    'workers': WORKERS,
    'class_names_json': json.dumps(CLASS_NAMES, ensure_ascii=False)
}

est = PyTorch(
    entry_point='train.py',
    source_dir=str(src_dir),
    role=role,
    framework_version='2.2',
    py_version='py310',
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    hyperparameters=hyperparameters,
    output_path=OUTPUT_S3,
)

job_name = f"yolo11-train-{int(time.time())}"
print('TrainingJobName:', job_name)

est.fit(
    inputs={'training': TrainingInput(S3_DATA_ZIP, content_type='application/zip')},
    job_name=job_name,
    wait=False,
)
print('已送出 Training Job。接著跑下一格開始動態監控。')


## 4) 動態顯示訓練過程與收斂曲線（即時）

會從 CloudWatch Logs 拉取最新日誌並嘗試解析 epoch 指標。


In [None]:
def get_training_job_status(job_name: str):
    desc = sm.describe_training_job(TrainingJobName=job_name)
    return desc['TrainingJobStatus'], desc

def find_log_stream(job_name: str, log_group='/aws/sagemaker/TrainingJobs'):
    paginator = logs.get_paginator('describe_log_streams')
    for page in paginator.paginate(logGroupName=log_group, logStreamNamePrefix=job_name):
        streams = page.get('logStreams', [])
        if streams:
            streams = sorted(streams, key=lambda s: s.get('lastEventTimestamp', 0), reverse=True)
            return streams[0]['logStreamName']
    return None

def fetch_logs(log_group, log_stream, next_token=None):
    kwargs = dict(logGroupName=log_group, logStreamName=log_stream, startFromHead=True)
    if next_token:
        kwargs['nextToken'] = next_token
    resp = logs.get_log_events(**kwargs)
    return resp.get('events', []), resp.get('nextForwardToken')

# best-effort 解析：epoch + 後面 5 個浮點數（版本不同可能不完全一致）
epoch_metrics = []
epoch_line_re = re.compile(r"^\s*(\d+)\s+([0-9.eE+-]+)\s+([0-9.eE+-]+)\s+([0-9.eE+-]+)\s+([0-9.eE+-]+)\s+([0-9.eE+-]+)")

def try_parse_epoch_line(msg: str):
    m = epoch_line_re.match(msg)
    if not m:
        return None
    epoch = int(m.group(1))
    vals = [float(m.group(i)) for i in range(2, 7)]
    return {'epoch': epoch, 'v1': vals[0], 'v2': vals[1], 'v3': vals[2], 'v4': vals[3], 'v5': vals[4]}

log_group = '/aws/sagemaker/TrainingJobs'
log_stream = None
token = None
seen = set()
tail = []

print('尋找 CloudWatch Log Stream...')
for _ in range(60):
    log_stream = find_log_stream(job_name, log_group=log_group)
    if log_stream:
        break
    time.sleep(5)

if not log_stream:
    raise RuntimeError('找不到 log stream。請稍後再跑一次這格。')

print('log_stream:', log_stream)

while True:
    status, desc = get_training_job_status(job_name)
    events, token = fetch_logs(log_group, log_stream, token)
    new_lines = []
    for e in events:
        key = (e.get('timestamp'), e.get('message'))
        if key in seen:
            continue
        seen.add(key)
        msg = (e.get('message') or '').rstrip('\n')
        if msg:
            new_lines.append(msg)
            parsed = try_parse_epoch_line(msg)
            if parsed:
                if not epoch_metrics or parsed['epoch'] != epoch_metrics[-1]['epoch']:
                    epoch_metrics.append(parsed)

    tail = (tail + new_lines)[-50:]
    clear_output(wait=True)
    print('TrainingJob:', job_name)
    print('Status:', status)
    print('--- logs tail ---')
    for line in tail:
        print(line)

    if len(epoch_metrics) >= 2:
        df_live = pd.DataFrame(epoch_metrics).sort_values('epoch')
        display(df_live.tail(10))
        plt.figure()
        plt.plot(df_live['epoch'], df_live['v1'], label='v1')
        plt.plot(df_live['epoch'], df_live['v2'], label='v2')
        plt.plot(df_live['epoch'], df_live['v3'], label='v3')
        plt.xlabel('epoch')
        plt.legend()
        plt.title('Live convergence (best-effort parsed from logs)')
        plt.show()

    if status in ['Completed', 'Failed', 'Stopped']:
        print('\nTraining job finished with status:', status)
        if status != 'Completed':
            print('FailureReason:', desc.get('FailureReason'))
        break

    time.sleep(10)


## 5) 下載 artifacts 並畫出完整收斂曲線（results.csv / results.png）

In [None]:
while True:
    status, _ = get_training_job_status(job_name)
    if status in ['Completed', 'Failed', 'Stopped']:
        print('Final status:', status)
        break
    time.sleep(20)

if status != 'Completed':
    raise RuntimeError(f'Training job not completed: {status}')

desc = sm.describe_training_job(TrainingJobName=job_name)
model_artifact = desc['ModelArtifacts']['S3ModelArtifacts']
print('Model artifact:', model_artifact)

out_dir = Path('artifacts') / job_name
out_dir.mkdir(parents=True, exist_ok=True)

S3Downloader.download(model_artifact, str(out_dir))
local_tar = out_dir / 'model.tar.gz'
print('Downloaded:', local_tar)

with tarfile.open(local_tar, 'r:gz') as t:
    t.extractall(path=out_dir)

best_pt = out_dir / 'best.pt'
results_png = out_dir / 'runs' / 'train' / 'results.png'
results_csv = out_dir / 'runs' / 'train' / 'results.csv'

print('best.pt exists:', best_pt.exists())
print('results.png exists:', results_png.exists())
print('results.csv exists:', results_csv.exists())

if results_png.exists():
    display(Image(filename=str(results_png)))

if results_csv.exists():
    df = pd.read_csv(results_csv)
    display(df.head())

    cols = [c for c in df.columns if c.lower() not in ['epoch', 'time']]
    x = df['epoch'] if 'epoch' in df.columns else range(len(df))

    to_plot = cols[:8]
    plt.figure()
    for c in to_plot:
        plt.plot(x, df[c], label=c)
    plt.xlabel('epoch')
    plt.legend()
    plt.title('Convergence curves (from results.csv)')
    plt.show()
else:
    print('找不到 results.csv；請確認 runs/ 是否有被寫入 /opt/ml/model/runs')
