In [None]:
from shutil import rmtree


In [None]:
@inputs(data_session_id=Types.String)
@inputs(sub_path=Types.String)
@inputs(stream_name=Types.String)
@inputs(stream_config_string=Types.String)
@outputs(output_zip=Types.Blob)
@batch_sub_task(version="1", memory_hint=4000)
def download_and_prepare_dataset(
    data_session_id, 
    sub_path, 
    stream_name, 
    stream_config_string, 
    output_zip
):
    
    """
    Take a dataset pair (collection_id, sub_path, and stream_name) and archive it all into 1 zip file
    so the training workflow has a quicker time downloading it all
    This also supports the sub-selection elements that exist in the config where we can sub-select a subset of
    the content in that dataset
    """

    dataset_config_json = ujson.loads(stream_config_string)
    tmp_folder = wf_params.working_directory.get_named_tempfile("input")
    output_zip_file_name = wf_params.working_directory.get_named_tempfile("output.zip")

    output_dir_path = Path(tmp_folder)
    if output_dir_path.exists():
        rmtree(tmp_folder)
    output_dir_path.mkdir(0o777, parents=True, exist_ok=False)

    s3_client = boto3.client("s3")
    bucket = BUCKET_NAME
    prefix = DATA_PATH_FORMAT.format(
        collection_id=data_session_id, sub_path=sub_path, stream_name=stream_name
    )

    # List all objects within a S3 bucket path
    start = dataset_config_json.get("start", 0)
    end = dataset_config_json.get("end", 100)
    every_n = dataset_config_json.get("every_n", 1)
    every_n_offset = dataset_config_json.get("every_n_offset", 0)

    dataset_keys = s3_list_contents_paginated(
        s3_client=s3_client, bucket=bucket, prefix=prefix
    )
    dataset_size = len(dataset_keys)
    # Loop through each file
    for i in range(0, dataset_size):
        # Get the file name
        file = dataset_keys[i]
        frame_id = should_include_frame_in_subset(
            frame_key_name=file["Key"],
            dataset_size=dataset_size,
            subset_start=start,
            subset_end=end,
            subset_every_n=every_n,
            subset_every_n_offset=every_n_offset,
        )
        if not frame_id:
            continue

        # download, with a more specific name to tmp
        new_file_name = (
            f"{tmp_folder}/Frame_{data_session_id}_{stream_name}_{frame_id:05d}.png"
        )
        print("downloading {} to {}".format(file["Key"], new_file_name))
        s3_client.download_file(bucket, file["Key"], new_file_name)

    zipdir(tmp_folder, output_zip_file_name)
    zip_blob = Types.Blob()
    with zip_blob as fileobj:
        with open(output_zip_file_name, mode="rb") as file:  # b is important -> binary
            fileobj.write(file.read())