Skip to content
This repository has been archived by the owner on Feb 23, 2022. It is now read-only.

Commit

Permalink
Merge pull request #285 from multinet-app/remove_with_client
Browse files Browse the repository at this point in the history
Remove with_client wrapper
  • Loading branch information
jjnesbitt committed Feb 4, 2020
2 parents 8381691 + 5dd18e9 commit 235c36e
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 120 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ ENV/
.DS_Store

# Arango test materials.
arango/
arango-apps/
/arango/
/arango-apps/
server.out
server.pid

Expand Down
178 changes: 65 additions & 113 deletions multinet/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import os

from arango import ArangoClient
from arango.database import StandardDatabase, StandardCollection
from arango.graph import Graph
from arango.database import StandardDatabase
from arango.collection import StandardCollection

from arango.exceptions import DatabaseCreateError, EdgeDefinitionCreateError
from requests.exceptions import ConnectionError

from typing import Callable, Any, Optional, Sequence, List, Set, Generator, Tuple
from typing import Any, Sequence, List, Set, Generator, Tuple
from typing_extensions import TypedDict
from multinet.types import EdgeDirection, TableType
from multinet.errors import InternalServerError

from multinet.errors import (
BadQueryArgument,
Expand All @@ -31,45 +35,31 @@
GraphNodesSpec = TypedDict("GraphNodesSpec", {"count": int, "nodes": List[str]})
GraphEdgesSpec = TypedDict("GraphEdgesSpec", {"count": int, "edges": List[str]})

arango = ArangoClient(
host=os.environ.get("ARANGO_HOST", "localhost"),
port=int(os.environ.get("ARANGO_PORT", "8529")),
)

def with_client(fun: Callable) -> Callable:
"""Call target function `fun`, passing in an authenticated ArangoClient object."""

def wrapper(*args: Any, **kwargs: Any) -> Callable:
kwargs["arango"] = kwargs.get(
"arango",
ArangoClient(
host=os.environ.get("ARANGO_HOST", "localhost"),
port=int(os.environ.get("ARANGO_PORT", "8529")),
),
)
return fun(*args, **kwargs)

return wrapper
def db(name: str) -> StandardDatabase:
"""Return a handle for Arango database `name`."""
return arango.db(
name, username="root", password=os.environ.get("ARANGO_PASSWORD", "letmein")
)


@with_client
def check_db(arango: ArangoClient) -> bool:
def check_db() -> bool:
"""Check the database to see if it's alive."""
try:
db("_system", arango=arango).has_database("test")
db("_system").has_database("test")
return True
except ConnectionError:
return False


@with_client
def db(name: str, arango: ArangoClient) -> StandardDatabase:
"""Return a handle for Arango database `name`."""
return arango.db(
name, username="root", password=os.environ.get("ARANGO_PASSWORD", "letmein")
)


@with_client
def create_workspace(name: str, arango: ArangoClient) -> None:
def create_workspace(name: str) -> None:
"""Create a new workspace named `name`."""
sysdb = db("_system", arango=arango)
sysdb = db("_system")
if not sysdb.has_database(name):
try:
sysdb.create_database(name)
Expand All @@ -79,77 +69,63 @@ def create_workspace(name: str, arango: ArangoClient) -> None:
raise AlreadyExists("Workspace", name)


@with_client
def delete_workspace(name: str, arango: ArangoClient) -> None:
def delete_workspace(name: str) -> None:
"""Delete the workspace named `name`."""
sysdb = db("_system", arango=arango)
sysdb = db("_system")
if sysdb.has_database(name):
sysdb.delete_database(name)


@with_client
def get_workspace(name: str, arango: ArangoClient) -> WorkspaceSpec:
def get_workspace(name: str) -> WorkspaceSpec:
"""Return a single workspace, if it exists."""
sysdb = db("_system", arango=arango)
sysdb = db("_system")
if not sysdb.has_database(name):
raise WorkspaceNotFound(name)

return {"name": name, "owner": "", "readers": [], "writers": []}


@with_client
def get_workspace_db(name: str, arango: ArangoClient) -> StandardDatabase:
def get_workspace_db(name: str) -> StandardDatabase:
"""Return the Arango database associated with a workspace, if it exists."""
get_workspace(name, arango=arango)
return db(name, arango=arango)
get_workspace(name)
return db(name)


@with_client
def get_graph_collection(
workspace: str, graph: str, arango: ArangoClient
) -> StandardCollection:
def get_graph_collection(workspace: str, graph: str) -> Graph:
"""Return the Arango collection associated with a graph, if it exists."""
space = get_workspace_db(workspace, arango=arango)
space = get_workspace_db(workspace)
if not space.has_graph(graph):
raise GraphNotFound(workspace, graph)

return space.graph(graph)


@with_client
def get_table_collection(
workspace: str, table: str, arango: ArangoClient
) -> StandardCollection:
def get_table_collection(workspace: str, table: str) -> StandardCollection:
"""Return the Arango collection associated with a table, if it exists."""
space = get_workspace_db(workspace, arango=arango)
space = get_workspace_db(workspace)
if not space.has_collection(table):
raise TableNotFound(workspace, table)

return space.collection(table)


@with_client
def get_workspaces(arango: ArangoClient) -> Generator[str, None, None]:
def get_workspaces() -> Generator[str, None, None]:
"""Return a list of all workspace names."""
sysdb = db("_system", arango=arango)
sysdb = db("_system")
return (workspace for workspace in sysdb.databases() if workspace != "_system")


@with_client
def workspace_tables(
workspace: str, table_type: TableType, arango: ArangoClient
workspace: str, table_type: TableType
) -> Generator[str, None, None]:
"""Return a list of all table names in the workspace named `workspace`."""

def edge_table(fields: Sequence[str]) -> bool:
return "_from" in fields and "_to" in fields

space = get_workspace_db(workspace, arango=arango)
space = get_workspace_db(workspace)
tables = (
(
table["name"],
edge_table(table_fields(workspace, table["name"], arango=arango)),
)
(table["name"], edge_table(table_fields(workspace, table["name"])))
for table in space.collections()
if not table["name"].startswith("_")
)
Expand All @@ -175,12 +151,9 @@ def is_node(x: Tuple[Any, bool]) -> bool:
return (table[0] for table in tables if desired_type(table))


@with_client
def workspace_table(
workspace: str, table: str, offset: int, limit: int, arango: ArangoClient
) -> dict:
def workspace_table(workspace: str, table: str, offset: int, limit: int) -> dict:
"""Return a specific table named `name` in workspace `workspace`."""
get_table_collection(workspace, table, arango=arango)
get_table_collection(workspace, table)

query = f"""
FOR d in {table}
Expand All @@ -200,12 +173,9 @@ def workspace_table(
return {"count": list(count)[0], "rows": list(rows)}


@with_client
def graph_node(
workspace: str, graph: str, table: str, node: str, arango: ArangoClient
) -> dict:
def graph_node(workspace: str, graph: str, table: str, node: str) -> dict:
"""Return the data associated with a particular node in a graph."""
space = get_workspace_db(workspace, arango=arango)
space = get_workspace_db(workspace)
graphs = filter(lambda g: g["name"] == graph, space.graphs())
try:
next(graphs)
Expand Down Expand Up @@ -233,41 +203,36 @@ def graph_node(
return {k: data[k] for k in data if k != "_rev"}


@with_client
def workspace_graphs(workspace: str, arango: ArangoClient) -> List[str]:
def workspace_graphs(workspace: str) -> List[str]:
"""Return a list of all graph names in workspace `workspace`."""
space = get_workspace_db(workspace, arango=arango)
space = get_workspace_db(workspace)
return [graph["name"] for graph in space.graphs()]


@with_client
def workspace_graph(workspace: str, graph: str, arango: ArangoClient) -> GraphSpec:
def workspace_graph(workspace: str, graph: str) -> GraphSpec:
"""Return a specific graph named `name` in workspace `workspace`."""
get_graph_collection(workspace, graph)

# Get the lists of node and edge tables.
node_tables = graph_node_tables(workspace, graph, arango=arango)
edge_table = graph_edge_table(workspace, graph, arango=arango)
node_tables = graph_node_tables(workspace, graph)
edge_table = graph_edge_table(workspace, graph)

return {"nodeTables": node_tables, "edgeTable": edge_table}


@with_client
def graph_nodes(
workspace: str, graph: str, offset: int, limit: int, arango: ArangoClient
) -> GraphNodesSpec:
def graph_nodes(workspace: str, graph: str, offset: int, limit: int) -> GraphNodesSpec:
"""Return the nodes of a graph."""
get_graph_collection(workspace, graph)

# Get the actual node data.
node_tables = graph_node_tables(workspace, graph, arango=arango)
node_tables = graph_node_tables(workspace, graph)
node_query = f"""
FOR c in [{", ".join(node_tables)}]
FOR d in c
LIMIT {offset}, {limit}
RETURN d
"""
nodes = aql_query(workspace, node_query, arango=arango)
nodes = aql_query(workspace, node_query)

# Get the total node count.
count_query = f"""
Expand All @@ -276,54 +241,47 @@ def graph_nodes(
COLLECT WITH COUNT INTO count
RETURN count
"""
count = aql_query(workspace, count_query, arango=arango)
count: int = next(aql_query(workspace, count_query))

return {"count": list(count)[0], "nodes": list(nodes)}
return {"count": count, "nodes": list(nodes)}


@with_client
def table_fields(workspace: str, table: str, arango: ArangoClient) -> List[str]:
def table_fields(workspace: str, table: str) -> List[str]:
"""Return a list of column names for `query.table` in `query.workspace`."""
space = db(workspace, arango=arango)
space = db(workspace)
if space.has_collection(table) and space.collection(table).count() > 0:
sample = space.collection(table).random()
return list(sample.keys())
else:
return []


@with_client
def delete_table(workspace: str, table: str, arango: ArangoClient) -> str:
def delete_table(workspace: str, table: str) -> str:
"""Delete a table."""
space = db(workspace, arango=arango)
space = db(workspace)
if space.has_collection(table):
space.delete_collection(table)

return table


@with_client
def aql_query(
workspace: str, query: str, arango: ArangoClient
) -> Generator[dict, None, None]:
def aql_query(workspace: str, query: str) -> Generator[Any, None, None]:
"""Perform an AQL query in the given workspace."""
aql = db(workspace, arango=arango).aql
aql = db(workspace).aql

cursor = aql.execute(query)
return cursor


@with_client
def create_graph(
workspace: str,
graph: str,
edge_table: str,
from_vertex_collections: Set[str],
to_vertex_collections: Set[str],
arango: ArangoClient,
) -> bool:
"""Create a graph named `graph`, defined by`node_tables` and `edge_table`."""
space = db(workspace, arango=arango)
space = db(workspace)
if space.has_graph(graph):
return False

Expand All @@ -344,37 +302,32 @@ def create_graph(
return True


@with_client
def delete_graph(workspace: str, graph: str, arango: ArangoClient) -> str:
def delete_graph(workspace: str, graph: str) -> str:
"""Delete graph `graph` from workspace `workspace`."""
space = db(workspace, arango=arango)
space = db(workspace)
if space.has_graph(graph):
space.delete_graph(graph)

return graph


@with_client
def graph_node_tables(
workspace: str, graph: str, arango: ArangoClient
) -> List[StandardCollection]:
def graph_node_tables(workspace: str, graph: str) -> List[str]:
"""Return the node tables associated with a graph."""
g = get_graph_collection(workspace, graph)
return g.vertex_collections()


@with_client
def graph_edge_table(
workspace: str, graph: str, arango: ArangoClient
) -> Optional[StandardCollection]:
def graph_edge_table(workspace: str, graph: str) -> str:
"""Return the edge tables associated with a graph."""
g = get_graph_collection(workspace, graph)
edge_collections = g.edge_definitions()

return None if not edge_collections else edge_collections[0]["edge_collection"]
if not edge_collections:
raise InternalServerError

return edge_collections[0]["edge_collection"]


@with_client
def node_edges(
workspace: str,
graph: str,
Expand All @@ -383,11 +336,10 @@ def node_edges(
offset: int,
limit: int,
direction: EdgeDirection,
arango: ArangoClient,
) -> GraphEdgesSpec:
"""Return the edges connected to a node."""
get_table_collection(workspace, table)
edge_table = graph_edge_table(workspace, graph, arango=arango)
edge_table = graph_edge_table(workspace, graph)

def query_text(filt: str) -> str:
return f"""
Expand Down

0 comments on commit 235c36e

Please sign in to comment.