# Import libraries

In [None]:
import glob
import os

from google.cloud import aiplatform

# Set variables

In [None]:
project_id = "smle-attribution-d237"
bucket_name = "gs://attribute-models-bucket/fit-model"
region = "europe-west4"
service_account = "awesomeserviceaccount@smle-attribution-d237.iam.gserviceaccount.com"
image_uri = "europe-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-11:latest"

tensorboard_name = "karan_tb"
tb_resource_name = f"projects/737104763822/locations/{region}/tensorboards/5001001095590313984"

display_name = "fit_custom"
module_name = "trainer.train"
gcs_output_uri_prefix = f"{bucket_name}/{display_name}"

In [None]:
os.environ["bucket_name"] = bucket_name

# Setup Vertex AI

In [None]:
aiplatform.init(
    project=project_id,
    staging_bucket=bucket_name,
    location=region,
)

# Tensorboard instance

Create instance

In [None]:
# tensorboard = aiplatform.Tensorboard.create(
#     display_name=tensorboard_name,
#     location=region,
#     project=project_id,
# )

Use existing instance

In [None]:
tensorboard = aiplatform.Tensorboard(tb_resource_name)

# Make package

In [None]:
# %%sh
# cd ../
# make build
# gsutil cp ./dist/*.whl $bucket_name

In [None]:
package_name = os.path.basename(glob.glob("../dist/*.whl")[0])
package_name

# Custom job

In [None]:
replica_count = 1
machine_type = "n1-standard-4"
accelerator_count = 1
accelerator_type = "NVIDIA_TESLA_T4"

args = [
    "--batch_size",
    "64",
    "--num_epochs",
    "2",
]

In [None]:
custom_training_job = aiplatform.CustomPythonPackageTrainingJob(
    display_name=display_name,
    python_package_gcs_uri= f"{bucket_name}/{package_name}",
    python_module_name=module_name,
    container_uri=image_uri,
)

In [None]:
custom_training_job.run(
    args=args,
    base_output_dir=gcs_output_uri_prefix,
    replica_count=replica_count,
    machine_type=machine_type,
    accelerator_count=accelerator_count,
    accelerator_type=accelerator_type,
    tensorboard=tensorboard.resource_name,
    service_account=service_account,
)