Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Managing knowledge base columns #9005

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/mindsdb_sql/agents/knowledge-bases.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ CREATE KNOWLEDGE_BASE my_kb
[storage = vector_database.storage_table;]]
```

### Managing input columns

Knowledge base can accept optional columns parameters to define where id, content and metadata columns are located:
```sql
CREATE KNOWLEDGE_BASE my_kb
USING
metadata_columns = ['date', 'creator'], -- optional, if not set: no metadata columns
content_columns = ['review'], -- optional, if not set: all columns is content
id_column='index' -- optional, default: id
```


## Step by step guide


Expand Down
133 changes: 120 additions & 13 deletions mindsdb/interfaces/knowledge_base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def insert(self, df: pd.DataFrame):
if df.empty:
return

df = self._adapt_column_names(df)

# add embeddings
df_emb = self._df_to_embeddings(df)
df = pd.concat([df, df_emb], axis=1)
Expand All @@ -130,6 +132,111 @@ def insert(self, df: pd.DataFrame):
db_handler = self._get_vector_db()
db_handler.do_upsert(self._kb.vector_database_table, df)

def _adapt_column_names(self, df: pd.DataFrame) -> pd.DataFrame:

'''
convert input columns for vector db input
- id, content and metadata
'''

params = self._kb.params

columns = list(df.columns)

# -- prepare id --

# if id_column is defined:
# use it as id
# elif 'id' column exists:
# use it
# else:
# use hash(content) -- it happens inside of vector handler

id_column = params.get('id_column')
if id_column is not None and id_column not in columns:
# wrong name
id_column = None

if id_column is None and TableField.ID.value in columns:
# default value
id_column = TableField.ID.value

if id_column is not None:
# remove from lookup list
columns.remove(id_column)

# -- prepare content and metadata --

# if content_columns is defined:
# if len(content_columns) > 1:
# make text from row (col: value\n col: value)
# if metadata_columns is defined:
# use them as metadata
# else:
# use all unused columns is metadata
# elif metadata_columns is defined:
# metadata_columns go to metadata
# use all unused columns as content (make text if columns>1)
# else:
# no metadata
# all unused columns go to content (make text if columns>1)

content_columns = params.get('content_columns')
metadata_columns = params.get('metadata_columns')

if content_columns is not None:
content_columns = list(set(content_columns).intersection(columns))
if len(content_columns) == 0:
raise ValueError(f'Content columns {params.get("content_columns")} not found in dataset: {columns}')

if metadata_columns is not None:
metadata_columns = list(set(metadata_columns).intersection(columns))
else:
# all the rest columns
metadata_columns = list(set(columns).difference(content_columns))

elif metadata_columns is not None:
metadata_columns = list(set(metadata_columns).intersection(columns))
# use all unused columns is content
content_columns = list(set(columns).difference(metadata_columns))
else:
# all columns go to content
content_columns = columns

if not content_columns:
raise ValueError("Can't find content columns")

def row_to_document(row: pd.Series) -> str:
"""
Convert a row in the input dataframe into a document

Default implementation is to concatenate all the columns
in the form of
field1: value1\nfield2: value2\n...
"""
fields = row.index.tolist()
values = row.values.tolist()
document = "\n".join(
[f"{field}: {value}" for field, value in zip(fields, values)]
)
return document

# create dataframe
if len(content_columns) == 1:
c_content = df[content_columns[0]]
else:
c_content = df[content_columns].apply(row_to_document, axis=1)
c_content.name = TableField.CONTENT.value
df_out = pd.DataFrame(c_content)

if id_column is not None:
df_out[TableField.ID.value] = df[id_column]

if metadata_columns and len(metadata_columns) > 0:
df_out[TableField.METADATA.value] = df[metadata_columns].apply(lambda row: str(dict(row)), axis=1)

return df_out

def _replace_query_content(self, node, **kwargs):
if isinstance(node, BinaryOperation):
if isinstance(node.args[0], Identifier) and isinstance(node.args[1], Constant):
Expand Down Expand Up @@ -157,6 +264,9 @@ def _df_to_embeddings(self, df: pd.DataFrame) -> pd.DataFrame:
:return: dataframe with embeddings
"""

if df.empty:
return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])

model_id = self._kb.embedding_model_id
# get the input columns
model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
Expand All @@ -171,21 +281,18 @@ def _df_to_embeddings(self, df: pd.DataFrame) -> pd.DataFrame:
if input_col is not None and input_col != TableField.CONTENT.value:
df = df.rename(columns={TableField.CONTENT.value: input_col})

if df.empty:
df_out = pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
else:
data = df.to_dict('records')
data = df[[TableField.CONTENT.value]].to_dict('records')

df_out = project_datanode.predict(
model_name=model_rec.name,
data=data,
)
df_out = project_datanode.predict(
model_name=model_rec.name,
data=data,
)

target = model_rec.to_predict[0]
if target != TableField.EMBEDDINGS.value:
# adapt output for vectordb
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
df_out = df_out[[TableField.EMBEDDINGS.value]]
target = model_rec.to_predict[0]
if target != TableField.EMBEDDINGS.value:
# adapt output for vectordb
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
df_out = df_out[[TableField.EMBEDDINGS.value]]

return df_out

Expand Down
69 changes: 69 additions & 0 deletions tests/unit/test_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,72 @@ def test_show_knowledge_bases(self):
"""
df = self.run_sql(sql)
assert df.shape[0] == 1

def test_kb_params(self):

df = pd.DataFrame([
{'id': 1, 'ticket': 'NFLX', 'value': 532, 'created_at': '2020-01-01', 'ma': 100},
{'id': 2, 'ticket': 'MSFT', 'value': 311, 'created_at': '2020-01-02', 'ma': 200},
])

self.save_file('stock', df)

# ---- default ----
self.run_sql('create knowledge base kb_test')
self.run_sql('INSERT INTO kb_test select * from files.stock')
ret = self.run_sql("select * from kb_test where content='msft'")
self.run_sql('drop knowledge base kb_test') # have to drop KB with model and vector sore before assertions

# second row is the result, all columns in content
content = ret.content[0]
assert 'MSFT' in content and 'created_at' in content and '311' in content and '200' in content

# metadata is empty
assert ret.metadata[0] is None

# id = 2
assert ret.id[0] == '2'

# ---- choose content ----
self.run_sql('''
create knowledge base kb_test
using content_columns = ['ticket', 'value']
''')
self.run_sql('INSERT INTO kb_test select * from files.stock')
ret = self.run_sql("select * from kb_test where content='msft'")
self.run_sql('drop knowledge base kb_test')

metadata = ret.metadata[0]
content = ret.content[0]
# ticket and value in content
assert 'MSFT' in content and '311' in content
# created and ma in metadata
assert 'created_at' in metadata and 'ma' in metadata

# ---- choose metadata ----
self.run_sql('''
create knowledge base kb_test
using metadata_columns = ['created_at', 'value']
''')
self.run_sql('INSERT INTO kb_test select * from files.stock')
ret = self.run_sql("select * from kb_test where content='msft'")
self.run_sql('drop knowledge base kb_test')

metadata = ret.metadata[0]
content = ret.content[0]
# ticket and ma in content
assert 'MSFT' in content and '200' in content
# created and value in metadata
assert 'created_at' in metadata and 'value' in metadata

# ---- choose id ----
self.run_sql('''
create knowledge base kb_test
using id_column='ma'
''')
self.run_sql('INSERT INTO kb_test select * from files.stock')
ret = self.run_sql("select * from kb_test where content='msft'")
self.run_sql('drop knowledge base kb_test')

# id = 200
assert ret.id[0] == '200'