In [1]:
# !pip3 install ray ray[client] --user
# !pip3 install datasets --user

In [None]:
from transformers import (
    AutoConfig, 
    AutoModelForCausalLM, 
    AutoTokenizer, 
    default_data_collator,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset, load_from_disk

import ray
from ray.train.huggingface import TransformersTrainer
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from transformers.trainer_utils import get_last_checkpoint
import os
import wandb

wandb_api = wandb.Api()

In [None]:
WANDB_PROJECT = os.environ.get('WANDB_PROJECT', 'run-ray')
WANDB_API_KEY = os.environ.get('WANDB_API_KEY', wandb_api.api_key)

In [None]:
runtime_env = {
    "pip": ['wandb', 's3fs'],
    "env_vars": {"WANDB_PROJECT": WANDB_PROJECT,
                'WANDB_API_KEY': WANDB_API_KEY}
}

ray.init("ray://jupyter:10001", runtime_env=runtime_env)

In [None]:
def train_func(config):
    import s3fs
    MODEL_NAME = "gpt2"

    fs = s3fs.S3FileSystem(endpoint_url = 'http://minio:9000', anon = True)
    train_dataset = load_from_disk('s3://train/wiki-test', 
                                 storage_options=fs.storage_options, 
                                 keep_in_memory = False)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    
    output_dir = f"/home/ubuntu/{MODEL_NAME}-test"
    
    args = TrainingArguments(
        output_dir=output_dir,
        save_strategy="steps",
        logging_strategy="steps",
        learning_rate=2e-5,
        weight_decay=0.01,
        max_steps=10000,
        save_steps = 100,
        save_total_limit = 2,
        logging_steps = 1,
        per_device_train_batch_size = 6,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        data_collator=default_data_collator,
    )
    
    last_checkpoint = get_last_checkpoint(output_dir)
    checkpoint = None
    if last_checkpoint is not None:
        checkpoint = last_checkpoint
    print(checkpoint)
    trainer.train(resume_from_checkpoint=checkpoint)

In [None]:
scaling_config = ScalingConfig(num_workers=1, use_gpu=True)
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
)
result = ray_trainer.fit()