In [1]:
from typing import Dict, Any, List

import pytest

"""
    Learning tests/functions for Prefect:
        - can be used for testing connection to client
        - can be used to learn how prefect client and graphql works
"""


from omigami.authentication.prefect_factory import prefect_client_factory


def get_prefect_client():
    """
    Used by the other functions
    """
    prefect_client = prefect_client_factory.get()

    if not prefect_client.active_tenant_id:
        prefect_client.create_tenant("default")

    return prefect_client



def test_connection():
    """
    Tests if prefect client can connect to prefect server
    """
    prefect_client = get_prefect_client()

    assert prefect_client


def get_flows():
    """
    Gets the flows for the first project that is listed from prefect server 
    """

    prefect_client = get_prefect_client()
    projects = get_projects(prefect_client)
    flows = get_project_flows(prefect_client, projects[0]["id"])

    assert flows


def get_projects(prefect_client) -> List[str]:
    """
    Uses the graphql endpoint to query for the list of projects
    """
    query = prefect_client.graphql(
        {"query": {'project': {"name", "id"}}}
    )

    return [item for item in query["data"]["project"]]


def get_project_id(prefect_client, project_name: str) -> str:
    """
    Uses the graphql endpoint to query for the project id based on the project name
    """
    query = prefect_client.graphql(
        {"query": {'project(where: {name: {_eq: "%s"}})' % (project_name): {"name", "id"}}}
    )

    try:
        id = query["data"]["project"][0]["id"]
    except (KeyError, IndexError):
        raise ValueError(
            f"Project {project_name} not found in Prefect Server."
        )

    return id


def get_project_flows(prefect_client, project_id) -> Dict[str, Any]:
    """
    Lists the flows inside a certain project. Expects the project id.
    """
    query = prefect_client.graphql(
        {
            "query": {
                'flow(where: {project_id: {_eq: "%s"}})'
                % (project_id): {"name", "version", "id"}
            }
        }
    )

    all_flows = query["data"]["flow"]

    flows = {}
    for flow in all_flows:
        n = flow["name"]
        if n not in flows:
            flows[n] = flow
            continue

        up_to_date_flow = flows[n]
        if up_to_date_flow["version"] < flow["version"]:
            flows[n] = flow

    return flows


In [5]:

"""
Example usage
"""
client = get_prefect_client()
projects = get_projects(client)

print(projects)

[{
    "id": "9bf13a7e-33cc-4d1c-80e5-9f54a6cc13df",
    "name": "spec2vec"
}]
