In [None]:
import pathlib
import sqlite3
from contextlib import contextmanager


In [None]:
class SQLPriorityQueue:
    def __init__(self, filename=None, memory=False, **kwargs):
        
        if memory or filename == ":memory:":
            self.conn = sqlite3.connect(":memory:", isolation_level=None, **kwargs)
        elif isinstance(filename, (str, pathlib.Path)):
            self.conn = sqlite3.connect(str(filename), isolation_level=None, **kwargs)
            self.conn.execute("PRAGMA journal_mode = 'WAL';")
            self.conn.execute("PRAGMA temp_store = 2;")
            self.conn.execute("PRAGMA synchronous = 1;")
            self.conn.execute(f"PRAGMA cache_size = {-1 * 64_000};")
        else:
            assert filename is not None
            self.conn = filename
            self.conn.isolation_level = None
            
        self.conn.row_factory = sqlite3.Row
        
        with self.transaction():
            self.conn.execute(
                """CREATE TABLE IF NOT EXISTS Queue
                ( message TEXT NOT NULL,
                  message_id TEXT,
                  status INTEGER,
                  in_time INTEGER NOT NULL DEFAULT (strftime('%s','now')),
                  lock_time INTEGER,
                  done_time INTEGER,
                  priority INTEGER DEFAULT 0 )
                """
            )
    
            self.conn.execute("CREATE INDEX IF NOT EXISTS TIdx ON Queue(message_id)")
            self.conn.execute("CREATE INDEX IF NOT EXISTS SIdx ON Queue(status)")
            
    def put(self, message):
        """
        Insert a new message
        """
        
        with self.transaction(mode="IMMEDIATE"):
            rid = self.conn.execute(
                """
                INSERT INTO Queue  (message, message_id, status, in_time, lock_time, done_time, priority)
                VALUES (:message, lower(hex(randomblob(16))), 0, strftime('%s','now'), NULL, NULL, (SELECT COALESCE( MAX( priority ), 0 ) + 1 FROM Queue WHERE STATUS = 0))
                """,
                {"message": message},
            ).lastrowid

        return rid
    
    def pop(self):
        with self.transaction(mode="IMMEDIATE"):
            message = self.conn.execute(
                """
                UPDATE Queue SET status = 1, lock_time = strftime('%s','now')
                WHERE rowid = (SELECT min(rowid) FROM Queue
                WHERE status = 0)
                RETURNING *;
                """
            ).fetchone()

            if not message:
                return None

            return dict(message)
        
    def update_priority(self, message_id, priority):
        with self.transaction(mode="IMMEDIATE"):
            x = self.conn.execute("""
                UPDATE Queue SET priority = priority + 1 WHERE priority >= :priority
                """,
                {"priority": priority}
            ).lastrowid
            x = self.conn.execute("""
                UPDATE Queue SET priority = :priority WHERE message_id = :message_id
                """,
                {"message_id": message_id, "priority": priority}
            ).lastrowid
    
    def peek(self):
        "Show next message to be popped."
        value = self.conn.execute(
            "SELECT * FROM Queue WHERE status = 0 ORDER BY rowid LIMIT 1"
        ).fetchone()
        return dict(value)

    def get(self, message_id=None, limit=100):
        "Get a message by its `message_id` if supplied or all up to limit"
        
        if message_id is not None:
            value = self.conn.execute(
                "SELECT * FROM Queue WHERE message_id = :message_id",
                {"message_id": message_id},
            ).fetchone()
            return dict(value) if value is not None else value
        else:
            value = self.conn.execute(
                "SELECT * FROM Queue ORDER BY priority LIMIT :limit",
                {"limit": limit}
            )
            return [dict(v) for v in value]
            
    def done(self, message_id):
        """
        Mark message as done.
        If executed multiple times, `done_time` will be
        the last time this function is called.
        """

        x = self.conn.execute(
            "UPDATE Queue SET status = 2,  done_time = strftime('%s','now') WHERE message_id = :message_id",
            {"message_id": message_id},
        ).lastrowid
        return x

    def qsize(self):
        return next(self.conn.execute("SELECT COUNT(*) FROM Queue WHERE status != 2"))[0]

    def empty(self):
        value = self.conn.execute(
            "SELECT COUNT(*) as cnt FROM Queue WHERE status = 0"
        ).fetchone()
        return not bool(value["cnt"])
    
    @contextmanager
    def transaction(self, mode="DEFERRED"):
        if mode not in {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}:
            raise ValueError(f"Transaction mode '{mode}' is not valid")
        self.conn.execute(f"BEGIN {mode}")
        try:
            # Yield control back to the caller.
            yield
        except BaseException as e:
            self.conn.rollback()  # Roll back all changes if an exception occurs.
            raise e
        else:
            self.conn.commit()
            