Skip to content

Commit

Permalink
Feat/indexing faissless (#173)
Browse files Browse the repository at this point in the history
* fix: fix searcher always being reloaded

* feat: implement torch kmeans

* chore: lower cutoff

* chore: move warning

* chore: higher kmeans batch size

* chore: argument support

* chore: restore all default behaviour when `use_faiss` is True after having been false

* chore: lint

* chore: print exception if one occurs when using pytorch indexing

* chore: make _original_train_kmeans robust to subsequent calls

* nit: comment

feat: rework kmeans to be closer to FAISS

chore: store kmeans functions as class attributes

fix: method assignment

chore: more memory efficient

lint

chore: lower bsize, resultd unaffected

feat: better batching, slower max doc count

chore: batch size safe for 8gb GPUs

chore: more elaborate warning

chore: use external lib to support minibatching, revert to homebrew later

* poetry lock

* lint

* chore: better batch size

* 0.0.8 dependency prep
  • Loading branch information
bclavie committed Mar 18, 2024
1 parent f8c53cb commit d27b693
Show file tree
Hide file tree
Showing 8 changed files with 1,954 additions and 1,163 deletions.
13 changes: 4 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@ jobs:
with:
python-version: 3.9

- name: Cache Poetry virtualenv
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-poetry-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry-
- name: Install Poetry
uses: snok/install-poetry@v1.3.1

- name: Clean poetry
run: rm poetry.lock

- name: Install dependencies
run: poetry install --with dev
run: poetry install --with dev --no-cache

- name: Run tests
run: poetry run pytest tests/
2,661 changes: 1,540 additions & 1,121 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.7post11"
version = "0.0.8"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <ben@clavie.eu>"]
license = "Apache-2.0"
Expand All @@ -9,20 +9,19 @@ packages = [{include = "ragatouille"}]
repository = "https://github.com/bclavie/ragatouille"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
ruff = "^0.1.9"
python = ">=3.9,<4.0"
faiss-cpu = "^1.7.4"
transformers = "^4.36.2"
voyager = "^2.0.2"
aiohttp = "3.9.1"
sentence-transformers = "^2.2.2"
torch = "^2.0.1"
llama-index = "^0.9.24"
torch = ">=1.13"
llama-index = ">=0.7"
langchain_core = "^0.1.4"
colbert-ai = "0.2.19"
langchain = "^0.1.0"
onnx = "^1.15.0"
srsly = "2.4.8"
fast-pytorch-kmeans= "0.2.0.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
4 changes: 4 additions & 0 deletions ragatouille/RAGPretrainedModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def index(
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
bsize: int = 32,
use_faiss: bool = False,
):
"""Build an index from a list of documents.
Expand Down Expand Up @@ -215,6 +216,7 @@ def index(
max_document_length=max_document_length,
overwrite=overwrite_index,
bsize=bsize,
use_faiss=use_faiss,
)

def add_to_index(
Expand All @@ -227,6 +229,7 @@ def add_to_index(
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
bsize: int = 32,
use_faiss: bool = False,
):
"""Add documents to an existing index.
Expand Down Expand Up @@ -258,6 +261,7 @@ def add_to_index(
new_docid_metadata_map=new_docid_metadata_map,
index_name=index_name,
bsize=bsize,
use_faiss=use_faiss,
)

def delete_from_index(
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.7post11"
__version__ = "0.0.8"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
15 changes: 12 additions & 3 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.pid_docid_map = None
self.docid_pid_map = None
self.docid_metadata_map = None
self.base_model_max_tokens = 512
self.base_model_max_tokens = 510
if n_gpu == -1:
n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()

Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
)
self.base_model_max_tokens = (
self.inference_ckpt.bert.config.max_position_embeddings
)
) - 4

self.run_context = Run().context(self.run_config)
self.run_context.__enter__() # Manually enter the context
Expand Down Expand Up @@ -125,6 +125,7 @@ def add_to_index(
new_docid_metadata_map: Optional[List[dict]] = None,
index_name: Optional[str] = None,
bsize: int = 32,
use_faiss: bool = False,
):
self.index_name = index_name if index_name is not None else self.index_name
if self.index_name is None:
Expand Down Expand Up @@ -181,6 +182,7 @@ def add_to_index(
new_collection,
verbose=self.verbose != 0,
bsize=bsize,
use_faiss=use_faiss,
)
self.config = self.model_index.config

