In [None]:
# make sure we are working in module directory
repo_root = !git rev-parse --show-toplevel
module_path = repo_root[0] + "/backend/heatflask"
%cd $module_path

import sys
__package__ = "heatflask"
if ".." not in sys.path:
    sys.path.insert(0, "..")
    
# Make cells wider
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
# %load StreamCodecs.py
# Here we define a custom encoding/compression scheme for streams
# Run-Length-Diff encoding
#
#  It is RLE on successive differences, which in our case are small enough to
#  be 8 bit integers

import numpy as np

Nums = list[int] | list[float] | np.ndarray
RLDEncoded = bytes


def positive_non_decreasing(vals: Nums) -> bool:
    lastv = vals[0]
    if lastv < 0:
        return False

    i = 1
    while i < len(vals):
        if vals[i] < lastv:
            return False
        lastv = vals[i]
        i += 1
    return True


def rld_encode(vals: Nums) -> RLDEncoded:
    vals = (
        np.fromiter((v + 0.5 for v in vals), dtype="i4", count=len(vals))
        if type(vals[0]) is float
        else np.array(vals, dtype="i4")
    )

    increasing = positive_non_decreasing(vals)
    my_dtype = np.uint8 if increasing else np.int8
    rl_marker = 255 if increasing else -128
    max_reps = 254 if increasing else 126

    n = len(vals)
    encoded = np.empty(n, dtype=my_dtype)
    reps = 0
    j = 0

    v = vals[1]
    d = v - vals[0]

    for i in range(2, len(vals)):
        next_v = vals[i]
        next_d = next_v - v

        if (d == next_d) and (reps < max_reps):
            reps += 1

        else:
            if reps == 0:
                encoded[j] = d
                j += 1
            elif reps <= 2:
                reps += 1
                while reps:
                    encoded[j] = d
                    j += 1
                    reps -= 1
            else:
                encoded[j] = rl_marker
                encoded[j + 1] = d
                encoded[j + 2] = reps + 1
                j += 3
                reps = 0
        d = next_d
        v = next_v

    if reps == 0:
        encoded[j] = d
        j += 1
    elif reps == 1:
        encoded[j] = d
        encoded[j + 1] = d
        j += 2
    else:
        encoded[j] = rl_marker
        encoded[j + 1] = d
        encoded[j + 2] = reps + 1
        j += 3

    ntype = b"\x01" if increasing else b"\x00"
    firstval = np.array(vals[0], dtype=np.int16).tobytes()
    bytesdata = ntype + firstval + encoded[:j].tobytes()
    return bytesdata


def decoded_length(enc: np.ndarray, rl_marker: int) -> int:
    L = 1
    i = 0
    while i < len(enc):
        if enc[i] == rl_marker:
            L += enc[i + 2]
            i += 3
        else:
            L += 1
            i += 1
    return L


def rld_decode(enc: RLDEncoded, dtype=np.int32) -> Nums:
    ntype = np.frombuffer(enc, dtype="i1", count=1, offset=0)[0]
    start_val = np.frombuffer(enc, dtype="i2", count=1, offset=1)[0]
    enc_diffs = np.frombuffer(enc, dtype="i1" if ntype == 0 else "u1", offset=3)

    increasing = ntype != 0

    rl_marker = 255 if increasing else -128
    L = decoded_length(enc_diffs, rl_marker)

    decoded = np.empty(L, dtype=dtype)
    decoded[0] = start_val
    cumsum = start_val
    i = 0  # enc_diffs counter
    j = 1  # decoded counter
    while i < len(enc_diffs):
        if enc_diffs[i] == rl_marker:
            d = enc_diffs[i + 1]
            reps = enc_diffs[i + 2]
            endreps = j + reps
            while j < endreps:
                cumsum += d
                decoded[j] = cumsum
                j += 1
            i += 3
        else:
            cumsum += enc_diffs[i]
            decoded[j] = cumsum
            i += 1
            j += 1
    return decoded


In [None]:
# %load Streams.py
"""
Functions and constants pertaining to the Streams data store.  Each activity
has the streams time, latlng, and altitude.

***  For Jupyter notebook ***
Paste one of these Jupyter magic directives to the top of a cell
 and run it, to do these things:
    %%cython --annotate      # Compile and run the cell
    %load Streams.py         # Load Streams.py file into this (empty) cell
    %%writefile Streams.py   # Write the contents of this cell to Streams.py
"""

