In [None]:
# default_exp pod.api
%load_ext autoreload
%autoreload 2

In [None]:
# export
import os
from typing import Any, Dict, List, Generator, Deque
import requests
import urllib
from hashlib import sha256
from collections import deque
import json

In [None]:
# export
DEFAULT_POD_ADDRESS = os.environ.get("POD_ADDRESS") or "http://localhost:3030"
POD_VERSION = "v4"

In [None]:
# export
class PodError(Exception):
    def __init__(self, status=None, message=None, **kwargs) -> None:
        super().__init__(status, message, **kwargs)
        self.status = status
        self.message = message

    def __str__(self) -> str:
        return " ".join([str(a) for a in self.args if a])

In [None]:
# export
class PodAPI:
    def __init__(
        self,
        database_key: str,
        owner_key: str,
        url: str = DEFAULT_POD_ADDRESS,
        version: str = POD_VERSION,
        auth_json: dict = None,
        verbose: bool = True,
    ) -> None:
        self.verbose = verbose
        self.database_key = database_key
        self.owner_key = owner_key
        self.version = version
        self._url = url
        self.base_url = f"{url}/{version}/{self.owner_key}"
        self.auth_json = self._create_auth(auth_json)

    def _create_auth(self, auth_json: dict = None) -> dict:
        if auth_json is not None:
            return {"type": "PluginAuth", **auth_json}
        else:
            return {"type": "ClientAuth", "databaseKey": self.database_key}

    def test_connection(self) -> bool:
        try:
            res = requests.get(self._url)
            if self.verbose:
                print("Succesfully connected to pod")
            return True
        except requests.exceptions.RequestException as e:
            print("Could no connect to backend")
            return False

    @property
    def pod_version(self) -> dict:
        response = requests.get(f"{self._url}/version")
        if response.status_code != 200:
            raise PodError(response.status_code, response.text)
        return response.json()

    def post(self, endpoint: str, payload: Any) -> Any:
        body = {"auth": self.auth_json, "payload": payload}
        response = requests.post(f"{self.base_url}/{endpoint}", json=body)
        if response.status_code != 200:
            raise PodError(response.status_code, response.text)
        return response

    def get_item(self, uid: str) -> dict:
        return self.post("get_item", uid).json()

    def create_item(self, item: dict) -> str:
        return self.post("create_item", item).json()

    def update_item(self, item: dict) -> list:
        return self.post("update_item", item).json()

    def get_edges(
        self, uid: str, direction: str = "Outgoing", expand_items: bool = True
    ) -> List[dict]:
        payload = {"item": uid, "direction": direction, "expandItems": expand_items}
        return self.post("get_edges", payload).json()

    def create_edge(self, edge: dict) -> str:
        return self.post("create_edge", edge).json()

    def delete_item(self, uid) -> list:
        return self.post("delete_item", uid).json()

    def search(self, query: dict) -> List[dict]:
        return self.post("search", query).json()
    
    def search_paginate(self, query: dict, limit: int = 32, even_page_size=True) -> Generator:
        """
        The Pod returns uneven page sizes when paginating, which can be an issue for some applications.
        `search_paginate` wraps the pagination, and always returns pages of size `limit` by storing overflow items
        in a queue.
        """
        paginator = self._paginate(query, limit)
        
        if not even_page_size:
            yield from paginator
        
        remaining: Deque[Item] = deque()
        while True:
            if len(remaining) >= limit:
                yield [remaining.popleft() for _ in range(limit)]
            if len(remaining) < limit:
                try:
                    remaining.extend(next(paginator))
                except StopIteration:
                    break

        while len(remaining):
            yield [remaining.popleft() for _ in range(min(limit, len(remaining)))]


    def _paginate(self, query: dict, limit: int = 32) -> Generator:
        if (
            "_limit" in query
            or "dateServerModified" in query
            or "dateServerModified>=" in query
            or "dateServerModified<" in query
        ):
            raise ValueError("Cannot paginate query that contains a date or limit.")
        if "_sortOrder" in query:
            raise NotImplementedError("Only 'Asc' order is supported.")

        query = {**query, "_limit": limit}
        response = self.search(query)
        if not len(response):
            return
            yield

        next_date = 0
        while True:
            query["dateServerModified>="] = next_date
            response = self.search(query)
            if not len(response):
                break

            next_date = response[-1]["dateServerModified"] + 1
            yield response

    def bulk(
        self,
        create_items: List[dict] = None,
        update_items: List[dict] = None,
        create_edges: List[dict] = None,
        delete_items: List[str] = None,
        search: List[dict] = None,
    ) -> Dict[str, Any]:

        payload = {
            "createItems": create_items,
            "updateItems": update_items,
            "createEdges": create_edges,
            "deleteItems": delete_items,
            "search": search,
        }
        payload = {k: v for k, v in payload.items() if v is not None}
        return self.post("bulk", payload).json()

    def upload_file(self, file: bytes) -> Any:
        if self.auth_json.get("type") == "PluginAuth":
            # alternative file upload for plugins, with different authentication
            return self.upload_file_b(file)

        sha = sha256(file).hexdigest()
        result = requests.post(
            f"{self.base_url}/upload_file/{self.database_key}/{sha}", data=file
        )
        if result.status_code != 200:
            raise PodError(result.status_code, result.text)

        return result

    def upload_file_b(self, file: bytes) -> Any:
        sha = sha256(file).hexdigest()
        auth = urllib.parse.quote(json.dumps(self.auth_json))
        result = requests.post(f"{self.base_url}/upload_file_b/{auth}/{sha}", data=file)
        if result.status_code != 200:
            raise PodError(result.status_code, result.text)

        return result

    def get_file(self, sha: str) -> bytes:
        return self.post("get_file", {"sha256": sha}).content

    def send_email(self, to: str, subject: str = "", body: str = "") -> Any:
        payload = {"to": to, "subject": subject, "body": body}
        return self.post("send_email", payload)
    
    def plugin_run_logs(self, run_id: str) -> Any:
        return self.post("get_pluginrun_log", run_id).json()