Expand Down Expand Up @@ -294,6 +296,7 @@ def index(
max_document_length: int = 256,
overwrite: Union[bool, str] = "reuse",
bsize: int = 32,
use_faiss: bool = False,
):
self.collection = collection
self.config.doc_maxlen = max_document_length
Expand Down Expand Up @@ -341,6 +344,7 @@ def index(
overwrite,
verbose=self.verbose != 0,
bsize=bsize,
use_faiss=use_faiss,
)
self.config = self.model_index.config
self._save_index_metadata()
Expand Down Expand Up @@ -494,7 +498,11 @@ def _set_inference_max_tokens(
not hasattr(self, "inference_ckpt_len_set")
or self.inference_ckpt_len_set is False
):
if max_tokens == "auto" or max_tokens > self.base_model_max_tokens:
if max_tokens == "auto":
max_tokens = self.base_model_max_tokens
else:
max_tokens = int(max_tokens)
if max_tokens > self.base_model_max_tokens:
max_tokens = self.base_model_max_tokens
percentile_90 = np.percentile(
[len(x.split(" ")) for x in documents], 90
Expand All @@ -504,6 +512,7 @@ def _set_inference_max_tokens(
self.base_model_max_tokens,
)
max_tokens = max(256, max_tokens)

if max_tokens > 300:
print(
f"Your documents are roughly {percentile_90} tokens long at the 90th percentile!",
Expand Down
91 changes: 68 additions & 23 deletions ragatouille/models/index.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from time import time
from typing import Any, List, Literal, Optional, TypeVar, Union

import srsly
import torch
from colbert import Indexer, IndexUpdater, Searcher
from colbert.indexing.collection_indexer import CollectionIndexer
from colbert.infra import ColBERTConfig

from ragatouille.models import torch_kmeans

IndexType = Literal["FLAT", "HNSW", "PLAID"]


Expand Down Expand Up @@ -126,6 +130,8 @@ class HNSWModelIndex(ModelIndex):
class PLAIDModelIndex(ModelIndex):
_DEFAULT_INDEX_BSIZE = 32
index_type = "PLAID"
faiss_kmeans = staticmethod(deepcopy(CollectionIndexer._train_kmeans))
pytorch_kmeans = staticmethod(torch_kmeans._train_kmeans)

def __init__(self, config: ColBERTConfig) -> None:
super().__init__(config)
Expand Down Expand Up @@ -168,21 +174,6 @@ def build(
bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE)
assert isinstance(bsize, int)

if torch.cuda.is_available():
import faiss

if not hasattr(faiss, "StandardGpuResources"):
print(
"________________________________________________________________________________\n"
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
"This means that indexing will be slow. To make use of your GPU.\n"
"Please install `faiss-gpu` by running:\n"
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
"________________________________________________________________________________",
)
print("Will continue with CPU indexing in 5 seconds...")
time.sleep(5)

nbits = 2
if len(collection) < 5000:
nbits = 8
Expand All @@ -192,22 +183,76 @@ def build(
self.config, ColBERTConfig(nbits=nbits, index_bsize=bsize)
)

# Instruct colbert-ai to disable forking if nranks == 1
self.config.avoid_fork_if_possible = True

if len(collection) > 100000:
self.config.kmeans_niters = 4
elif len(collection) > 50000:
self.config.kmeans_niters = 10
else:
self.config.kmeans_niters = 20

# Instruct colbert-ai to disable forking if nranks == 1
self.config.avoid_fork_if_possible = True
indexer = Indexer(
checkpoint=checkpoint,
config=self.config,
verbose=verbose,
# Monkey-patch colbert-ai to avoid using FAISS
monkey_patching = (
len(collection) < 100000 and kwargs.get("use_faiss", False) is False
)
indexer.configure(avoid_fork_if_possible=True)
indexer.index(name=index_name, collection=collection, overwrite=overwrite)
if monkey_patching:
print(
"---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----"
)
print("This is a behaviour change from RAGatouille 0.8.0 onwards.")
print(
"This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations."
)
print(
"If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour."
)
print("--------------------")
CollectionIndexer._train_kmeans = self.pytorch_kmeans

# Try to keep runtime stable -- these are values that empirically didn't degrade performance at all on 3 benchmarks.
# More tests required before warning can be removed.
try:
indexer = Indexer(
checkpoint=checkpoint,
config=self.config,
verbose=verbose,
)
indexer.configure(avoid_fork_if_possible=True)
indexer.index(
name=index_name, collection=collection, overwrite=overwrite
)
except Exception as err:
print(
f"PyTorch-based indexing did not succeed with error: {err}",
"! Reverting to using FAISS and attempting again...",
)
monkey_patching = False
if monkey_patching is False:
CollectionIndexer._train_kmeans = self.faiss_kmeans
if torch.cuda.is_available():
import faiss

if not hasattr(faiss, "StandardGpuResources"):
print(
"________________________________________________________________________________\n"
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
"This means that indexing will be slow. To make use of your GPU.\n"
"Please install `faiss-gpu` by running:\n"
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
"________________________________________________________________________________",
)
print("Will continue with CPU indexing in 5 seconds...")
time.sleep(5)
indexer = Indexer(
checkpoint=checkpoint,
config=self.config,
verbose=verbose,
)
indexer.configure(avoid_fork_if_possible=True)
indexer.index(name=index_name, collection=collection, overwrite=overwrite)

return self

def _load_searcher(
Expand Down
Loading

0 comments on commit d27b693

Please sign in to comment.