import os
import time
import datetime
from logging import getLogger
import msgpack
import polyline
import asyncio
import types
from typing import TypedDict, Awaitable, AsyncGenerator, Coroutine, cast

from . import DataAPIs
from .DataAPIs import db
from . import Strava
from . import StreamCodecs
from .Users import UserField as U


log = getLogger(__name__)
log.setLevel("DEBUG")
log.propagate = True

COLLECTION_NAME = "streams_v0"
CACHE_PREFIX = "S:"

SECS_IN_HOUR = 60 * 60
SECS_IN_DAY = 24 * SECS_IN_HOUR

MONGO_TTL = int(os.environ.get("MONGO_STREAMS_TTL", 10)) * SECS_IN_DAY
REDIS_TTL = int(os.environ.get("REDIS_STREAMS_TTL", 4)) * SECS_IN_HOUR
OFFLINE = os.environ.get("OFFLINE")

myBox = types.SimpleNamespace(collection=None)


async def get_collection():
    if myBox.collection is None:
        myBox.collection = await DataAPIs.init_collection(
            COLLECTION_NAME, ttl=MONGO_TTL, cache_prefix=CACHE_PREFIX
        )
    return myBox.collection


POLYLINE_PRECISION = 6


class EncodedStreams(TypedDict):
    t: StreamCodecs.RLDEncoded
    a: StreamCodecs.RLDEncoded
    p: str


PackedStreams = bytes


def encode_streams(rjson: Strava.Streams) -> PackedStreams:
    """compress stream data"""
    enc: EncodedStreams = {
        "t": StreamCodecs.rld_encode(rjson["time"]["data"]),
        "a": StreamCodecs.rld_encode(rjson["altitude"]["data"]),
        "p": polyline.encode(rjson["latlng"]["data"], POLYLINE_PRECISION),
    }
    return msgpack.packb(enc)


def decode_streams(msgpacked_streams: PackedStreams):
    """de-compress stream data"""
    d: EncodedStreams = msgpack.unpackb(msgpacked_streams)
    return {
        "time": StreamCodecs.rld_decode(d["t"], dtype="u2"),
        "altitude": StreamCodecs.rld_decode(d["a"], dtype="i2"),
        "latlng": polyline.decode(d["p"], POLYLINE_PRECISION),
    }


class StreamsDoc(TypedDict):
    _id: int
    mpk: PackedStreams
    ts: datetime.datetime


def mongo_doc(activity_id: int, packed: PackedStreams, ts=None) -> StreamsDoc:
    return {
        "_id": int(activity_id),
        "mpk": packed,
        "ts": ts or datetime.datetime.now(),
    }


def cache_key(aid: int):
    return f"{CACHE_PREFIX}{aid}"


StreamsQueryResult = tuple[int, PackedStreams]


async def strava_import(
    activity_ids: list[int], **user
) -> AsyncGenerator[StreamsQueryResult, bool]:
    uid = int(user[U.ID])

    strava = Strava.AsyncClient(uid, **user[U.AUTH])
    await strava.update_access_token()
    coll = await get_collection()

    mongo_docs = []
    now = datetime.datetime.now()
    aiterator = strava.get_many_streams(activity_ids)

    async with db.redis.pipeline(transaction=True) as pipe:
        async for aid, streams in aiterator:
            packed = encode_streams(streams)

            # queue packed streams to be redis cached
            pipe = pipe.setex(cache_key(aid), REDIS_TTL, packed)

            mongo_docs.append(mongo_doc(aid, packed, ts=now))

            abort_signal = yield aid, packed

            if abort_signal:
                await Strava.AsyncClient.abort(aiterator)
                break

        await pipe.execute()
    await coll.insert_many(mongo_docs)


