/
datastore.py
85 lines (68 loc) · 2.42 KB
/
datastore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import faiss
import warnings
from utils.constants import *
class DataStore(object):
"""
Builds a Datastore object, i.e, the set of all key-value pairs constructed from the training examples.
One training example, consists in a vector (of size `dimension`) and a label (`int`).
The datastore has a maximal number of examples that it can store (`capacity`).
Attributes
----------
capacity:
strategy:
dimension:
rng (numpy.random._generator.Generator):
index:
labels:
Methods
-------
__init__
build
clear
"""
def __init__(self, capacity, strategy, dimension, rng):
self.capacity = capacity
self.strategy = strategy
self.dimension = dimension
self.rng = rng
self.index = faiss.IndexFlatL2(self.dimension)
self.labels = None
@property
def strategy(self):
return self.__strategy
@strategy.setter
def strategy(self, strategy):
if strategy in ALL_STRATEGIES:
self.__strategy = strategy
else:
warnings.warn("strategy is set to random!", RuntimeWarning)
self.__strategy = "random"
def build(self, train_vectors, train_labels):
"""
add vectors to `index` according to `strategy`
:param train_vectors:
:type train_vectors: numpy.array os shape (n_samples, dimension)
:param train_labels:
:type train_labels: numpy.array of shape (n_samples,)
"""
if self.capacity <= 0:
return
n_train_samples = len(train_vectors)
if n_train_samples <= self.capacity:
self.index.add(train_vectors)
self.labels = train_labels
return
if self.strategy == "random":
selected_indices = self.rng.choice(n_train_samples, size=self.capacity, replace=False)
selected_vectors = train_vectors[selected_indices]
selected_labels = train_labels[selected_indices]
self.index.add(selected_vectors)
self.labels = selected_labels
else:
raise NotImplementedError(f"{self.strategy} is not implemented")
def clear(self):
"""
empties the datastore by reinitializing `index` and clearing `labels`
"""
self.index = faiss.IndexFlatL2(self.dimension)
self.labels = None