Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Managing knowledge base columns #9005

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/mindsdb_sql/agents/knowledge-bases.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ CREATE KNOWLEDGE_BASE my_kb
[storage = vector_database.storage_table;]]
```

### Managing input columns

Knowledge base can accept optional columns parameters to define where id, content and metadata columns are located:
```sql
CREATE KNOWLEDGE_BASE my_kb
USING
metadata_columns = ['date', 'creator'], -- optional, if not set: no metadata columns
content_columns = ['review'], -- optional, if not set: all columns is content
id_column='index' -- optional, default: id
```


## Step by step guide


Expand Down
133 changes: 120 additions & 13 deletions mindsdb/interfaces/knowledge_base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def insert(self, df: pd.DataFrame):
if df.empty:
return

df = self._adapt_column_names(df)

# add embeddings
df_emb = self._df_to_embeddings(df)
df = pd.concat([df, df_emb], axis=1)
Expand All @@ -130,6 +132,111 @@ def insert(self, df: pd.DataFrame):
db_handler = self._get_vector_db()
db_handler.do_upsert(self._kb.vector_database_table, df)

def _adapt_column_names(self, df: pd.DataFrame) -> pd.DataFrame:

'''
convert input columns for vector db input
- id, content and metadata
'''

params = self._kb.params

columns = list(df.columns)

# -- prepare id --

# if id_column is defined:
# use it as id
# elif 'id' column exists:
# use it
# else:
# use hash(content) -- it happens inside of vector handler

id_column = params.get('id_column')
if id_column is not None and id_column not in columns:
# wrong name
id_column = None

if id_column is None and TableField.ID.value in columns:
# default value
id_column = TableField.ID.value

if id_column is not None:
# remove from lookup list
columns.remove(id_column)

# -- prepare content and metadata --

# if content_columns is defined:
# if len(content_columns) > 1:
# make text from row (col: value\n col: value)
# if metadata_columns is defined:
# use them as metadata
# else:
# use all unused columns is metadata
# elif metadata_columns is defined:
# metadata_columns go to metadata
# use all unused columns as content (make text if columns>1)
# else:
# no metadata
# all unused columns go to content (make text if columns>1)

content_columns = params.get('content_columns')
metadata_columns = params.get('metadata_columns')

if content_columns is not None:
content_columns = list(set(content_columns).intersection(columns))
if len(content_columns) == 0:
raise ValueError(f'Content columns {params.get("content_columns")} not found in dataset: {columns}')

if metadata_columns is not None:
metadata_columns = list(set(metadata_columns).intersection(columns))
else:
# all the rest columns
metadata_columns = list(set(columns).difference(content_columns))

elif metadata_columns is not None:
metadata_columns = list(set(metadata_columns).intersection(columns))
# use all unused columns is content
content_columns = list(set(columns).difference(metadata_columns))
else:
# all columns go to content
content_columns = columns

if not content_columns:
raise ValueError("Can't find content columns")

def row_to_document(row: pd.Series) -> str:
"""
Convert a row in the input dataframe into a document

Default implementation is to concatenate all the columns
in the form of
field1: value1\nfield2: value2\n...
"""
fields = row.index.tolist()
values = row.values.tolist()
document = "\n".join(
[f"{field}: {value}" for field, value in zip(fields, values)]
)
return document

# create dataframe
if len(content_columns) == 1:
c_content = df[content_columns[0]]
else:
c_content = df[content_columns].apply(row_to_document, axis=1)
c_content.name = TableField.CONTENT.value
df_out = pd.DataFrame(c_content)

if id_column is not None:
df_out[TableField.ID.value] = df_out[id_column]

if metadata_columns and len(metadata_columns) > 0:
df_out[TableField.METADATA.value] = df[metadata_columns].apply(lambda row: str(dict(row)), axis=1)

return df_out

def _replace_query_content(self, node, **kwargs):
if isinstance(node, BinaryOperation):
if isinstance(node.args[0], Identifier) and isinstance(node.args[1], Constant):
Expand Down Expand Up @@ -157,6 +264,9 @@ def _df_to_embeddings(self, df: pd.DataFrame) -> pd.DataFrame:
:return: dataframe with embeddings
"""

if df.empty:
return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])

model_id = self._kb.embedding_model_id
# get the input columns
model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
Expand All @@ -171,21 +281,18 @@ def _df_to_embeddings(self, df: pd.DataFrame) -> pd.DataFrame:
if input_col is not None and input_col != TableField.CONTENT.value:
df = df.rename(columns={TableField.CONTENT.value: input_col})

if df.empty:
df_out = pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
else:
data = df.to_dict('records')
data = df.to_dict('records')

df_out = project_datanode.predict(
model_name=model_rec.name,
data=data,
)
df_out = project_datanode.predict(
model_name=model_rec.name,
data=data,
)

target = model_rec.to_predict[0]
if target != TableField.EMBEDDINGS.value:
# adapt output for vectordb
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
df_out = df_out[[TableField.EMBEDDINGS.value]]
target = model_rec.to_predict[0]
if target != TableField.EMBEDDINGS.value:
# adapt output for vectordb
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
df_out = df_out[[TableField.EMBEDDINGS.value]]

return df_out

Expand Down