In [None]:
import os
import json
import shutil

In [None]:
train_raw_file_name = 'train-v1.1.json'
valid_raw_file_name = 'valid-v1.1.json'
base_url = 'https://github.com/yahoojapan/JGLUE/raw/main/datasets/jsquad-v1.1/'
data_dir = 'data/'
!rm -rf {data_dir}
!mkdir -p {data_dir}
!echo {base_url}{train_raw_file_name}
!wget {base_url}{train_raw_file_name}
!mv {train_raw_file_name} {data_dir}
!wget {base_url}{valid_raw_file_name}
!mv {valid_raw_file_name} {data_dir}

In [None]:
def transform(path):
    with open(path) as f:
        raw_data = json.loads(f.read())
    
    
    title_index = 0
    qas = []
    for title in a['data']:
        paragraph_index = 0
        for paragraph in title['paragraphs']:
            context = paragraph['context']
            question_index = 0
            for question in paragraph['qas']:
                instruction = question['question']
                answer = question['answers'][0]['text']
                question_id = question['id']
                qa = {
                    'title_index' : title_index,
                    'paragraph_index' : paragraph_index,
                    'question_index' : question_index,
                    'input' : context,
                    'output' : answer,
                    'instruction' : instruction,
                    'question_id' : question_id
                }
                qas.append(qa)
                question_index += 1
            paragraph_index += 1
        title_index += 1
    
    return qas

train_data = transform(os.path.join(data_dir,train_raw_file_name))
valid_data = transform(os.path.join(data_dir,valid_raw_file_name))

In [None]:
len(train_data)

In [None]:
train_data_tmp = []
split_num = 5

for i in range(split_num):
    for x in train_data:
        if x['question_index'] == i:
            train_data_tmp.append(x)
    with open(os.path.join(data_dir,f'p{str(i)}.json'),'wt') as f:
        f.write(json.dumps(train_data_tmp))

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFace

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

In [None]:
input_s3_list = []
for i in range(split_num):
    input_train = sess.upload_data(
        path=f"./data/p{str(i)}.json",
        key_prefix="jsquad"
    )
    input_s3_list.append(input_train)

In [None]:
input_s3_list[3]

In [None]:
for i,s3_uri in enumerate(input_s3_list):
    base_job_name=f"jsquad-eval-{str(i)}"
    hyperparameters={
        'base_model':'rinna/japanese-gpt-neox-3.6b',
        'load_in_8bit': True,
        # 'load_in_4bit': True,
        'data_path': f'/opt/ml/input/data/train/p{str(i)}.json',
        'num_epochs': 3, # default 3
        'cutoff_len': 512,
        'group_by_length': False,
        'output_dir': '/opt/ml/model',
        # 'resume_from_checkpoint': '/opt/ml/checkpoints',
        'lora_target_modules': '[query_key_value]',
        'lora_r': 16,
        'batch_size': 8,
        'micro_batch_size': 8,
        'prompt_template_name': 'alpaca',
        ## wandb setting
        'wandb_project': 'jsquad-eval',
        'wandb_run_name': base_job_name,
        'wandb_watch': "gradients",  # options: false | gradients | all
        'wandb_log_model': "false",  # options: false | true
    }
    environment = {
        'WANDB_API_KEY': 'd184d8a3762bbf3cacf36fcf9780b6a58aba59b6'
    }
    huggingface_estimator = HuggingFace(
        base_job_name=base_job_name,
        role=role,
        entry_point='finetune.py',
        source_dir='./scripts/code',
        instance_type='ml.g5.2xlarge',
        instance_count=1,
        volume_size=200,
        transformers_version='4.26',
        pytorch_version='1.13',
        py_version='py39',
        # use_spot_instances=True,
        # max_wait=86400,
        hyperparameters=hyperparameters,
        environment=environment,
        metric_definitions=[{'Name': 'eval_loss', 'Regex': "'eval_loss': (\d\.\d+)"},
                            {'Name': 'train_loss', 'Regex': "'loss': (\d\.\d+)"}],
        # checkpoint_s3_uri=f"s3://{bucket}/{base_job_name}/checkpoint/",
    )
    huggingface_estimator.fit({'train': s3_uri},wait=False)

