Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FewShotExamples managed value (redo) #444

Merged
merged 15 commits into from
May 14, 2024
Merged

Conversation

andrewnguonly
Copy link
Contributor

@andrewnguonly andrewnguonly commented May 14, 2024

Summary

This PR is a reimplementation of the old "Add FewShotExamples managed value" PR. A new PR is created because there are merge conflicts with the old PR. The changes in this PR are mostly derived from the old PR with a few noted exceptions.

Implementation

  1. Add score field to CheckpointMetadata. Clients can specify a score for a checkpoint to mark a thread as "good". This is different from the old PR.
  2. Add search() and asearch() APIs to BaseCheckpointSaver for searching checkpoints by metadata without filtering on thread_id. Implement the APIs in MemorySaver, SqliteSaver, and AsyncSqliteSaver. The implementation is mostly copied from the list() method. The old PR implemented an API called list_w_score().
  3. Copy FewShotExamples class from the old PR.
  4. Copy the tests from the old PR.
  5. Copy learning.ipynb from the old PR.

To Do

loop = asyncio.get_running_loop()
iter = await loop.run_in_executor(None, self.search, metadata_query)

def next_item(iter: Iterator[CheckpointTuple]) -> CheckpointTuple:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally tried to copy the implementation from alist(), but it was not working as expected since StopIteration is raised in the subprocess, not the main thread.

I'm not sure if alist() actually works as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

signature seems wrong? also this function seems equivalent to next(iter, None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll adopt the implementation from this PR: #447

with self.cursor(transaction=False) as cur:
cur.execute(
query,
(() if before is None else (before["configurable"]["thread_ts"],)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to format this nicer.


where = "WHERE "
for query_key, query_value in metadata_query.items():
where += f"json_extract(CAST(metadata AS TEXT), '$.{query_key}') {_where_value(query_value)} AND "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily performant to call json_extract() for each metadata query key, but it's the simplest way to do it without creating a subtable, etc.

query = (
f"SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints {search_where(metadata_query)}ORDER BY thread_ts DESC"
if before is None
else f"SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints {search_where(metadata_query)}AND thread_ts < ? ORDER BY thread_ts DESC"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug (search_where() can return empty string). Need to fix.

@@ -382,3 +437,46 @@ def put(
"thread_ts": checkpoint["ts"],
}
}


def search_where(metadata_query: CheckpointMetadata, predicates: List[str] = []) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update implementation so that query values are parameterized instead of hardcoded into query.

@nfcampos nfcampos merged commit a694aaa into main May 14, 2024
12 checks passed
@nfcampos nfcampos deleted the an/10may/few-shot-clone branch May 14, 2024 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants