In [49]:
import argparse
import boto3
import evaluate
import importlib
import json
import lighteval
import os
import pathlib
import requests
import shutil
import sys
import tarfile
import time
import torch
import transformers
import uuid
import wandb

import awswrangler as wr
import numpy as np

from botocore.exceptions import ClientError
from datasets import load_dataset, DatasetDict, Dataset
from datetime import datetime, timezone
from IPython.display import display
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
from sagemaker import image_uris, utils as sm_utils
from sagemaker.huggingface.processing import HuggingFaceProcessor
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.processing import FrameworkProcessor
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.workflow.steps import ProcessingStep
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.session import get_execution_role
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoConfig, 
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer
)
from typing import List, Union, Optional


# Adding ../01_modules or ./01_modules to the system path so that we can load modules from 
# there as well
if '__file__' in globals():
    script_dir = pathlib.Path(__file__).parent.resolve()
else:
    script_dir = pathlib.Path().absolute()
modules_path_in_dev = os.path.abspath(os.path.join(script_dir, '..', '01_modules'))
modules_path_in_prod = os.path.abspath(os.path.join(script_dir, '01_modules'))
if os.path.exists(modules_path_in_dev):
    sys.path.append(modules_path_in_dev)
if os.path.exists(modules_path_in_prod):
    sys.path.append(modules_path_in_prod)


# # Jupyter only reads a local module the first time after 
# # kernel start. Re-running a cell with 
# # "from mymodulename import *" would not change
# # anything, even if the imported module has since changed.
# # As a workaround, we need to directly load the module, 
# # use importlib.reload to reload it and then import * 
import utils
_ = importlib.reload(utils)
import config
_ = importlib.reload(config) 

_ = """
ml.g6.xlarge:   6
ml.g6.2xlarge:  3
ml.g6.4xlarge:  3
ml.g6.8xlarge:  3
ml.g6.12xlarge: 3
ml.g6.16xlarge: 3
ml.g6.24xlarge: 3
ml.g6.48xlarge: 3

google-bert/bert-base-uncased
distilbert/distilbert-base-uncased
microsoft/deberta-v3-base
FacebookAI/roberta-base
answerdotai/ModernBERT-base
answerdotai/ModernBERT-large
allenai/scibert_scivocab_uncased
google/bigbird-roberta-base
allenai/longformer-base-4096
"""

utils.py loaded: v0.2.12
config.py loaded: v0.1


In [52]:
JOB_NAME = utils.create_supervised_multiclass_classification_training_job(
    SCRIPT_FILEPATH = script_dir,
    MODEL_NAME='distilbert-base-uncased',
    INSTANCE_TYPE='ml.g6.8xlarge',
    HF_DATASET_SUFFIX='_Title_TopicIndex',
    TEXT_KEY='title',
    LABEL_TYPE='topic',
    SAMPLE='100', # must be string
    MAX_RUNTIME_S=2*60*60
)
JOB_NAME

Training job created: distilbert-topic-title-s100-0906222925


'distilbert-topic-title-s100-0906222925'

In [53]:
JOB_NAME = utils.create_supervised_multiclass_classification_training_job(
    SCRIPT_FILEPATH = script_dir,
    MODEL_NAME='google-bert/bert-base-uncased',
    INSTANCE_TYPE='ml.g6.24xlarge',
    HF_DATASET_SUFFIX='_Abstract_TopicIndex',
    TEXT_KEY='abstract',
    LABEL_TYPE='topic',
    SAMPLE='100', # must be string
    MAX_RUNTIME_S=24*60*60
)
JOB_NAME

Training job created: google-topic-abstract-s100-0907004910


'google-topic-abstract-s100-0907004910'

In [27]:
_ = """
ml.g6.xlarge:   6
ml.g6.2xlarge:  3
ml.g6.4xlarge:  3
ml.g6.8xlarge:  3
ml.g6.12xlarge: 3
ml.g6.16xlarge: 3
ml.g6.24xlarge: 3
ml.g6.48xlarge: 3
"""

INSTANCE_TYPE = 'ml.g6.4xlarge'
ENTRY_POINT = '05_tuning_basic/05_12_tuning_basic_simple.py'
MODEL_NAME = 'distilbert-base-uncased'