## Test setup

In [None]:
from pymemri.pod.client import PodClient
from pymemri.data.schema import Account, Person
from pymemri.data.itembase import Edge

client = PodClient()
api = PodAPI(database_key=client.database_key, owner_key=client.owner_key)

In [None]:
# Create dummy data
person = Person(displayName="Alice")
accounts = [
    Account(identifier="Alice", service="whatsapp"),
    Account(identifier="Alice", service="instagram"),
    Account(identifier="Alice", service="gmail")
]

edges = [
    Edge(account, person, "owner") for account in accounts
]

client.add_to_schema(Account, Person)
client.bulk_action(create_items = [person] + accounts, create_edges=edges)

# Create data for search
search_accounts = [Account(identifier=str(i), service="search") for i in range(100)]
client.bulk_action(create_items=search_accounts)

BULK: Writing 7/7 items/edges
Completed Bulk action, written 7 items/edges
BULK: Writing 100/100 items/edges
Completed Bulk action, written 100 items/edges


True

In [None]:
from pymemri.data.schema import Message

In [None]:
client.bulk_action(create_items=[Message(isMock=False) for i in range (1000)])
client.bulk_action(create_items=[Message(isMock=True) for i in range (10)])
client.bulk_action(create_items=[Message(isMock=False) for i in range (10)])
client.bulk_action(create_items=[Message(isMock=True) for i in range (10)])
client.bulk_action(create_items=[Message(isMock=False) for i in range (10)])
client.bulk_action(create_items=[Message(isMock=True) for i in range (10)])
client.bulk_action(create_items=[Message(isMock=False) for i in range (1000)])


BULK: Writing 1000/1000 items/edges
Completed Bulk action, written 1000 items/edges
BULK: Writing 10/10 items/edges
Completed Bulk action, written 10 items/edges
BULK: Writing 10/10 items/edges
Completed Bulk action, written 10 items/edges
BULK: Writing 10/10 items/edges
Completed Bulk action, written 10 items/edges
BULK: Writing 10/10 items/edges
Completed Bulk action, written 10 items/edges
BULK: Writing 10/10 items/edges
Completed Bulk action, written 10 items/edges
BULK: Writing 1000/1000 items/edges
Completed Bulk action, written 1000 items/edges


True

In [None]:
len(client.search({"type": "Message", "_limit": 30, "isMock": True}))

30

## Tests

In [None]:
version = api.pod_version

assert "cargo" in version

### Get, Create

In [None]:
# Get created person
person = api.get_item(person.id)

assert len(person)
assert person[0]["displayName"] == 'Alice'

In [None]:
# Add new person
new_person = {"displayName": "clara", "type": "Person"}
new_id = api.create_item(new_person)

assert len(new_id) and isinstance(new_id, str)

