In [1]:
import os
import sys

current_directory = os.getcwd()

parent_directory = os.path.dirname(current_directory)

sys.path.append(parent_directory)

In [2]:
# %%writefile ../config/config.py

import os
from dotenv import load_dotenv

load_dotenv()

pipeline_root = os.environ['PIPELINE_ROOT']
base_image = os.environ.get("CONTAINER_IMAGE")
project_id = os.environ['PROJECT_ID']
region = os.environ['REGION']
service_account = os.environ['SERVICE_ACCOUNT']
bucket_name = os.environ['BUCKET_NAME']
model_gcs_path = os.environ['MODEL_GCS_PATH'] 
input_data_gcs_path = os.environ['GCS_URL'] 
table_ref = os.environ['TABLE_REF']

In [None]:
# %%writefile ../batch_prediction/batch_predict.py
from config.config import base_image
from kfp.v2 import dsl
from typing import Optional

@dsl.component(base_image=base_image)
def batch_predict(
    model_gcs_path: str, 
    input_data_gcs_path: str, 
    table_ref: str, 
    project: str, 
    target_column: Optional[str] = None
):
    """
    Loads data from GCS, obtains predictions, and writes data to BigQuery
    """

    import argparse
    import os
    import pandas as pd
    import joblib
    from google.cloud import storage, bigquery

    storage_client = storage.Client()
    bucket_name, model_path = model_gcs_path.replace("gs://", "").split("/", 1)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(model_path)
    model_filename = "/tmp/model.joblib"
    blob.download_to_filename(model_filename)
    model = joblib.load(model_filename)

    data_bucket_name, input_data_path = input_data_gcs_path.replace("gs://", "").split("/", 1)
    data_bucket = storage_client.bucket(data_bucket_name)
    blob = data_bucket.blob(input_data_path)
    input_data_filename = "/tmp/input_data.csv"
    blob.download_to_filename(input_data_filename)
    input_data = pd.read_csv(input_data_filename).sample(4)

    if target_column:
        input_data.drop(columns=[target_column], inplace=True, errors="ignore")
    else:
        input_data = input_data.iloc[:, :-1]

    categorical_cols = input_data.select_dtypes(include=["object"]).columns
    input_data[categorical_cols] = input_data[categorical_cols].astype("category")

    predictions = model.predict(input_data)
    print("Predictions success!")
    print(f"prediction: {predictions}")

    bigquery_client = bigquery.Client(project=project)
    job_config = bigquery.LoadJobConfig(
        schema=[
            bigquery.SchemaField("prediction", "FLOAT"),
        ],
        write_disposition=bigquery.WriteDisposition.WRITE_APPEND,
    )
    job = bigquery_client.load_table_from_dataframe(
        pd.DataFrame({"prediction": predictions}),
        table_ref,
        job_config=job_config,
    )
    job.result()  # Wait for the job to complete

    return f"Predictions written to {table_ref}"

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Obtain batch prediction from Vertex AI model"
    )
    parser.add_argument("--model_gcs_path", required=True, help="Model GCS path")
    parser.add_argument(
        "--input_data_gcs_path", required=True, help="Input data GCS path"
    )
    parser.add_argument("--table_ref", required=True, help="GCP output data table")
    parser.add_argument("--project", required=True, help="GCP project name")

    args = parser.parse_args()
    batch_predict(args.model_gcs_path, args.input_data_gcs_path, args.table_ref, args.project)

In [None]:
# %%writefile ../inference_pipeline.py

# integrate into a self contained script for automated run
from kfp import dsl, compiler
from kfp.dsl import pipeline
from batch_prediction.batch_predict import batch_predict
from config.config import model_gcs_path, input_data_gcs_path, table_ref, \
    project_id, region, pipeline_root, service_account

@pipeline(
    name="inference_pipeline",
    description="A pipeline that returns predictions from deployed model",
    pipeline_root= pipeline_root
)
def inference_pipeline(
    model_gcs_path: str = model_gcs_path,
    input_data_gcs_path: str = input_data_gcs_path,  
    table_ref:str = table_ref, 
    project: str = project_id                      
    ):
    
    batch_prediction_op = batch_predict(model_gcs_path=model_gcs_path,
                                        input_data_gcs_path=input_data_gcs_path,
                                        table_ref=table_ref, 
                                        project=project
                                        )

if __name__ == "__main__":
    pipeline_filename = "inference_pipeline.json"
    compiler.Compiler().compile(
        pipeline_func=inference_pipeline,
        package_path=pipeline_filename
    )

    from google.cloud import aiplatform
    aiplatform.init(project=project_id, location=region)
    _ = aiplatform.PipelineJob(
        display_name="inference-pipeline",
        template_path=pipeline_filename,
        parameter_values={
            "model_gcs_path": model_gcs_path,
            "input_data_gcs_path": input_data_gcs_path,  
            "table_ref": table_ref, 
            "project": project_id  
        },
        enable_caching=True
    ).submit(service_account=service_account)