---
## 出来上がったモデルを評価

In [None]:
import boto3
sagemaker = boto3.client('sagemaker')

In [None]:
training_jobs = []
for training_job in sagemaker.list_training_jobs()['TrainingJobSummaries'][:split_num]:
    training_jobs.append({
        'TrainingJobName' : [training_job['TrainingJobName']]
    })
print(training_jobs)

In [None]:
# training 失敗発生時の書き換え
training_job_names = [
    'jsquad-eval-0-2023-11-15-05-29-12-947',
    'jsquad-eval-1-2023-11-15-05-29-13-904',
    'jsquad-eval-2-2023-11-15-05-29-16-366',
    'jsquad-eval-3-2023-11-16-11-35-25-879',
    'jsquad-eval-4-2023-11-15-05-29-18-049',
]

for i, training_job_name in enumerate(training_job_names):
    training_jobs[i]['TrainingJobName'] = training_job_name

In [None]:
for i, training_job in enumerate(training_jobs):
    training_jobs[i]['S3ModelArtifacts'] = sagemaker.describe_training_job(
        TrainingJobName=training_job['TrainingJobName'],
    )['ModelArtifacts']['S3ModelArtifacts']

In [None]:
training_jobs

In [None]:
peft_path_list = []
for i, training_job in enumerate(training_jobs):
    model_uri = training_job['S3ModelArtifacts']
    peft_path = f'{str(i).zfill(5)}'
    peft_path_list.append(peft_path)
    !mkdir -p {peft_path}
    !aws s3 cp {model_uri} ./{peft_path}
    !tar zxvf ./{peft_path}/model.tar.gz -C ./{peft_path}/
    checkpoints=[]
    for obj in os.listdir(f'./{peft_path}/'):
        if 'checkpoint' in obj:
            checkpoints.append(obj)
        else:
            print(f'del {obj}')
            !rm -rf {obj}
    max_point = 0
    for checkpoint in checkpoints:
        max_point = int(checkpoint.split('-')[-1]) if max_point < int(checkpoint.split('-')[-1]) else max_point
    for checkpoint in checkpoints:
        if max_point == int(checkpoint.split('-')[-1]):
            print(max_point)
            os.rename(f'./{peft_path}/{checkpoint}',f'./{peft_path}/peft')
        else:
            !rm -rf ./{peft_path}/{checkpoint}

In [None]:
peft_s3_list = []
for peft_path in peft_path_list:
    peft_s3_uri = sess.upload_data(
        path=peft_path,
        key_prefix=f'/jsquad/peft/{peft_path}'
    )
    peft_s3_list.append(peft_s3_uri)


In [None]:
pip install sagemaker-ssh-helper

In [None]:
# from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
for i,s3_uri in enumerate(peft_s3_list):
    base_job_name=f"jsquad-eval-{str(i)}"
    huggingface_estimator = HuggingFace(
        base_job_name=base_job_name,
        role=role,
        entry_point='entrypoint.py',
        source_dir='./lm-evaluation-harness/',
        instance_type='ml.g5.2xlarge',
        instance_count=1,
        volume_size=200,
        transformers_version='4.28.1',
        pytorch_version='2.0.0',
        py_version='py310',
        # dependencies=[SSHEstimatorWrapper.dependency_dir()],
    )
    # ssh_wrapper = SSHEstimatorWrapper.create(huggingface_estimator, connection_wait_time_seconds=3600) 
    huggingface_estimator.fit({'train': s3_uri},wait=False)
    # instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  
    # print(instance_ids)

pip install -e ".[ja]"
pip install bitsandbytes
pip install accelerate


python main.py   --model hf-causal-experimental   --model_args pretrained=rinna/japanese-gpt-neox-3.6b,load_in_8bit=True,device_map_option=auto,dtype=float16,peft=../model_0/checkpoint-5000   --tasks 'jsquad-1.1-0.2'   --num_fewshot '1'   --device 'cuda'   --output_path 'result_0.json'