In [1]:
import sqlite3
from typing import Optional

from numba.experimental import jitclass
from slumba import sqlite_udaf, create_aggregate

In [2]:
@sqlite_udaf
@jitclass
class Avg:
    total: float
    count: int
    
    def __init__(self) -> None:
        self.total = 0.0
        self.count = 0
        
    def step(self, value: Optional[float]) -> None:
        if value is not None:
            self.total += value
            self.count += 1
        
    def finalize(self) -> Optional[float]:
        count = self.count
        if count:
            return self.total / count
        return None
    
    def value(self) -> Optional[float]:
        return self.finalize()
    
    def inverse(self, value: Optional[float]) -> None:
        if value is not None:
            self.total -= value
            self.count -= 1

In [3]:
con = sqlite3.connect(':memory:')

In [4]:
create_aggregate(con, 'my_avg', 1, Avg)

In [5]:
query = """
WITH t AS (
  SELECT 1 AS c UNION
  SELECT 2 AS c UNION
  SELECT NULL AS c
)
SELECT my_avg(c) OVER (ORDER BY c) AS my_udaf,
       avg(c) OVER (ORDER BY c) AS builtin_udaf
FROM t
"""

In [6]:
con.execute(query).fetchall()

[(None, None), (1.0, 1.0), (1.5, 1.5)]