Skip to content

Commit

Permalink
Make dataset runner parallel
Browse files Browse the repository at this point in the history
For the LCP Client
  • Loading branch information
vowelparrot committed May 12, 2023
2 parents 32c1c33 + 0df18d8 commit a3fe67b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 47 deletions.
97 changes: 52 additions & 45 deletions langchain/src/client/langchainplus.ts
Expand Up @@ -406,16 +406,17 @@ export class LangChainPlusClient {
llm: BaseLLM,
numRepetitions = 1
): Promise<(LLMResult | string)[]> {
const results: (LLMResult | string)[] = [];
for (let i = 0; i < numRepetitions; i += 1) {
try {
const prompts = example.inputs.prompts as string[];
results.push(await llm.generate(prompts, undefined, [tracer]));
} catch (e) {
console.error(e);
results.push(stringifyError(e));
}
}
const results: (LLMResult | string)[] = await Promise.all(
Array.from({ length: numRepetitions }).map(async () => {
try {
const prompts = example.inputs.prompts as string[];
return await llm.generate(prompts, undefined, [tracer]);
} catch (e) {
console.error(e);
return stringifyError(e);
}
})
);
return results;
}

Expand All @@ -425,15 +426,16 @@ export class LangChainPlusClient {
chain: BaseChain,
numRepetitions = 1
): Promise<(ChainValues | string)[]> {
const results: (ChainValues | string)[] = [];
for (let i = 0; i < numRepetitions; i += 1) {
try {
results.push(await chain.call(example.inputs, [tracer]));
} catch (e) {
console.error(e);
results.push(stringifyError(e));
}
}
const results: (ChainValues | string)[] = await Promise.all(
Array.from({ length: numRepetitions }).map(async () => {
try {
return await chain.call(example.inputs, [tracer]);
} catch (e) {
console.error(e);
return stringifyError(e);
}
})
);
return results;
}

Expand All @@ -444,37 +446,42 @@ export class LangChainPlusClient {
sessionName: string | undefined = undefined
): Promise<DatasetRunResults> {
const examples = await this.listExamples(undefined, datasetName);
let sessionName_ = sessionName;
let sessionName_: string;
if (sessionName === undefined) {
const currentTime = new Date().toISOString();
sessionName_ = `${datasetName}-${llmOrChain.constructor.name}-${currentTime}`;
} else {
sessionName_ = sessionName;
}
const results: DatasetRunResults = {};
const tracer = new LangChainTracer();
await tracer.newSession(sessionName_);
for (const example of examples) {
if (isLLM(llmOrChain)) {
const llmResult = await this.runLLM(
example,
tracer,
llmOrChain,
numRepetitions
);
results[example.id] = llmResult;
} else if (isChain(llmOrChain)) {
const ChainResult = await this.runChain(
example,
tracer,
llmOrChain,
numRepetitions
);
results[example.id] = ChainResult;
} else if (isChatModel(llmOrChain)) {
throw new Error("Chat models not yet supported");
} else {
throw new Error(` llm or chain type: ${llmOrChain}`);
}
}
await new LangChainTracer().newSession(sessionName_);
await Promise.all(
examples.map(async (example) => {
const tracer = new LangChainTracer(example.id);
await tracer.loadSession(sessionName_);
if (isLLM(llmOrChain)) {
const llmResult = await this.runLLM(
example,
tracer,
llmOrChain,
numRepetitions
);
results[example.id] = llmResult;
} else if (isChain(llmOrChain)) {
const ChainResult = await this.runChain(
example,
tracer,
llmOrChain,
numRepetitions
);
results[example.id] = ChainResult;
} else if (isChatModel(llmOrChain)) {
throw new Error("Chat models not yet supported");
} else {
throw new Error(` llm or chain type: ${llmOrChain}`);
}
})
);
return results;
}
}
4 changes: 2 additions & 2 deletions langchain/src/client/tests/langchainplus.int.test.ts
Expand Up @@ -60,7 +60,7 @@ test("Test LangChainPlus Client Dataset CRD", async () => {
expect(deleted.id).toBe(datasetId);
});

test.skip("Test LangChainPlus Client Run Chain Over Dataset", async () => {
test("Test LangChainPlus Client Run Chain Over Dataset", async () => {
const client: LangChainPlusClient = await LangChainPlusClient.create(
"http://localhost:8000"
);
Expand Down Expand Up @@ -112,5 +112,5 @@ what is 1213 divided by 4345?,approximately 0.2791714614499425

const results = await client.runOnDataset(datasetName, executor);
console.log(results);
expect(results.length).toEqual(10);
expect(Object.keys(results).length).toEqual(10);
});

1 comment on commit a3fe67b

@vercel
Copy link

@vercel vercel bot commented on a3fe67b May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.