Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Allow multiple splits per example, add splits to experiment metadata on evaluation #709

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
- name: Install dependencies
run: |
poetry install --with dev
poetry run pip install -U langchain
poetry run pip install -U langchain langchain_anthropic langchain_openai rapidfuzz
- name: Run Python integration tests
uses: ./.github/actions/python-integration-tests
with:
Expand Down
4 changes: 2 additions & 2 deletions js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
exampleId?: string;

metadata?: KVMap;
split?: string;
split?: string | string[];
};

type AutoBatchQueueItem = {
Expand Down Expand Up @@ -2036,7 +2036,7 @@
inputs: Array<KVMap>;
outputs?: Array<KVMap>;
metadata?: Array<KVMap>;
splits?: Array<string>;
splits?: Array<string | Array<string>>;
sourceRunIds?: Array<string>;
exampleIds?: Array<string>;
datasetId?: string;
Expand Down Expand Up @@ -2602,7 +2602,7 @@
public async logEvaluationFeedback(
evaluatorResponse: EvaluationResult | EvaluationResults,
run?: Run,
sourceInfo?: { [key: string]: any }

Check warning on line 2605 in js/src/client.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
): Promise<EvaluationResult[]> {
const results: Array<EvaluationResult> =
this._selectEvalResults(evaluatorResponse);
Expand Down
16 changes: 16 additions & 0 deletions js/src/evaluation/_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
runs?: AsyncGenerator<Run>;
evaluationResults?: AsyncGenerator<EvaluationResults>;
summaryResults?: AsyncGenerator<
(runsArray: Run[]) => AsyncGenerator<EvaluationResults, any, unknown>,

Check warning on line 65 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
any,

Check warning on line 66 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
unknown
>;
examples?: Example[];
Expand Down Expand Up @@ -142,8 +142,8 @@
_evaluationResults?: AsyncGenerator<EvaluationResults>;

_summaryResults?: AsyncGenerator<
(runsArray: Run[]) => AsyncGenerator<EvaluationResults, any, unknown>,

Check warning on line 145 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
any,

Check warning on line 146 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
unknown
>;

Expand Down Expand Up @@ -207,7 +207,7 @@
get evaluationResults(): AsyncGenerator<EvaluationResults> {
if (this._evaluationResults === undefined) {
return async function* (this: _ExperimentManager) {
for (const _ of await this.getExamples()) {

Check warning on line 210 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

'_' is assigned a value but never used
yield { results: [] };
}
}.call(this);
Expand Down Expand Up @@ -629,7 +629,7 @@
this.client._selectEvalResults(summaryEvalResult);
aggregateFeedback.push(...flattenedResults);
for (const result of flattenedResults) {
const { targetRunId, ...feedback } = result;

Check warning on line 632 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

'targetRunId' is assigned a value but never used
const evaluatorInfo = feedback.evaluatorInfo;
delete feedback.evaluatorInfo;

Expand Down Expand Up @@ -694,13 +694,29 @@
).date;
}

async _getDatasetSplits(): Promise<string[] | undefined> {
const examples = await this.getExamples();
const allSplits = examples.reduce((acc, ex) => {
if (ex.metadata && ex.metadata.dataset_split) {
if (Array.isArray(ex.metadata.dataset_split)) {
ex.metadata.dataset_split.forEach((split) => acc.add(split));
} else if (typeof ex.metadata.dataset_split === "string") {
acc.add(ex.metadata.dataset_split);
}
}
return acc;
}, new Set<string>());
return allSplits.size ? Array.from(allSplits) : undefined;
}

async _end(): Promise<void> {
const experiment = this._experiment;
if (!experiment) {
throw new Error("Experiment not yet started.");
}
const projectMetadata = await this._getExperimentMetadata();
projectMetadata["dataset_version"] = await this._getDatasetVersion();
projectMetadata["dataset_splits"] = await this._getDatasetSplits();
// Update revision_id if not already set
if (!projectMetadata["revision_id"]) {
projectMetadata["revision_id"] = await getDefaultRevisionId();
Expand Down Expand Up @@ -809,7 +825,7 @@
return results;
}

type ForwardFn = ((...args: any[]) => Promise<any>) | ((...args: any[]) => any);

Check warning on line 828 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type

Check warning on line 828 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type

Check warning on line 828 in js/src/evaluation/_runner.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type

async function _forward(
fn: ForwardFn,
Expand Down
4 changes: 2 additions & 2 deletions js/src/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ export interface RunUpdate {
export interface ExampleCreate extends BaseExample {
id?: string;
created_at?: string;
split?: string;
split?: string | string[];
}

export interface Example extends BaseExample {
Expand All @@ -245,7 +245,7 @@ export interface ExampleUpdate {
inputs?: KVMap;
outputs?: KVMap;
metadata?: KVMap;
split?: string;
split?: string | string[];
}
export interface BaseDataset {
name: string;
Expand Down
23 changes: 17 additions & 6 deletions js/src/tests/client.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,22 @@ test.concurrent("Test LangSmith Client Dataset CRD", async () => {
await client.updateExample(example.id, {
inputs: { col1: "updatedExampleCol1" },
outputs: { col2: "updatedExampleCol2" },
split: "my_split2",
split: ["my_split2"],
});
// Says 'example updated' or something similar
const newExampleValue = await client.readExample(example.id);
expect(newExampleValue.inputs.col1).toBe("updatedExampleCol1");
expect(newExampleValue.metadata?.dataset_split).toBe("my_split2");
expect(newExampleValue.metadata?.dataset_split).toStrictEqual(["my_split2"]);

await client.updateExample(example.id, {
inputs: { col1: "updatedExampleCol3" },
outputs: { col2: "updatedExampleCol4" },
split: "my_split3",
});
// Says 'example updated' or something similar
const newExampleValue2 = await client.readExample(example.id);
expect(newExampleValue2.inputs.col1).toBe("updatedExampleCol3");
expect(newExampleValue2.metadata?.dataset_split).toStrictEqual(["my_split3"]);
await client.deleteExample(example.id);
const examples2 = await toArray(
client.listExamples({ datasetId: newDataset.id })
Expand Down Expand Up @@ -489,7 +499,7 @@ test.concurrent(
{ output: "hi there 3" },
],
metadata: [{ key: "value 1" }, { key: "value 2" }, { key: "value 3" }],
splits: ["train", "test", "train"],
splits: ["train", "test", ["train", "validation"]],
datasetId: dataset.id,
});
const initialExamplesList = await toArray(
Expand Down Expand Up @@ -520,19 +530,20 @@ test.concurrent(
);
expect(example1?.outputs?.output).toEqual("hi there 1");
expect(example1?.metadata?.key).toEqual("value 1");
expect(example1?.metadata?.dataset_split).toEqual("train");
expect(example1?.metadata?.dataset_split).toEqual(["train"]);
const example2 = examplesList2.find(
(e) => e.inputs.input === "hello world 2"
);
expect(example2?.outputs?.output).toEqual("hi there 2");
expect(example2?.metadata?.key).toEqual("value 2");
expect(example2?.metadata?.dataset_split).toEqual("test");
expect(example2?.metadata?.dataset_split).toEqual(["test"]);
const example3 = examplesList2.find(
(e) => e.inputs.input === "hello world 3"
);
expect(example3?.outputs?.output).toEqual("hi there 3");
expect(example3?.metadata?.key).toEqual("value 3");
expect(example3?.metadata?.dataset_split).toEqual("train");
expect(example3?.metadata?.dataset_split).toContain("train");
expect(example3?.metadata?.dataset_split).toContain("validation");

await client.createExample(
{ input: "hello world" },
Expand Down
85 changes: 84 additions & 1 deletion js/src/tests/evaluate.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { EvaluationResult } from "../evaluation/evaluator.js";
import { evaluate } from "../evaluation/_runner.js";
import { Example, Run } from "../schemas.js";
import { Example, Run, TracerSession } from "../schemas.js";
import { Client } from "../index.js";
import { afterAll, beforeAll } from "@jest/globals";
import { RunnableLambda } from "@langchain/core/runnables";
Expand Down Expand Up @@ -30,6 +30,13 @@ afterAll(async () => {
await client.deleteDataset({
datasetName: TESTING_DATASET_NAME,
});
try {
await client.deleteDataset({
datasetName: "my_splits_ds2",
});
} catch (_) {
//pass
}
});

test("evaluate can evaluate", async () => {
Expand Down Expand Up @@ -351,6 +358,82 @@ test("can pass multiple evaluators", async () => {
);
});

test("split info saved correctly", async () => {
const client = new Client();
// create a new dataset
await client.createDataset("my_splits_ds2", {
description:
"For testing purposed. Is created & deleted for each test run.",
});
// create examples
await client.createExamples({
inputs: [{ input: 1 }, { input: 2 }, { input: 3 }],
outputs: [{ output: 2 }, { output: 3 }, { output: 4 }],
splits: [["test"], ["train"], ["validation", "test"]],
datasetName: "my_splits_ds2",
});

const targetFunc = (input: Record<string, any>) => {
console.log("__input__", input);
return {
foo: input.input + 1,
};
};
await evaluate(targetFunc, {
data: client.listExamples({ datasetName: "my_splits_ds2" }),
description: "splits info saved correctly",
});

const exp = client.listProjects({ referenceDatasetName: "my_splits_ds2" });
let myExp: TracerSession | null = null;
for await (const session of exp) {
myExp = session;
}
expect(myExp?.extra?.metadata?.dataset_splits.sort()).toEqual(
["test", "train", "validation"].sort()
);

await evaluate(targetFunc, {
data: client.listExamples({
datasetName: "my_splits_ds2",
splits: ["test"],
}),
description: "splits info saved correctly",
});

const exp2 = client.listProjects({ referenceDatasetName: "my_splits_ds2" });
let myExp2: TracerSession | null = null;
for await (const session of exp2) {
if (myExp2 === null || session.start_time > myExp2.start_time) {
myExp2 = session;
}
}

expect(myExp2?.extra?.metadata?.dataset_splits.sort()).toEqual(
["test", "validation"].sort()
);

await evaluate(targetFunc, {
data: client.listExamples({
datasetName: "my_splits_ds2",
splits: ["train"],
}),
description: "splits info saved correctly",
});

const exp3 = client.listProjects({ referenceDatasetName: "my_splits_ds2" });
let myExp3: TracerSession | null = null;
for await (const session of exp3) {
if (myExp3 === null || session.start_time > myExp3.start_time) {
myExp3 = session;
}
}

expect(myExp3?.extra?.metadata?.dataset_splits.sort()).toEqual(
["train"].sort()
);
});

test("can pass multiple summary evaluators", async () => {
const targetFunc = (input: Record<string, any>) => {
console.log("__input__", input);
Expand Down
15 changes: 12 additions & 3 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2936,7 +2936,7 @@ def create_examples(
inputs: Sequence[Mapping[str, Any]],
outputs: Optional[Sequence[Optional[Mapping[str, Any]]]] = None,
metadata: Optional[Sequence[Optional[Mapping[str, Any]]]] = None,
splits: Optional[Sequence[Optional[str]]] = None,
splits: Optional[Sequence[Optional[str | List[str]]]] = None,
source_run_ids: Optional[Sequence[Optional[ID_TYPE]]] = None,
ids: Optional[Sequence[Optional[ID_TYPE]]] = None,
dataset_id: Optional[ID_TYPE] = None,
Expand All @@ -2953,6 +2953,9 @@ def create_examples(
The output values for the examples.
metadata : Optional[Sequence[Optional[Mapping[str, Any]]]], default=None
The metadata for the examples.
split : Optional[Sequence[Optional[str | List[str]]]], default=None
The splits for the examples, which are divisions
of your dataset such as 'train', 'test', or 'validation'.
source_run_ids : Optional[Sequence[Optional[ID_TYPE]]], default=None
The IDs of the source runs associated with the examples.
ids : Optional[Sequence[ID_TYPE]], default=None
Expand Down Expand Up @@ -3012,7 +3015,7 @@ def create_example(
created_at: Optional[datetime.datetime] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[str, Any]] = None,
split: Optional[str] = None,
split: Optional[str | List[str]] = None,
example_id: Optional[ID_TYPE] = None,
) -> ls_schemas.Example:
"""Create a dataset example in the LangSmith API.
Expand All @@ -3034,6 +3037,9 @@ def create_example(
The output values for the example.
metadata : Mapping[str, Any] or None, default=None
The metadata for the example.
split : str or List[str] or None, default=None
The splits for the example, which are divisions
of your dataset such as 'train', 'test', or 'validation'.
exemple_id : UUID or None, default=None
The ID of the example to create. If not provided, a new
example will be created.
Expand Down Expand Up @@ -3165,7 +3171,7 @@ def update_example(
inputs: Optional[Dict[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Dict] = None,
split: Optional[str] = None,
split: Optional[str | List[str]] = None,
dataset_id: Optional[ID_TYPE] = None,
) -> Dict[str, Any]:
"""Update a specific example.
Expand All @@ -3180,6 +3186,9 @@ def update_example(
The output values to update.
metadata : Dict or None, default=None
The metadata to update.
split : str or List[str] or None, default=None
The dataset split to update, such as
'train', 'test', or 'validation'.
dataset_id : UUID or None, default=None
The ID of the dataset to update.

Expand Down
18 changes: 18 additions & 0 deletions python/langsmith/evaluation/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,13 +1322,31 @@ def _get_dataset_version(self) -> Optional[str]:
max_modified_at = max(modified_at) if modified_at else None
return max_modified_at.isoformat() if max_modified_at else None

def _get_dataset_splits(self) -> Optional[list[str]]:
examples = list(self.examples)
splits = set()
for example in examples:
if (
example.metadata
and example.metadata.get("dataset_split")
and isinstance(example.metadata["dataset_split"], list)
):
for split in example.metadata["dataset_split"]:
if isinstance(split, str):
splits.add(split)
else:
splits.add("base")

return list(splits)

def _end(self) -> None:
experiment = self._experiment
if experiment is None:
raise ValueError("Experiment not started yet.")

project_metadata = self._get_experiment_metadata()
project_metadata["dataset_version"] = self._get_dataset_version()
project_metadata["dataset_splits"] = self._get_dataset_splits()
self.client.update_project(
experiment.id,
end_time=datetime.datetime.now(datetime.timezone.utc),
Expand Down
4 changes: 2 additions & 2 deletions python/langsmith/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ExampleCreate(ExampleBase):

id: Optional[UUID]
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
split: Optional[str] = None
split: Optional[Union[str, List[str]]] = None


class Example(ExampleBase):
Expand Down Expand Up @@ -106,7 +106,7 @@ class ExampleUpdate(BaseModel):
inputs: Optional[Dict[str, Any]] = None
outputs: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None
split: Optional[str] = None
split: Optional[Union[str, List[str]]] = None

class Config:
"""Configuration class for the schema."""
Expand Down
21 changes: 18 additions & 3 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def test_datasets(langchain_client: Client) -> None:
def test_list_examples(langchain_client: Client) -> None:
"""Test list_examples."""
examples = [
("Shut up, idiot", "Toxic", "train"),
("Shut up, idiot", "Toxic", ["train", "validation"]),
("You're a wonderful person", "Not toxic", "test"),
("This is the worst thing ever", "Toxic", "train"),
("This is the worst thing ever", "Toxic", ["train"]),
("I had a great day today", "Not toxic", "test"),
("Nobody likes you", "Toxic", "train"),
("This is unacceptable. I want to speak to the manager.", "Not toxic", None),
Expand All @@ -133,6 +133,11 @@ def test_list_examples(langchain_client: Client) -> None:
)
assert len(example_list) == 3

example_list = list(
langchain_client.list_examples(dataset_id=dataset.id, splits=["validation"])
)
assert len(example_list) == 1

example_list = list(
langchain_client.list_examples(dataset_id=dataset.id, splits=["test"])
)
Expand All @@ -148,11 +153,21 @@ def test_list_examples(langchain_client: Client) -> None:
example.id
for example in example_list
if example.metadata is not None
and example.metadata.get("dataset_split") == "test"
and "test" in example.metadata.get("dataset_split", [])
][0],
split="train",
)

example_list = list(
langchain_client.list_examples(dataset_id=dataset.id, splits=["test"])
)
assert len(example_list) == 1

example_list = list(
langchain_client.list_examples(dataset_id=dataset.id, splits=["train"])
)
assert len(example_list) == 4

langchain_client.create_example(
inputs={"text": "What's up!"},
outputs={"label": "Not toxic"},
Expand Down
Loading