Skip to content

Commit

Permalink
langchain[patch]: Fix run on dataset method for built-in evaluators (#…
Browse files Browse the repository at this point in the history
…4332)

* Fix run on dataset method for built-in evaluators

* Fix typo

* Update run_on_dataset.int.test.ts
  • Loading branch information
jacoblee93 committed Feb 7, 2024
1 parent 59f63f9 commit f893b04
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 24 deletions.
2 changes: 2 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@
"node-llama-cpp": "2.7.3",
"notion-to-md": "^3.1.0",
"officeparser": "^4.0.4",
"openai": "^4.26.1",
"pdf-parse": "1.1.1",
"peggy": "^3.0.2",
"playwright": "^1.32.1",
Expand All @@ -1286,6 +1287,7 @@
"vectordb": "^0.1.4",
"weaviate-ts-client": "^1.4.0",
"web-auth-library": "^1.0.3",
"wikipedia": "^2.1.2",
"youtube-transcript": "^1.0.6",
"youtubei.js": "^5.8.0"
},
Expand Down
8 changes: 3 additions & 5 deletions langchain/src/evaluation/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ export abstract class LLMStringEvaluator<
* @param config
*/
abstract _evaluateStrings(
args: StringEvaluatorArgs,
callOptions?: ExtractLLMCallOptions<this["llm"]>,
args: StringEvaluatorArgs & ExtractLLMCallOptions<this["llm"]>,
config?: Callbacks | BaseCallbackConfig
): Promise<ChainValues>;

Expand All @@ -195,12 +194,11 @@ export abstract class LLMStringEvaluator<
* @param config
*/
evaluateStrings(
args: StringEvaluatorArgs,
callOptions?: ExtractLLMCallOptions<this["llm"]>,
args: StringEvaluatorArgs & ExtractLLMCallOptions<this["llm"]>,
config?: Callbacks | BaseCallbackConfig
): Promise<ChainValues> {
this.checkEvaluationArgs(args.reference, args.input);
return this._evaluateStrings(args, callOptions, config);
return this._evaluateStrings(args, config);
}
}

Expand Down
8 changes: 2 additions & 6 deletions langchain/src/evaluation/criteria/criteria.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,10 @@ export class CriteriaEvalChain extends LLMStringEvaluator {
}

async _evaluateStrings(
args: StringEvaluatorArgs,
callOptions: ExtractLLMCallOptions<this["llm"]>,
args: StringEvaluatorArgs & ExtractLLMCallOptions<this["llm"]>,
config?: Callbacks | BaseCallbackConfig
): Promise<ChainValues> {
const result = await this.call(
{ ...this.getEvalInput(args), ...callOptions },
config
);
const result = await this.call({ ...this.getEvalInput(args) }, config);

return this._prepareOutput(result);
}
Expand Down
48 changes: 35 additions & 13 deletions langchain/src/smith/runner_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import {
RunnableConfig,
RunnableLambda,
} from "@langchain/core/runnables";
import { RunCollectorCallbackHandler } from "@langchain/core/tracers/run_collector";
import { LangChainTracer } from "@langchain/core/tracers/tracer_langchain";
import { BaseTracer } from "@langchain/core/tracers/base";
import { ChainValues } from "@langchain/core/utils/types";
import { Client, Example, Feedback, Run } from "langsmith";
import { EvaluationResult, RunEvaluator } from "langsmith/evaluation";
Expand Down Expand Up @@ -35,7 +35,7 @@ export type ChainOrFactory =
| (() => (obj: unknown) => unknown)
| (() => (obj: unknown) => Promise<unknown>);

class RunIdExtractor {
class SingleRunIdExtractor {
runIdPromiseResolver: (runId: string) => void;

runIdPromise: Promise<string>;
Expand All @@ -59,6 +59,30 @@ class RunIdExtractor {
}
}

class SingleRunExtractor extends BaseTracer {
runPromiseResolver: (run: Run) => void;

runPromise: Promise<Run>;

/** The name of the callback handler. */
name = "single_run_extractor";

constructor() {
super();
this.runPromise = new Promise<Run>((extract) => {
this.runPromiseResolver = extract;
});
}

async persistRun(run: Run) {
this.runPromiseResolver(run);
}

async extract(): Promise<Run> {
return this.runPromise;
}
}

/**
* Wraps an evaluator function + implements the RunEvaluator interface.
*/
Expand All @@ -76,7 +100,7 @@ class DynamicRunEvaluator implements RunEvaluator {
* @returns A promise that extracts to the evaluation result.
*/
async evaluateRun(run: Run, example?: Example): Promise<EvaluationResult> {
const extractor = new RunIdExtractor();
const extractor = new SingleRunIdExtractor();
const tracer = new LangChainTracer({ projectName: "evaluators" });
const result = await this.evaluator.invoke(
{
Expand Down Expand Up @@ -169,7 +193,7 @@ class PreparedRunEvaluator implements RunEvaluator {
rawReferenceOutput: example?.outputs,
run,
});
const extractor = new RunIdExtractor();
const extractor = new SingleRunIdExtractor();
const tracer = new LangChainTracer({ projectName: "evaluators" });
if (this.isStringEvaluator) {
const evalResult = await this.evaluator.evaluateStrings(
Expand Down Expand Up @@ -278,25 +302,23 @@ const loadExamples = async ({
}) => {
const exampleIterator = client.listExamples({ datasetName });
const configs: RunnableConfig[] = [];
const runCollectors = [];
const runExtractors = [];
const examples = [];
for await (const example of exampleIterator) {
const runCollector = new RunCollectorCallbackHandler({
exampleId: example.id,
});
const runExtractor = new SingleRunExtractor();
configs.push({
callbacks: [
new LangChainTracer({ exampleId: example.id, projectName }),
runCollector,
runExtractor,
],
});
examples.push(example);
runCollectors.push(runCollector);
runExtractors.push(runExtractor);
}
return {
configs,
examples,
runCollectors,
runExtractors,
};
};

Expand Down Expand Up @@ -456,7 +478,7 @@ export const runOnDataset = async (
const dataset = await testClient.readDataset({ datasetName });
const datasetId = dataset.id;
const testConcurrency = maxConcurrency ?? 5;
const { configs, examples, runCollectors } = await loadExamples({
const { configs, examples, runExtractors } = await loadExamples({
datasetName,
client: testClient,
projectName: testProjectName,
Expand Down Expand Up @@ -494,7 +516,7 @@ export const runOnDataset = async (
progress.complete();
const runs: Run[] = [];
for (let i = 0; i < examples.length; i += 1) {
runs.push(runCollectors[i].tracedRuns[0]);
runs.push(await runExtractors[i].extract());
}
let evalResults: Record<
string,
Expand Down
Loading

0 comments on commit f893b04

Please sign in to comment.