In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dotenv import load_dotenv
from fathomnet_voxel51.check_gcp_auth import check_gcp_auth

load_dotenv("../.env")

check_gcp_auth()

In [None]:
import json
import asyncio
import aiohttp
from google.cloud import storage
from tqdm.asyncio import tqdm_asyncio

# CONFIGURATION
BUCKET_NAME = "voxel51-test"


async def upload_stream(session, url, blob_name, bucket, semaphore):
    async with semaphore:
        blob = bucket.blob(blob_name)
        if blob.exists():
            return "skipped"

        try:
            async with session.get(url) as response:
                if response.status == 200:
                    content = await response.read()
                    blob.upload_from_string(
                        content, content_type=response.headers.get("Content-Type")
                    )
                    return "uploaded"
                else:
                    return f"error_status_{response.status}"
        except Exception as e:
            return f"error_{str(e)}"


async def process_split(json_path, split_name, limit=None, concurrent=50):
    # 1. Setup GCS
    storage_client = storage.Client()  # Project inferred from auth
    bucket = storage_client.bucket(BUCKET_NAME)

    # Define prefix based on split (e.g., 'fathomnet/train_images/')
    gcp_prefix = f"fathomnet/{split_name}_images/"

    # 2. Load JSON
    print(f"Loading {json_path} for split '{split_name}'...")
    with open(json_path, "r") as f:
        data = json.load(f)

    images = data["images"]
    if limit:
        images = images[:limit]
        print(f"Ô∏èLimiting to first {limit} images.")

    # 3. Async Stream
    semaphore = asyncio.Semaphore(concurrent)
    print(
        f"Stream-uploading {len(images)} images to gs://{BUCKET_NAME}/{gcp_prefix}..."
    )

    async with aiohttp.ClientSession() as session:
        tasks = []
        for img in images:
            # Filename safety: ensure it's clean
            fname = img["file_name"]
            blob_name = f"{gcp_prefix}{fname}"
            url = img["coco_url"]

            tasks.append(upload_stream(session, url, blob_name, bucket, semaphore))

        results = await tqdm_asyncio.gather(*tasks)

    # 4. Report
    print(
        f"Split '{split_name}' complete: {results.count('uploaded')} uploaded, {results.count('skipped')} skipped."
    )

In [None]:
# load a subset of the data to test the code
train_json = "data/dataset_train.json"
test_json = "data/dataset_test.json"
limit = 100

await process_split(train_json, "train", limit)