HF_DATASET_SUFFIX = '_Title_SubfieldIndex'
LABEL_TYPE = 'subfield'
TEXT_KEY = 'title'
TEXT_KEY_RENAME_TO = 'text'
LABEL_KEY_RENAME_TO = 'label'
SAMPLE = '1'          # must be string
VOLUME_SIZE_GB = 450
MAX_RUNTIME_S = 60*60



NOW = datetime.now().strftime('%Y%m%d%H%M%S')
JOB_NAME = f'{MODEL_NAME}{HF_DATASET_SUFFIX}_sample-{SAMPLE}_{NOW}_{INSTANCE_TYPE.replace(".","-")}' # TODO: add more params
SAGEMAKER_CLIENT = boto3.client('sagemaker', region_name=config.AWS_REGION)
S3_CLIENT = boto3.client('s3')
EXECUTION_ROLE = get_execution_role()
SCRIPT_FILEPATH = script_dir
SOURCE_DIRPATH = SCRIPT_FILEPATH.parents[0]
ROOT_DIRPATH = SCRIPT_FILEPATH.parents[1]
TEMP_DIRPATH = pathlib.Path(f'./_code/{JOB_NAME}')
TAR_FILEPATH = pathlib.Path(f'./_tar/source-{JOB_NAME}.tar.gz')
ENV_VARS = {
    'HUGGINGFACE_HUB_CACHE': '/tmp/.cache'
}
# config.DEFAULT_S3_BUCKET_NAME
print('SOURCE_DIRPATH', SOURCE_DIRPATH)

SOURCE_DIRPATH /home/sagemaker-user/research_methodology_extraction/src


In [28]:
if TEMP_DIRPATH.parents[0].exists():
    shutil.rmtree(TEMP_DIRPATH.parents[0])
TEMP_DIRPATH.mkdir(parents=True, exist_ok=True)

if TAR_FILEPATH.parents[0].exists():
    shutil.rmtree(TAR_FILEPATH.parents[0])
TAR_FILEPATH.parents[0].mkdir(parents=True, exist_ok=True)

ignore_names = {'__pycache__', '.ipynb_checkpoints'}
for item in SOURCE_DIRPATH.iterdir():
    name = item.name
    if name in ignore_names:
        continue
    dest = TEMP_DIRPATH / name
    if item.is_dir():
        # print('item.is_dir()', item, dest)
        for item2 in item.iterdir():
            name2 = item2.name
            if name2 in ignore_names:
                continue
            dest2 = TEMP_DIRPATH / name / name2
            if item2.is_dir():
                pass
                # print('item2.is_dir() NOT COPYING', item2, dest2)
                # shutil.copytree(item, dest, ignore=shutil.ignore_patterns('__pycache__', '*.pyc', '*.pyo', '*.tmp'), dirs_exist_ok=True)
            else:
                # print('else', item2, dest2)
                dest.mkdir(parents=True, exist_ok=True)
                shutil.copy2(item2, dest2)
        # shutil.copytree(item, dest, ignore=shutil.ignore_patterns('__pycache__', '*.pyc', '*.pyo', '*.tmp'), dirs_exist_ok=True)
    else:
        # print('else', item, dest)
        shutil.copy2(item, dest)

shutil.copy2(ROOT_DIRPATH / 'requirements_train.txt', TEMP_DIRPATH / 'requirements.txt')

# Tar the temp_dir (its contents become root of /opt/ml/code)
with tarfile.open(TAR_FILEPATH, 'w:gz') as tar:
    tar.add(str(TEMP_DIRPATH), arcname='.')


In [29]:
code_s3_key = f'02_code/train/{JOB_NAME}/source.tar.gz'
S3_CLIENT.upload_file(str(TAR_FILEPATH), config.DEFAULT_S3_BUCKET_NAME, code_s3_key)
code_s3_uri = f's3://{config.DEFAULT_S3_BUCKET_NAME}/{code_s3_key}'

In [30]:
image_uri = image_uris.retrieve(
    framework='huggingface',
    region=config.AWS_REGION,
    version='4.26.0',                 # transformers version
    py_version='py39',
    instance_type=INSTANCE_TYPE,
    image_scope='training',
    base_framework_version='pytorch1.13.1'
)
print('Using training image:', image_uri)

image_uri = image_uris.retrieve(
    framework='huggingface',
    region=config.AWS_REGION,
    version='4.49.0',                 # transformers version
    py_version='py311',
    instance_type=INSTANCE_TYPE,
    image_scope='training',
    base_framework_version='pytorch2.5.1'
)
print('Using training image:', image_uri)

