Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Use a Factory Method #1250

Merged
merged 1 commit into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion examples/src/client/tracing_datasets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,18 @@ what is 1213 divided by 4345?,approximately 0.2791714614499425
);
}

// Many chains incorporate memory. For independent trials over the dataset, we
// pass in a factory function that creates a new executor for each trial.
// If you know that your chain does not use memory, you can return the same
// executor for each trial.
const executorFactory = async () =>
initializeAgentExecutorWithOptions(tools, model, {
agentType: "chat-conversational-react-description",
verbose: true,
});

// If using the traced dataset, you can update the datasetName to be
// "calculator-example-dataset" or the custom name you chose.
const results = await client.runOnDataset(datasetName, executor);
const results = await client.runOnDataset(datasetName, executorFactory);
console.log(results);
};
62 changes: 48 additions & 14 deletions langchain/src/client/langchainplus.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,28 +108,57 @@ const stringifyError = (err: Error | unknown): string => {
return result;
};

export function isLLM(llm: BaseLanguageModel | BaseChain): llm is BaseLLM {
export function isLLM(
llm: BaseLanguageModel | (() => Promise<BaseChain>)
): llm is BaseLLM {
const blm = llm as BaseLanguageModel;
return (
typeof blm?._modelType === "function" && blm?._modelType() === "base_llm"
);
}

export function isChatModel(llm: BaseLanguageModel): llm is BaseChatModel {
export function isChatModel(
llm: BaseLanguageModel | (() => Promise<BaseChain>)
): llm is BaseChatModel {
const blm = llm as BaseLanguageModel;
return (
typeof blm?._modelType === "function" &&
blm?._modelType() === "base_chat_model"
);
}

export function isChain(llm: BaseLanguageModel | BaseChain): llm is BaseChain {
const bch = llm as BaseChain;
export async function isChain(
llm: BaseLanguageModel | (() => Promise<BaseChain>)
): Promise<boolean> {
if (isLLM(llm)) {
return false;
}
const bchFactory = llm as () => Promise<BaseChain>;
const bch = await bchFactory();
return (
typeof bch?._chainType === "function" && bch?._chainType() !== undefined
);
}

type _ModelType = "llm" | "chatModel" | "chainFactory";

async function getModelOrFactoryType(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted this to avoid instantiating multiple times just to check type

llm: BaseLanguageModel | (() => Promise<BaseChain>)
): Promise<_ModelType> {
if (isLLM(llm)) {
return "llm";
}
if (isChatModel(llm)) {
return "chatModel";
}
const bchFactory = llm as () => Promise<BaseChain>;
const bch = await bchFactory();
if (typeof bch?._chainType === "function") {
return "chainFactory";
}
throw new Error("Unknown model or factory type");
}

export class LangChainPlusClient {
private apiKey?: string;

Expand Down Expand Up @@ -453,12 +482,13 @@ export class LangChainPlusClient {
protected async runChain(
example: Example,
tracer: LangChainTracer,
chain: BaseChain,
chainFactory: () => Promise<BaseChain>,
numRepetitions = 1
): Promise<(ChainValues | string)[]> {
const results: (ChainValues | string)[] = await Promise.all(
Array.from({ length: numRepetitions }).map(async () => {
try {
const chain = await chainFactory();
return chain.call(example.inputs, [tracer]);
} catch (e) {
console.error(e);
Expand Down Expand Up @@ -495,51 +525,55 @@ export class LangChainPlusClient {

public async runOnDataset(
datasetName: string,
llmOrChain: BaseLanguageModel | BaseChain,
llmOrChainFactory: BaseLanguageModel | (() => Promise<BaseChain>),
numRepetitions = 1,
sessionName: string | undefined = undefined
): Promise<DatasetRunResults> {
const examples = await this.listExamples(undefined, datasetName);
let sessionName_: string;
if (sessionName === undefined) {
const currentTime = new Date().toISOString();
sessionName_ = `${datasetName}-${llmOrChain.constructor.name}-${currentTime}`;
sessionName_ = `${datasetName}-${llmOrChainFactory.constructor.name}-${currentTime}`;
} else {
sessionName_ = sessionName;
}
const results: DatasetRunResults = {};
const modelOrFactoryType = await getModelOrFactoryType(llmOrChainFactory);
await Promise.all(
examples.map(async (example) => {
const tracer = new LangChainTracer({
exampleId: example.id,
sessionName: sessionName_,
});
if (isLLM(llmOrChain)) {
if (modelOrFactoryType === "llm") {
const llm = llmOrChainFactory as BaseLLM;
const llmResult = await this.runLLM(
example,
tracer,
llmOrChain,
llm,
numRepetitions
);
results[example.id] = llmResult;
} else if (isChain(llmOrChain)) {
} else if (modelOrFactoryType === "chainFactory") {
const chainFactory = llmOrChainFactory as () => Promise<BaseChain>;
const chainResult = await this.runChain(
example,
tracer,
llmOrChain,
chainFactory,
numRepetitions
);
results[example.id] = chainResult;
} else if (isChatModel(llmOrChain)) {
} else if (modelOrFactoryType === "chatModel") {
const chatModel = llmOrChainFactory as BaseChatModel;
const chatModelResult = await this.runChatModel(
example,
tracer,
llmOrChain,
chatModel,
numRepetitions
);
results[example.id] = chatModelResult;
} else {
throw new Error(` llm or chain type: ${llmOrChain}`);
throw new Error(` llm or chain type: ${llmOrChainFactory}`);
}
})
);
Expand Down
22 changes: 12 additions & 10 deletions langchain/src/client/tests/langchainplus.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ what is 1213 divided by 4345?,approximately 0.2791714614499425
const model = new ChatOpenAI({ temperature: 0 });
const tools = [new Calculator()];

const executor = await initializeAgentExecutorWithOptions(tools, model, {
agentType: "chat-conversational-react-description",
verbose: true,
});
const executorFactory = async () =>
await initializeAgentExecutorWithOptions(tools, model, {
agentType: "chat-conversational-react-description",
verbose: true,
});
console.log("Loaded agent.");

const results = await client.runOnDataset(datasetName, executor);
const results = await client.runOnDataset(datasetName, executorFactory);
console.log(results);
expect(Object.keys(results).length).toEqual(2);
});
Expand Down Expand Up @@ -205,13 +206,14 @@ what was the tjtal number of points scored in the 2023 super bowl? what is that
new Calculator(),
];

const executor = await initializeAgentExecutorWithOptions(tools, model, {
agentType: "chat-conversational-react-description",
verbose: true,
});
const executorFactory = async () =>
await initializeAgentExecutorWithOptions(tools, model, {
agentType: "chat-conversational-react-description",
verbose: true,
});
console.log("Loaded agent.");

const results = await client.runOnDataset(datasetName, executor);
const results = await client.runOnDataset(datasetName, executorFactory);
console.log(results);
expect(Object.keys(results).length).toEqual(10);
});