# core

> This is a single notebook that contains all the source code, yay!

In [None]:
#| default_exp core

In [2]:
#| hide
from nbdev.showdoc import *
import json
from diskcache import Cache
import hashlib
import orjson
import sqlite_vec
import pysqlite3

from typing import List
import struct


def serialize_f32(vector: List[float]) -> bytes:
    """Serializes a list of floats into a compact "raw bytes" format."""
    return struct.pack("%sf" % len(vector), *vector)



In [5]:
#| export

class VectoLite:
    def __init__(self, path: str):
        """
        Initializes the VectoLite instance with a connection to the SQLite database.

        Args:
            path (str): The path to the SQLite database file.
        """
        self.path = path
        self.db = pysqlite3.connect(f'{path}.sqlite')
        self.db.enable_load_extension(True)
        sqlite_vec.load(self.db)
        self.db.enable_load_extension(False)
        self.cache = Cache(path)
        self.rownums = None
        self.table_name = 'myvecs'

    def print_version(self):
        """
        Prints the SQLite and SQLite-vec versions.
        """
        sqlite_version, vec_version = self.db.execute(
            "select sqlite_version(), vec_version()"
        ).fetchone()
        print(f"sqlite_version={sqlite_version}, vec_version={vec_version}")
    
    @property
    def table_exists(self):
        """
        Checks if a table exists in the SQLite database.

        Args:
            table_name (str): The name of the table to check.

        Returns:
            bool: True if the table exists, False otherwise.
        """
        return self.db.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{self.table_name}'").fetchone() is not None

    @property
    def table_len(self):
        """
        Returns the number of rows in the specified table. Will also cache the number of rows internally when called.

        Args:
            table_name (str): The name of the table.

        Returns:
            int: The number of rows in the table.
        """
        if not self.rownums:
            self.rownums = self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}").fetchone()[0]
        return self.rownums
    
    def parse_item(self, item):
        """
        Parses an item and returns its MD5 hash, serialized contents, and serialized vector.

        Args:
            item (dict): The item to parse.

        Returns:
            tuple: A tuple containing the MD5 hash (str), serialized contents (str), and serialized vector (bytes).
        """
        contents = orjson.dumps({k: v for k, v in item.items() if k != 'vector'})
        md5_hash = hashlib.md5(contents).hexdigest()
        return md5_hash, contents, item['vector']

    def insert(self, table_name, stream):
        """
        Inserts a stream of items into the specified table.

        Args:
            table_name (str): The name of the table to insert the items into.
            stream (iterable): An iterable stream of items to insert.
            binary (bool, optional): Whether to insert the items in binary format. Defaults to False.
        """
        with self.db:
            for item in stream:
                md5_hash, contents, vector = self.parse_item(item)
        
                # Edge case: if the table does not exist, create it
                if not self.table_exists:
                    self.db.execute(f"CREATE VIRTUAL TABLE {table_name} USING vec0(embedding float[{len(vector)}])")
                    self.rownums = 0

                # If we have already inserted this item, no need to add again
                if md5_hash in self.cache:
                    return

                # Insert the item into the table
                i = self.table_len + 1
                self.db.execute(
                    f"INSERT INTO {table_name}(rowid, embedding) VALUES (?, ?)",
                    [i, serialize_f32(vector)],
                )
                self.cache[i] = contents
                self.cache[md5_hash] = i
                self.rownums += 1
    
    def query_idx(self, query, k=5):
        """
        Queries the specified table for the nearest neighbors to the given query vector.

        Args:
            query (list): The query vector.

        Returns:
            tuple: A tuple containing the rowids and distances of the nearest neighbors.
        """
        results = self.db.execute(
            f"""
              SELECT
                rowid,
                distance
              FROM {self.table_name}
              WHERE embedding MATCH ?
              ORDER BY distance
              LIMIT {k}
            """,
            [serialize_f32(query)],
        ).fetchall()
        return list(zip(*results))

    def query(self, query, k=5):
        """
        Queries the specified table for the nearest neighbors to the given query vector.

        Args:
            query (list): The query vector.

        Returns:
            list: A list of the nearest neighbors.
        """
        idxs, dists = self.query_idx(query, k)
        return [json.loads(self.cache[i].decode()) for i in idxs], dists

In [6]:
db = VectoLite("simsity.sqlite")
db.print_version()

sqlite_version=3.46.1, vec_version=v0.1.1


In [7]:
items = [
    (1, [0.1, 0.1, 0.1, 0.1]),
    (2, [0.2, 0.2, 0.2, 0.2]),
    (3, [0.3, 0.3, 0.3, 0.3]),
    (4, [0.4, 0.4, 0.4, 0.4]),
    (5, [0.5, 0.5, 0.5, 0.5]),
]
query = [0.3, 0.3, 0.3, 0.3]

db.insert('myvecs', [{'vector': item[1], "i": i} for i, item in enumerate(items)])
print(db.query(query, k=3))

([{'i': 2}, {'i': 3}, {'i': 1}], (0.0, 0.19999998807907104, 0.20000001788139343))


In [8]:
#| hide
import nbdev; nbdev.nbdev_export()