Skip to content

Commit

Permalink
chore: resolve circular dependecy
Browse files Browse the repository at this point in the history
  • Loading branch information
joennlae committed Sep 18, 2023
1 parent 009e953 commit d3cb359
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 39 deletions.
6 changes: 4 additions & 2 deletions src/python/halutmatmul/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import numba
from numba import prange

from halutmatmul.maddness import MultiSplit
from halutmatmul.halutmatmul import HalutMatmul
from halutmatmul.maddness_multisplit import MultiSplit

if TYPE_CHECKING: # otherwise circular dependency
from halutmatmul.halutmatmul import HalutMatmul


@numba.jit(parallel=True, nopython=True)
Expand Down
12 changes: 6 additions & 6 deletions src/python/halutmatmul/halutmatmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
from __future__ import annotations
from functools import reduce
from typing import Any, Dict, Optional
from sklearn.cluster import KMeans
import faiss

import numpy as np

import numba

from halutmatmul.maddness import (
learn_proto_and_hash_function,
maddness_lut,
maddness_quantize_luts,
)

from halutmatmul.functions import (
get_str_hash_buckets,
halut_encode_opt,
read_luts_opt,
read_luts_quantized_opt,
)
from halutmatmul.maddness import (
learn_proto_and_hash_function,
maddness_lut,
maddness_quantize_luts,
)


class HalutOfflineStorage:
Expand Down
32 changes: 1 addition & 31 deletions src/python/halutmatmul/maddness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numba
from sklearn import linear_model

from halutmatmul.maddness_multisplit import MultiSplit
from halutmatmul.functions import halut_encode_opt, split_lists_to_numpy


Expand Down Expand Up @@ -367,37 +368,6 @@ def create_codebook_start_end_idxs(X, number_of_codebooks, algo="start"):
return idxs


# untouched from maddness except for comments etc.
class MultiSplit:
__slots__ = "dim vals scaleby offset".split()

def __init__(self, dim, vals, scaleby=None, offset=None):
self.dim = dim
self.vals = np.asarray(vals)
self.scaleby = scaleby
self.offset = offset

def __repr__(self) -> str:
return f"<{self.get_params()}>"

def __str__(self) -> str:
return self.get_params()

def get_params(self) -> str:
params = (
f"Multisplit: dim({self.dim}), vals({self.vals}), "
f"scaleby({self.scaleby}), offset({self.offset})"
)
return params

def preprocess_x(self, x: np.ndarray) -> np.ndarray:
if self.offset is not None:
x = x - self.offset
if self.scaleby is not None:
x = x * self.scaleby
return x


def learn_binary_tree_splits(
X: np.ndarray,
K: int = 16,
Expand Down
35 changes: 35 additions & 0 deletions src/python/halutmatmul/maddness_multisplit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# extracted from https://github.com/dblalock/bolt
# SPDX-License-Identifier: MPL-2.0 (as before)
# this file is only needed to resolve a circular dependency

import numpy as np


class MultiSplit:
__slots__ = "dim vals scaleby offset".split()

def __init__(self, dim, vals, scaleby=None, offset=None):
self.dim = dim
self.vals = np.asarray(vals)
self.scaleby = scaleby
self.offset = offset

def __repr__(self) -> str:
return f"<{self.get_params()}>"

def __str__(self) -> str:
return self.get_params()

def get_params(self) -> str:
params = (
f"Multisplit: dim({self.dim}), vals({self.vals}), "
f"scaleby({self.scaleby}), offset({self.offset})"
)
return params

def preprocess_x(self, x: np.ndarray) -> np.ndarray:
if self.offset is not None:
x = x - self.offset
if self.scaleby is not None:
x = x * self.scaleby
return x

0 comments on commit d3cb359

Please sign in to comment.