In [None]:
import boto3
import os
import requests
import threading
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
from datetime import datetime, timezone


# This script performs the upload of a local datafile to a given Collection (as identified by its Collection uuid),
# where the datafile becomes a Dataset accessible via the Data Portal UI. In order to use this script, you must have
# a Curation API key, you must know the id of the Collection to which you wish to upload the datafile, and you must 
# decide upon a string tag (the `curator_tag`) to use to uniquely identify the resultant Dataset within this 
# Collection going forward. Uploads to a tag that has not been used yet in the given Collection will result in a
# new Dataset being created. Uploads to a tag for which there already exists a Dataset in the given Collection will 
# result in the existing Dataset being replaced by the new Dataset created from the datafile that you are uploading.

# You can only add/replace Datasets in private Collections or revision Collections. 


# Curators: substitute field values here ↓
api_key_file = "api_key_file.txt"  # path to file containing your API key
filename = "/absolute/path/to_datafile.h5ad"  # Provide the absolute path to the h5ad datafile to upload
curator_tag = "arbitrary/tag/choosen-by-you.h5ad"  # **MUST POSSESS '.h5ad' SUFFIX** -- curator_tag will serve as a unique identifier *within this Collection* for the resultant Dataset
collection_id = "01234567-89ab-cdef-0123-456789abcdef"  # your (non-public) Collection id, e.g. "1234abcd-5678-efab-1a2b-3c4d5e6f8a9b"
# ↑


# -------------------------------------------------------------------------------------------------->
# Don't make any functional changes below unless you're feeling particularly motivated 🎲🎲
# -------------------------------------------------------------------------------------------------->


#####################################################################################################################
# 1) Use API key to obtain a temporary access token (for authentication & authorization with Curation API routes)
#####################################################################################################################
api_key = open(api_key_file).read().strip()  
access_token_headers = {"x-api-key": api_key}
access_token_url = "https://api.cellxgene.dev.single-cell.czi.technology/curation/v1/auth/token"
resp = requests.post(access_token_url, headers=access_token_headers)
access_token = resp.json().get("access_token")
print("Retrieved access token")
# print(access_token)  # Uncomment to verify access token


#####################################################################################################################
# 2) Use access token to obtain temporary s3 write credentials for a given Collection; these credentials will only work for THIS Collection.
#####################################################################################################################
s3_credentials_url = f"https://api.cellxgene.dev.single-cell.czi.technology/curation/v1/collections/{collection_id}/datasets/s3-upload-credentials"
s3_cred_headers = {"Authorization": f"Bearer {access_token}"}

time_zone_info = datetime.now(timezone.utc).astimezone().tzinfo


def retrieve_s3_credentials():
    resp = requests.post(s3_credentials_url, headers=s3_cred_headers)
    s3_creds = resp.json().get("Credentials")
    s3_creds_formatted = {
        "access_key": s3_creds.get("AccessKeyId"),
        "secret_key": s3_creds.get("SecretAccessKey"),
        "token": s3_creds.get("SessionToken"),
        "expiry_time": datetime.fromtimestamp(s3_creds.get("Expiration")).replace(tzinfo=time_zone_info).isoformat(),
    }
    print("Retrieved/refreshed s3 credentials")
    return s3_creds_formatted

session_creds = RefreshableCredentials.create_from_metadata(
    metadata=retrieve_s3_credentials(),
    refresh_using=retrieve_s3_credentials,
    method="sts-assume-role",
)
#print(retrieve_s3_credentials())  # Uncomment to verify s3 credentials


#####################################################################################################################
# 3) Upload file using temporary s3 credentials
#####################################################################################################################
session = get_session()
session._credentials = session_creds
boto3_session = boto3.Session(botocore_session=session)
s3 = boto3_session.client("s3")

filesize = os.path.getsize(filename)

# For logging % uploaded ↓
def get_progress_cb():
    lock = threading.Lock()
    uploaded_bytes = 0
    prev_percent = 0

    def progress_cb(num_bytes):
        nonlocal uploaded_bytes
        nonlocal prev_percent
        should_update_progress_printout = False
        
        lock.acquire()
        uploaded_bytes += num_bytes
        percent_of_total_upload = float("{:.1f}".format(uploaded_bytes / filesize * 100))
        if percent_of_total_upload > prev_percent:
            should_update_progress_printout = True
        prev_percent = percent_of_total_upload
        lock.release()
        
        if should_update_progress_printout:
            print(f"{percent_of_total_upload}% uploaded\r", end="")
            

    return progress_cb
# ↑

try:
    print(f"Uploading {filename} to Collection {collection_id} with tag '{curator_tag}'...")
    s3.upload_file(
        Filename=filename,
        Bucket="cellxgene-dataset-submissions-dev",
        Key=f"{collection_id}/{curator_tag}",
        Callback=get_progress_cb(),
    )
except Exception as e:
    print("\n\n\033[1m\033[38;5;9mFAILED\033[0m")  # 'FAILED' in bold red
    print(f"\n\n{e}")
else:
    print("\n\n\033[1m\033[38;5;10mSUCCESS\033[0m")  # 'SUCCESS' in bold green
    print(f"\nFile {filename} successfully uploaded to Collection {collection_id} with tag {curator_tag}")