Using training image: 763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04
Using training image: 763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.5.1-transformers4.49.0-gpu-py311-cu124-ubuntu22.04


In [31]:
''

''

In [32]:
hyperparameters = {
    # SageMaker training toolkit special keys:
    'sagemaker_program': ENTRY_POINT,
    'sagemaker_submit_directory': code_s3_uri,
    'sagemaker_container_log_level': '20',
    'sagemaker_region': config.AWS_REGION,

    # Your script args:
    'runtype': 'prod',
    'now': NOW,
    'instance_type': INSTANCE_TYPE,
    'model_name': MODEL_NAME,
    
    'hf_dataset_suffix': HF_DATASET_SUFFIX,
    'label_type': LABEL_TYPE,
    'text_key': TEXT_KEY,
    'text_key_rename_to': TEXT_KEY_RENAME_TO,
    'label_key_rename_to': LABEL_KEY_RENAME_TO,
    'sample': SAMPLE,          # must be string
    # 'epochs': '5',
    # 'train_batch_size': '32',
    # 'eval_batch_size': '64',
    # 'warmup_steps': '500',
    # 'learning_rate': '5e-5'
}

In [33]:
input_data_config = [
    {
        'ChannelName': 'train',
        'DataSource': {
            'S3DataSource': {
                'S3DataType': 'S3Prefix',
                'S3Uri': 's3://sagemaker-research-methodology-extraction/01_data/03_core/unified_works_train/',
                'S3DataDistributionType': 'FullyReplicated'
            }
        },
        'InputMode': 'File'
    }
]

In [34]:
try:
    resp = SAGEMAKER_CLIENT.create_training_job(
        TrainingJobName=JOB_NAME,
        RoleArn=EXECUTION_ROLE,
        AlgorithmSpecification={
            'TrainingImage': image_uri,
            'TrainingInputMode': 'File'
        },
        HyperParameters=hyperparameters,
        InputDataConfig=input_data_config,
        OutputDataConfig={
            'S3OutputPath': f's3://{config.DEFAULT_S3_BUCKET_NAME}/03_training_output/{JOB_NAME}'
        },
        ResourceConfig={
            'InstanceType': INSTANCE_TYPE,
            'InstanceCount': 1,
            'VolumeSizeInGB': VOLUME_SIZE_GB
        },
        StoppingCondition={'MaxRuntimeInSeconds': MAX_RUNTIME_S},
        Environment=ENV_VARS,
        EnableManagedSpotTraining=False
    )
    print('Training job created:', JOB_NAME)
except ClientError as e:
    print('create_training_job failed:')
    print(e.response.get('Error', e))
    raise

Training job created: hf-boto-20250906175041-ml-g6-4xlarge


In [35]:


# 7. (Optional) simple waiter + log group polling
print('Polling status (CTRL+C to stop)...')
LOGS_CLIENT = boto3.client('logs', region_name=config.AWS_REGION)
log_group = '/aws/sagemaker/TrainingJobs'

def stream_logs():
    seen = set()
    while True:
        desc = SAGEMAKER_CLIENT.describe_training_job(TrainingJobName=JOB_NAME)
        status = desc['TrainingJobStatus']
        print('Status:', status)
        # Try to fetch log streams
        try:
            streams = LOGS_CLIENT.describe_log_streams(
                logGroupName=log_group,
                logStreamNamePrefix=JOB_NAME
            ).get('logStreams', [])
            for s in streams:
                stream_name = s['logStreamName']
                events = LOGS_CLIENT.get_log_events(
                    logGroupName=log_group,
                    logStreamName=stream_name,
                    startFromHead=True
                )['events']
                for ev in events:
                    if ev['eventId'] in seen: 
                        continue
                    seen.add(ev['eventId'])
                    print(ev['message'].rstrip())
        except LOGS_CLIENT.exceptions.ResourceNotFoundException:
            pass

        if status in ('Completed','Failed','Stopped'):
            print('Final status:', status)
            if status == 'Failed':
                print('Failure reason:', desc.get('FailureReason'))
            break
        time.sleep(30)

# tail inside script run:
stream_logs()

Polling status (CTRL+C to stop)...
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
Status: InProgress
