In [1]:
from kfp.v2.dsl import component, pipeline
from kfp.v2.dsl import Dataset, Output, Input, Metrics, Markdown, Artifact
from kfp.v2 import compiler

In [2]:
# Component 1: Load images from given directory and parse them into JSON
@component()
def parse_images_from_bucket(
    bucket_name: str, 
    image_dir: str,
    output_json_file: Output[Artifact]
):
    # Import Modules (This is required for each of the component function)
    from google.cloud import storage
    import base64
    import json
    
    # Initialize a client
    storage_client = storage.Client()
    # Create a bucket object
    bucket = storage_client.get_bucket(bucket_name)
    # Create blob objects from the filepath
    blobs = bucket.list_blobs(prefix=image_dir)
    # Iterate over the blobs and filter based on file extension
    image_extensions = ('.png', '.jpg', '.jpeg')
    image_blobs = [blob for blob in blobs if blob.name.lower().endswith(image_extensions)]
    
    # Download all images in String
    image_base64_dict = {"instances": []}
    for image_blob in image_blobs:
        base64_str = base64.b64encode(image_blob.download_as_bytes()).decode('utf-8')
        # image_base64_dict[image_blob.name] = base64_str
        image_base64_dict["instances"].append({
            'image': base64_str,
            'file_path': f"gs://{bucket_name}/{image_blob.name}"
        })
    
    # Save it to Artifact
    with open(output_json_file.path, "w") as file:
        json.dump(image_base64_dict, file)

In [None]:
# TODO
# Component 2: Trigger Batch Prediction Job and Get Batch Prediction Result
@component()
def get_batch_prediction(
    json_file: Input[Artifact],
    visualization: Output[Markdown]
):
    # Import Modules
    from google.cloud import storage
    import json
    
    # Load JSON File
    with open(json_file.path, 'r') as file:
        image_json = json.load(file)
    
    # TODO: Get Batch Prediction Results
    pass
    
    # TODO: Static Visualization
    # Please update the Markdown file after getting the prediction result
    with open(visualization.path, 'w') as f:
        for image_dict in image_json['instances']:
            f.write(f"# {image_dict['file_path']} \n")
            f.write(f"![Image](data:image/png;base64,{image_dict['image']})")
    
    # TODO: Save Metadata To BigQuery
    # Schema: Path, Original Image, Masked Image, Number of Masks

In [None]:
# Pipeline Initialization
@pipeline(
    pipeline_root="gs://pipeline_artifacts",
    name="sam-pipeline",
)
def sam_pipeline(
    bucket_name: str = "wallace-playground",
    image_dir: str = "batch_1"
):
    parse_image_op = parse_images_from_bucket(bucket_name, image_dir)

    get_batch_prediction_op = get_batch_prediction(json_file=parse_image_op.outputs['output_json_file'])

In [None]:
compiler.Compiler().compile(
    pipeline_func=sam_pipeline,
    package_path='sam_pipe.json')

In [None]:
!gsutil cp sam_pipe.json gs://segment-anything

In [None]:
from google.cloud import aiplatform

PROJECT_ID = "633534855904"
REGION = "us-central1"

job = aiplatform.PipelineJob(display_name = 'sam_test',
                             template_path = 'sam_pipe.json',
                             enable_caching = True,
                             failure_policy = "slow",
                             project=PROJECT_ID,
                             location=REGION,
                            )

job.submit()