Skip to content

Commit

Permalink
misc: refactored VectorStore trait
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jun 26, 2024
1 parent 787bfe4 commit 054b34e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 29 deletions.
14 changes: 13 additions & 1 deletion src/agent/rag/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use async_trait::async_trait;
use naive::NaiveVectorStore;
use serde::{Deserialize, Serialize};

use super::generator::Client;
Expand All @@ -23,10 +24,21 @@ pub struct Document {
#[async_trait]
pub trait VectorStore: Send {
#[allow(clippy::borrowed_box)]
fn new_with_generator(generator: Box<dyn Client>) -> Result<Self>
async fn new(embedder: Box<dyn Client>, config: Configuration) -> Result<Self>
where
Self: Sized;

async fn add(&mut self, document: Document) -> Result<()>;
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>>;
}

pub async fn factory(
flavor: &str,
embedder: Box<dyn Client>,
config: Configuration,
) -> Result<Box<dyn VectorStore>> {
match flavor {
"naive" => Ok(Box::new(NaiveVectorStore::new(embedder, config).await?)),
_ => Err(anyhow!("rag flavor '{flavor} not supported yet")),
}
}
46 changes: 22 additions & 24 deletions src/agent/rag/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,39 @@ use async_trait::async_trait;
use colored::Colorize;
use glob::glob;

use super::{Document, Embeddings, VectorStore};
use super::{Configuration, Document, Embeddings, VectorStore};
use crate::agent::{generator::Client, rag::metrics};

// TODO: integrate other more efficient vector databases.

pub struct NaiveVectorStore {
config: Configuration,
embedder: Box<dyn Client>,
documents: HashMap<String, Document>,
embeddings: HashMap<String, Embeddings>,
}

impl NaiveVectorStore {
// TODO: add persistency
pub async fn from_indexed_path(generator: Box<dyn Client>, path: &str) -> Result<Self> {
let path = std::fs::canonicalize(path)?.display().to_string();
#[async_trait]
impl VectorStore for NaiveVectorStore {
#[allow(clippy::borrowed_box)]
async fn new(embedder: Box<dyn Client>, config: Configuration) -> Result<Self>
where
Self: Sized,
{
// TODO: add persistency
let documents = HashMap::new();
let embeddings = HashMap::new();
let mut store = Self {
config,
documents,
embeddings,
embedder,
};

let path = std::fs::canonicalize(&store.config.path)?
.display()
.to_string();
let expr = format!("{}/**/*.txt", path);
let mut store = NaiveVectorStore::new_with_generator(generator)?;

for path in (glob(&expr)?).flatten() {
let doc_name = path.display();
Expand All @@ -39,24 +55,6 @@ impl NaiveVectorStore {

Ok(store)
}
}

#[async_trait]
impl VectorStore for NaiveVectorStore {
#[allow(clippy::borrowed_box)]
fn new_with_generator(embedder: Box<dyn Client>) -> Result<Self>
where
Self: Sized,
{
let documents = HashMap::new();
let embeddings = HashMap::new();

Ok(Self {
documents,
embeddings,
embedder,
})
}

async fn add(&mut self, document: Document) -> Result<()> {
if self.documents.contains_key(&document.name) {
Expand Down
7 changes: 3 additions & 4 deletions src/agent/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use metrics::Metrics;
use super::{
generator::{Client, Message},
namespaces::{self, Namespace},
rag::{naive::NaiveVectorStore, Document, VectorStore},
rag::{Document, VectorStore},
task::Task,
Invocation,
};
Expand Down Expand Up @@ -84,12 +84,11 @@ impl State {

// add RAG namespace
let rag: Option<Box<dyn VectorStore>> = if let Some(config) = task.get_rag_config() {
let v_store: NaiveVectorStore =
NaiveVectorStore::from_indexed_path(embedder, &config.path).await?;
let v_store = super::rag::factory("naive", embedder, config).await?;

namespaces.push(namespaces::NAMESPACES.get("rag").unwrap()());

Some(Box::new(v_store))
Some(v_store)
} else {
None
};
Expand Down

0 comments on commit 054b34e

Please sign in to comment.