In [None]:
# Update new person
new_person_updated = {"displayName": "Clara", "id": new_id, "birthDate": 631152000000}
api.update_item(new_person_updated)

new_person = api.get_item(new_id)
assert new_person[0]["displayName"] == "Clara"
assert new_person[0]["birthDate"] == 631152000000

### Edges

In [None]:
# Get outgoing edges
edges = api.get_edges(person[0]["id"], direction="Incoming")
assert len(edges) == 3
edge_items = [edge["item"] for edge in edges]
assert set(item["id"] for item in edge_items) == set(account.id for account in accounts)

# Get incoming edges
for account in accounts:
    edges = api.get_edges(account.id, direction="Outgoing")
    assert len(edges) == 1
    item = edges[0]["item"]
    assert item["id"] == person[0]["id"]

In [None]:
# Create edges
src_id = person[0]["id"]
tgt_id = new_person[0]["id"]
edge = {"_source": src_id, "_target": tgt_id, "_name": "relationship"}

edge_id = api.create_edge(edge)

edges = api.get_edges(person[0]["id"], direction="Outgoing")
assert len(edges) == 1
assert edges[0]["name"] == "relationship"
assert edges[0]["item"]["id"] == tgt_id

### Delete

In [None]:
api.delete_item(new_person[0]["id"])

item = api.get_item(new_person[0]["id"])
assert item[0]["deleted"] == True

### Search

In [None]:
# Test search with 100 new accounts
results = api.search({"type": "Account", "service": "search"})

assert len(results) == 100
assert set([result["id"] for result in results]) == set([account.id for account in search_accounts])

In [None]:
# Test paginated search
paginator = api.search_paginate({"type": "Account", "service": "search"}, limit=10, even_page_size=False)

results = []
for page in paginator:
    assert len(page)
    results.extend(page)
    
assert len(results) == 100, len(results)
assert set([result["id"] for result in results]) == set([account.id for account in search_accounts])
identifiers = [result["identifier"] for result in results]

# Test paginated search with even page sizes
paginator = api.search_paginate({"type": "Account", "service": "search"}, limit=10, even_page_size=True)

results = []
for page in paginator:
    assert len(page) == 10
    results.extend(page)
    
assert len(results) == 100, len(results)
assert set([result["id"] for result in results]) == set([account.id for account in search_accounts])
identifiers2 = [result["identifier"] for result in results]

assert identifiers == identifiers2

In [None]:
# Empty search
result = api.search({"type": "Account", "service": "wrong_service"})
assert result == []

In [None]:
# Empty paginated search
result = api.search_paginate({"type": "Account", "service": "wrong_service"}, limit=10)

assert isinstance(result, Generator)
try:
    page = next(result)
    assert False, "StopIteration expected for empty generator"
except StopIteration:
    pass

### Bulk

In [None]:
# Test Bulk
bulk_accounts = [
    {"type": "Account", "identifier": str(i), "service": "bulk"} for i in range(10)
]

result = api.bulk(create_items=bulk_accounts)
assert len(result["createItems"]) == 10

### Error handling

In [None]:
# Test error: duplicate id
new_person = {"displayName": "clara", "type": "Person"}
new_id = api.create_item(new_person)

new_person = {"displayName": "clara", "type": "Person", "id": new_id}
try:
    api.create_item(new_person)
    assert False, "PodError 500 expected"
except PodError as e:
    print("Expected error raised:")
    print(e)
    pass

Expected error raised:
500 Failure: Database rusqlite error UNIQUE constraint failed: items.id, Failed to execute insert_item with parameters


In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted basic.ipynb.
Converted cvu.utils.ipynb.
Converted data.dataset.ipynb.
Converted data.loader.ipynb.
Converted data.photo.ipynb.
Converted exporters.exporters.ipynb.
Converted index.ipynb.
Converted itembase.ipynb.
Converted plugin.authenticators.credentials.ipynb.
Converted plugin.authenticators.oauth.ipynb.
Converted plugin.listeners.ipynb.
Converted plugin.pluginbase.ipynb.
Converted plugin.states.ipynb.
Converted plugins.authenticators.password.ipynb.
Converted pod.api.ipynb.
Converted pod.client.ipynb.
Converted pod.db.ipynb.
Converted pod.utils.ipynb.
Converted template.config.ipynb.
Converted template.formatter.ipynb.
Converted test_schema.ipynb.
Converted test_utils.ipynb.
