In [None]:
from kfp import dsl, compiler

@dsl.component(packages_to_install=["gcsfs", "transformers", "datasets==2.16", "evaluate==0.4.3", "accelerate", "scikit-learn", "kubeflow-training"])
def start_distributed_training(bucket: str, dataset_file: str, output_model_name: str, gproject: str) -> str:
    import os
    import gcsfs
    import numpy as np
    from datasets import load_dataset
    from datasets.distributed import split_dataset_by_node
    from transformers import (
        AutoModelForSequenceClassification,
        AutoTokenizer,
        Trainer,
        TrainingArguments,
    )
    from kubeflow.training import TrainingClient
    import torch

    def train_func(parameters):
        import os
        import gcsfs
        import numpy as np
        from datasets import load_dataset
        from datasets.distributed import split_dataset_by_node
        from transformers import (
            AutoModelForSequenceClassification,
            AutoTokenizer,
            Trainer,
            TrainingArguments,
        )
        from kubeflow.training import TrainingClient
        import torch
        import evaluate
    
        model_name = parameters['MODEL_NAME']
        storage_options= parameters['STORAGE_OPTIONS'] 
        dataset = load_dataset("json", data_files=f'gs://{parameters["BUCKET"]}/{parameters["DATASET_FILE"]}', storage_options=storage_options)
        ds = dataset["train"].train_test_split(test_size=0.2)
        
        labels = [label for label in ds['train'].features.keys() if label not in ['body', 'title']]
        id2label = {idx:label for idx, label in enumerate(labels)}
        label2id = {label:idx for idx, label in enumerate(labels)}
        
    
        print("-" * 40)
        print("Download BERT Model")
        model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", 
                                                               problem_type="multi_label_classification", 
                                                               num_labels=len(labels),
                                                               id2label=id2label,
                                                               label2id=label2id)
        tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        
        # [2] Preprocess dataset.        
        def preprocess_data(example):
          text = f'{example["title"]}\n{example["body"]}'
          # encode them
          encoding = tokenizer(text, padding=True, truncation=True)
        
          lbls = [0. for i in range(len(labels))]
          for label in labels:
            if label in example and example[label] == True:
              label_id = label2id[label]
              lbls[label_id] = 1.
        
          encoding["labels"] = lbls  
          return encoding
        
        # Map Yelp review dataset to BERT tokenizer.
        print("-" * 40)
        print("Map dataset to BERT Tokenizer")
        encoded_dataset = ds.map(preprocess_data, remove_columns=ds['train'].column_names)
        encoded_dataset.set_format("torch")
        
        # Distribute train and test datasets between PyTorch workers.
        # Every worker will process chunk of training data.
        # RANK and WORLD_SIZE will be set by Kubeflow Training Operator.
        RANK = int(os.environ["RANK"])
        WORLD_SIZE = int(os.environ["WORLD_SIZE"])
        distributed_ds_train = split_dataset_by_node(
            encoded_dataset["train"],
            rank=RANK,
            world_size=WORLD_SIZE,
        )
        distributed_ds_test = split_dataset_by_node(
            encoded_dataset["test"],
            rank=RANK,
            world_size=WORLD_SIZE,
        )
        
        # Evaluate accuracy.    
        clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
    
        def sigmoid(x):
           return 1/(1 + np.exp(-x))
        
        def compute_metrics(eval_pred):
           predictions, labels = eval_pred
           predictions = sigmoid(predictions)
           predictions = (predictions > 0.5).astype(int).reshape(-1)
           return clf_metrics.compute(predictions=predictions, references=labels.astype(int).reshape(-1))
    
    
        batch_size = 3
        metric_name = "f1"
        args = TrainingArguments(
            f"{model_name}",
            evaluation_strategy = "epoch",
            save_strategy = "epoch",
            learning_rate=2e-5,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=5,
            weight_decay=0.01,
            load_best_model_at_end=True,
            metric_for_best_model=metric_name,
            #push_to_hub=True,
        )
    
        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=distributed_ds_train,
            eval_dataset=distributed_ds_test,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
        )
        
        print("-" * 40)
        print(f"Start Distributed Training. RANK: {RANK} WORLD_SIZE: {WORLD_SIZE}")
        
        trainer.train()
        
        print("-" * 40)
        print("Training is complete")

        # Export trained model to GCS from the worker with RANK = 0 (master).
        if RANK == 0:
            trainer.save_model(f"./{model_name}")
            fs = gcsfs.GCSFileSystem(**storage_options)
            files = ['config.json', 'model.safetensors', 'special_tokens_map.json', 'tokenizer_config.json', 'tokenizer.json', 'training_args.bin', 'vocab.txt']
            for f in files: 
                fs.put(f'{model_name}/{f}', f'{parameters["BUCKET"]}/{model_name}/{f}')
        
        print("-" * 40)
        print("Model export complete")
        
    job_name = "training-pipeline-job"
    # Create PyTorchJob
    TrainingClient().create_job(
        name=job_name,
        train_func=train_func,
        parameters={
            "BUCKET": bucket,
            "STORAGE_OPTIONS": {"project": gproject, "token": "google_default"},
            "MODEL_NAME": output_model_name,
            "DATASET_FILE": dataset_file
        },
        num_workers=2,  # Number of PyTorch workers to use.
        resources_per_worker={
            "cpu": "3",
            "memory": "10G",
            "gpu": "1",
        },
        packages_to_install=[
            "gcsfs",
            "transformers",
            "datasets==2.16",
            "evaluate",
            "accelerate",
            "scikit-learn",
            "kubeflow-training"
        ],  # PIP packages will be installed during PyTorchJob runtime.
    )
    # Wait until PyTorchJob has Running condition.
    job = TrainingClient().wait_for_job_conditions(
        job_name,
        expected_conditions={"Running"},
    )
    return "job is running"

@dsl.pipeline
def training_pipeline(bucket: str, dataset_file: str, output_model_name: str, gproject: str) -> str:
    training_task = start_distributed_training(bucket=bucket, dataset_file=dataset_file, output_model_name=output_model_name, gproject=gproject)
    return training_task.output

compiler.Compiler().compile(training_pipeline, 'training_pipeline.yaml')