async def aiter_query(
    activity_ids: list[int], user=None
) -> AsyncGenerator[StreamsQueryResult, bool]:
    if not activity_ids:
        return
    #
    # First we check Redis cache
    #
    t0 = time.perf_counter()
    keys = [cache_key(aid) for aid in activity_ids]
    redis_response = await db.redis.mget(keys)

    # Reset TTL for those cached streams that were hit
    async with db.redis.pipeline(transaction=True) as pipe:
        for k, val in zip(keys, redis_response):
            if val:
                pipe = pipe.expire(k, REDIS_TTL)
        await pipe.execute()

    t1 = time.perf_counter()
    local_result: list[StreamsQueryResult] = [
        (a, s) for a, s in zip(activity_ids, redis_response) if s
    ]
    log.debug(
        "retrieved %d streams from Redis in %d", len(local_result), (t1 - t0) * 1000
    )

    #
    # Next we query MongoDB for streams that were not in Redis
    #
    # activity IDs of cache misses
    activity_ids = [a for a, s in zip(activity_ids, redis_response) if not s]
    if activity_ids:
        # Next we query MongoDB for any cache misses
        t0 = time.perf_counter()
        streams = await get_collection()
        query = {"_id": {"$in": activity_ids}}
        exclusions = {"ts": False}

        cursor = streams.find(query, projection=exclusions)
        mongo_result = [(doc["_id"], doc["mpk"]) async for doc in cursor]
        local_result.extend(mongo_result)
        mongo_result_ids = [_id for _id, mpk in mongo_result]

        # Cache the mongo hits
        async with db.redis.pipeline(transaction=True) as pipe:
            for aid, s in mongo_result:
                pipe = pipe.setex(cache_key(aid), REDIS_TTL, s)
            await pipe.execute()

        # Update TTL for mongo hits
        await streams.update_many(
            {"_id": {"$in": mongo_result_ids}},
            {"$set": {"ts": datetime.datetime.utcnow()}},
        )
        elapsed = (time.perf_counter() - t0) * 1000
        log.debug("retrieved %d streams from Mongo in %d", len(mongo_result), elapsed)

        activity_ids = list(set(activity_ids) - set(mongo_result_ids))

    streams_import = None
    first_fetch = None
    if activity_ids and (user is not None) and (not OFFLINE):
        # Start a fetch process going. We will get back to this...
        t0 = time.perf_counter()
        streams_import = strava_import(activity_ids, **user)
        first_fetch = asyncio.create_task(cast(Coroutine, streams_import.__anext__()))

    # Yield all the results from Redis and Mongo
    for item in local_result:
        abort_signal = yield item
        if abort_signal:
            log.info("Local Streams query aborted")
            if streams_import:
                await Strava.AsyncClient.abort(streams_import)
            break

    if streams_import:
        # Now we yield results of fetches as they come in
        item1: StreamsQueryResult = await cast(Awaitable, first_fetch)
        abort_signal = yield item1
        imported_items = [item1]

        if not abort_signal:
            async for item in streams_import:
                imported_items.append(item)
                abort_signal = yield item
                if abort_signal:
                    break

        if abort_signal:
            Strava.AsyncClient.abort(streams_import)
            log.info("Remote Streams query aborted")

        t1 = time.perf_counter()
        log.debug(
            "retrieved %d streams from Strava in %d",
            len(imported_items),
            (t1 - t0) * 1000,
        )
        imported_ids = set(aid for aid, mpk in imported_items)
        missing_ids = set(activity_ids) - imported_ids
        if missing_ids:
            log.info("unable to import streams for %s", missing_ids)


async def query(**kwargs) -> list[StreamsQueryResult]:
    return [s async for s in aiter_query(**kwargs)]


async def delete(activity_ids: list[int]):
    if not activity_ids:
        return
    streams = await get_collection()
    await streams.delete_many({"_id": {"$in": activity_ids}})
    keys = [cache_key(aid) for aid in activity_ids]
    await db.redis.delete(*keys)


async def clear_cache():
    streams_keys = await db.redis.keys(cache_key("*"))
    if streams_keys:
        return await db.redis.delete(*streams_keys)


def stats():
    return DataAPIs.stats(COLLECTION_NAME)


def drop():
    return DataAPIs.drop(COLLECTION_NAME)


In [None]:
import logging
logging.basicConfig(level="DEBUG")

await DataAPIs.connect(None, None)

N_FETCH = 15

from . import Index
result = await Index.query(limit=N_FETCH)
activity_ids = [d["_id"] for d in result["docs"]]
result

In [None]:
import asyncio
from . import Users
from . import Strava

admin = await Users.get(Users.ADMIN[0])
admin

In [None]:
q = await query(activity_ids=activity_ids, user=admin)

In [None]:
a,b = zip(*q)
a,b

In [None]:
await stats()

In [None]:
await DataAPIs.disconnect()