diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index 6854eb03..74058a6b 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -29,6 +29,7 @@ import { IconArrowMerge, IconArrowsSplit, IconForms, + IconAbacus, } from "@tabler/icons-react"; import RemoveEdge from "./RemoveEdge"; import TextFieldsNode from "./TextFieldsNode"; // Import a custom node @@ -88,6 +89,7 @@ import { isEdgeChromium, isChromium, } from "react-device-detect"; +import MultiEvalNode from "./MultiEvalNode"; const IS_ACCEPTED_BROWSER = (isChrome || @@ -157,6 +159,7 @@ const nodeTypes = { simpleval: SimpleEvalNode, evaluator: CodeEvaluatorNode, llmeval: LLMEvaluatorNode, + multieval: MultiEvalNode, vis: VisNode, inspect: InspectNode, script: ScriptNode, @@ -328,6 +331,7 @@ const App = () => { const addTabularDataNode = () => addNode("table"); const addCommentNode = () => addNode("comment"); const addLLMEvalNode = () => addNode("llmeval"); + const addMultiEvalNode = () => addNode("multieval"); const addJoinNode = () => addNode("join"); const addSplitNode = () => addNode("split"); const addProcessorNode = (progLang: string) => { @@ -1052,6 +1056,15 @@ const App = () => { LLM Scorer{" "} + + } + > + {" "} + Multi-Evaluator{" "} + + Visualizers diff --git a/chainforge/react-server/src/CodeEvaluatorNode.tsx b/chainforge/react-server/src/CodeEvaluatorNode.tsx index f11a5eaa..b14654f0 100644 --- a/chainforge/react-server/src/CodeEvaluatorNode.tsx +++ b/chainforge/react-server/src/CodeEvaluatorNode.tsx @@ -33,6 +33,7 @@ import "ace-builds/src-noconflict/theme-xcode"; import "ace-builds/src-noconflict/ext-language_tools"; import { APP_IS_RUNNING_LOCALLY, + genDebounceFunc, getVarsAndMetavars, stripLLMDetailsFromResponses, toStandardResponseFormat, @@ -188,6 +189,7 @@ export interface CodeEvaluatorComponentProps { onCodeEdit?: (code: string) => void; onCodeChangedFromLastRun?: () => void; onCodeEqualToLastRun?: () => void; + sandbox?: boolean; } /** @@ -206,6 +208,7 @@ export const CodeEvaluatorComponent = forwardRef< onCodeEdit, onCodeChangedFromLastRun, onCodeEqualToLastRun, + sandbox, }, ref, ) { @@ -215,6 +218,10 @@ export const CodeEvaluatorComponent = forwardRef< false, ); + // Debounce helpers + const debounceTimeoutRef = useRef(null); + const debounce = genDebounceFunc(debounceTimeoutRef); + // Controlled handle when user edits code const handleCodeEdit = (code: string) => { if (codeTextOnLastRun !== false) { @@ -223,7 +230,10 @@ export const CodeEvaluatorComponent = forwardRef< else if (!code_changed && onCodeEqualToLastRun) onCodeEqualToLastRun(); } setCodeText(code); - if (onCodeEdit) onCodeEdit(code); + + // Debounce to control number of re-renders to parent, when user is editing/typing: + if (onCodeEdit) + debounce(() => onCodeEdit(code), 200)(); }; // Runs the code evaluator/processor over the inputs, returning the results as a Promise. @@ -233,6 +243,8 @@ export const CodeEvaluatorComponent = forwardRef< script_paths?: string[], runInSandbox?: boolean, ) => { + if (runInSandbox === undefined) runInSandbox = sandbox; + // Double-check that the code includes an 'evaluate' or 'process' function, whichever is needed: const find_func_regex = node_type === "evaluator" @@ -317,7 +329,7 @@ export const CodeEvaluatorComponent = forwardRef< mode={progLang} theme="xcode" onChange={handleCodeEdit} - value={code} + value={codeText} name={"aceeditor_" + id} editorProps={{ $blockScrolling: true }} width="100%" diff --git a/chainforge/react-server/src/LLMEvalNode.tsx b/chainforge/react-server/src/LLMEvalNode.tsx index ec0dfabc..35c8ee0d 100644 --- a/chainforge/react-server/src/LLMEvalNode.tsx +++ b/chainforge/react-server/src/LLMEvalNode.tsx @@ -21,7 +21,7 @@ import LLMResponseInspectorModal, { } from "./LLMResponseInspectorModal"; import InspectFooter from "./InspectFooter"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; -import { stripLLMDetailsFromResponses } from "./backend/utils"; +import { genDebounceFunc, stripLLMDetailsFromResponses } from "./backend/utils"; import { AlertModalContext } from "./AlertModal"; import { Dict, LLMResponse, LLMSpec, QueryProgress } from "./backend/typing"; import { Status } from "./StatusIndicatorComponent"; @@ -116,11 +116,18 @@ export const LLMEvaluatorComponent = forwardRef< ); const apiKeys = useStore((state) => state.apiKeys); + // Debounce helpers + const debounceTimeoutRef = useRef(null); + const debounce = genDebounceFunc(debounceTimeoutRef); + const handlePromptChange = useCallback( (e: React.ChangeEvent) => { // Store prompt text setPromptText(e.target.value); - if (onPromptEdit) onPromptEdit(e.target.value); + + // Update the caller, but debounce to reduce the number of callbacks when user is typing + if (onPromptEdit) + debounce(() => onPromptEdit(e.target.value), 200)(); }, [setPromptText, onPromptEdit], ); @@ -157,36 +164,49 @@ export const LLMEvaluatorComponent = forwardRef< " " + formatting_instr + "\n```\n{input}\n```"; - - // Keeping track of progress (unpacking the progress state since there's only a single LLM) const llm_key = llmScorers[0].key ?? ""; - const _progress_listener = onProgressChange - ? (progress_by_llm: Dict) => - onProgressChange({ - success: progress_by_llm[llm_key].success, - error: progress_by_llm[llm_key].error, - }) - : undefined; - - // Run LLM as evaluator - return evalWithLLM( - id ?? Date.now().toString(), - llmScorers[0], - template, - input_node_ids, - apiKeys ?? {}, - _progress_listener, - ).then(function (res) { - // Check if there's an error; if so, bubble it up to user and exit: - if (res.errors && res.errors.length > 0) throw new Error(res.errors[0]); - else if (res.responses === undefined) - throw new Error( - "Unknown error encountered when requesting evaluations: empty response returned.", - ); - // Success! - return res.responses; - }); + // Fetch info about the number of queries we'll need to make + return grabResponses(input_node_ids) + .then(function (resps) { + // Create progress listener + // Keeping track of progress (unpacking the progress state since there's only a single LLM) + const num_resps_required = resps.reduce( + (acc, resp_obj) => acc + resp_obj.responses.length, + 0, + ); + return onProgressChange + ? (progress_by_llm: Dict) => + onProgressChange({ + success: + (100 * progress_by_llm[llm_key].success) / num_resps_required, + error: + (100 * progress_by_llm[llm_key].error) / num_resps_required, + }) + : undefined; + }) + .then((progress_listener) => { + // Run LLM as evaluator + return evalWithLLM( + id ?? Date.now().toString(), + llmScorers[0], + template, + input_node_ids, + apiKeys ?? {}, + progress_listener, + ); + }) + .then(function (res) { + // Check if there's an error; if so, bubble it up to user and exit: + if (res.errors && res.errors.length > 0) throw new Error(res.errors[0]); + else if (res.responses === undefined) + throw new Error( + "Unknown error encountered when requesting evaluations: empty response returned.", + ); + + // Success! + return res.responses; + }); }; // Export the current internal state as JSON @@ -305,41 +325,22 @@ const LLMEvaluatorNode: React.FC = ({ data, id }) => { if (showAlert) showAlert(typeof err === "string" ? err : err?.message); }; - // Fetch info about the number of queries we'll need to make - grabResponses(input_node_ids) - .then(function (resps) { - // Create progress listener - const num_resps_required = resps.reduce( - (acc, resp_obj) => acc + resp_obj.responses.length, - 0, - ); - const onProgressChange = (prog: QueryProgress) => { - setProgress({ - success: (100 * prog.success) / num_resps_required, - error: (100 * prog.error) / num_resps_required, - }); - }; - - // Run LLM evaluator - llmEvaluatorRef?.current - ?.run(input_node_ids, onProgressChange) - .then(function (evald_resps) { - // Ping any vis + inspect nodes attached to this node to refresh their contents: - pingOutputNodes(id); - - console.log(evald_resps); - setLastResponses(evald_resps); - - if (!showDrawer) setUninspectedResponses(true); - - setStatus(Status.READY); - setProgress(undefined); - }) - .catch(handleError); + // Run LLM evaluator + llmEvaluatorRef?.current + ?.run(input_node_ids, setProgress) + .then(function (evald_resps) { + // Ping any vis + inspect nodes attached to this node to refresh their contents: + pingOutputNodes(id); + + console.log(evald_resps); + setLastResponses(evald_resps); + + if (!showDrawer) setUninspectedResponses(true); + + setStatus(Status.READY); + setProgress(undefined); }) - .catch(() => { - handleError("Error pulling input data for node: No input data found."); - }); + .catch(handleError); }, [ inputEdgesForNode, llmEvaluatorRef, diff --git a/chainforge/react-server/src/LLMResponseInspector.tsx b/chainforge/react-server/src/LLMResponseInspector.tsx index da05760f..2c062fb2 100644 --- a/chainforge/react-server/src/LLMResponseInspector.tsx +++ b/chainforge/react-server/src/LLMResponseInspector.tsx @@ -15,6 +15,8 @@ import { ActionIcon, Tooltip, TextInput, + Stack, + ScrollArea, } from "@mantine/core"; import { useToggle } from "@mantine/hooks"; import { @@ -36,6 +38,7 @@ import { ResponseBox, ResponseGroup, genResponseTextsDisplay, + getEvalResultStr, } from "./ResponseBoxes"; import { getLabelForResponse } from "./ResponseRatingToolbar"; import { @@ -51,6 +54,34 @@ const getLLMName = (resp_obj: LLMResponse) => const escapeRegExp = (txt: string) => txt.replace(/[-[\]{}()*+?.,\\^$|#\s]/g, "\\$&"); +function getEvalResCols(responses: LLMResponse[]) { + // Look for + extract any consistent, *named* evaluation metrics (dicts) + const metric_names = new Set(); + let has_unnamed_metric = false; + let eval_res_cols = []; + responses.forEach((res_obj) => { + if (res_obj?.eval_res?.items === undefined) return; + res_obj.eval_res.items.forEach((item) => { + if (typeof item !== "object") { + has_unnamed_metric = true; + return; + } + Object.keys(item).forEach((metric_name) => metric_names.add(metric_name)); + }); + }); + + if (metric_names.size === 0 || has_unnamed_metric) + // None found, but there are scores, OR, there is at least one unnamed score. Add a generic col for scores: + eval_res_cols.push("Score"); + + if (metric_names.size > 0) { + // Add a column for each named metric: + eval_res_cols = eval_res_cols.concat(Array.from(metric_names)); + } + + return eval_res_cols; +} + function getIndicesOfSubstringMatches( s: string, substr: string, @@ -265,7 +296,7 @@ const LLMResponseInspector: React.FC = ({ ); // The var name to use for columns in the table view - const [tableColVar, setTableColVar] = useState("LLM"); + const [tableColVar, setTableColVar] = useState("$LLM"); const [userSelectedTableCol, setUserSelectedTableCol] = useState(false); // State of the 'only show scores' toggle when eval results are present @@ -306,6 +337,12 @@ const LLMResponseInspector: React.FC = ({ const contains_eval_res = batchedResponses.some( (res_obj) => res_obj.eval_res !== undefined, ); + const contains_multi_evals = contains_eval_res + ? batchedResponses.some((res_obj) => { + const items = res_obj.eval_res?.items; + return items && items.length > 0 && typeof items[0] === "object"; + }) + : false; setShowEvalScoreOptions(contains_eval_res); // Set the variables accessible in the MultiSelect for 'group by' @@ -315,25 +352,37 @@ const LLMResponseInspector: React.FC = ({ // in the future we can add special types of variables without name collisions ({ value: name, label: name }), ) - .concat({ value: "LLM", label: "LLM" }); + .concat({ value: "$LLM", label: "LLM" }); + if (contains_eval_res && viewFormat === "table") + msvars.push({ value: "$EVAL_RES", label: "Eval results" }); setMultiSelectVars(msvars); // If only one LLM is present, and user hasn't manually selected one to plot, // and there's more than one prompt variable as input, default to plotting the - // first found prompt variable as columns instead: + // eval scores, or the first found prompt variable as columns instead: if ( viewFormat === "table" && !userSelectedTableCol && - tableColVar === "LLM" && - found_llms.length === 1 && - found_vars.length > 1 + tableColVar === "$LLM" ) { - setTableColVar(found_vars[0]); - return; // useEffect will replot with the new values + if ( + contains_multi_evals || + (found_llms.length === 1 && contains_eval_res) + ) { + // Plot eval scores on columns + setTableColVar("$EVAL_RES"); + return; + } else if (found_llms.length === 1 && found_vars.length > 1) { + setTableColVar(found_vars[0]); + return; // useEffect will replot with the new values + } } // If this is the first time receiving responses, set the multiSelectValue to whatever is the first: if (!receivedResponsesOnce) { + if (contains_multi_evals) + // If multiple evals are detected, default to "table" format: + setViewFormat("table"); setMultiSelectValue([msvars[0].value]); setReceivedResponsesOnce(true); } else if ( @@ -411,6 +460,7 @@ const LLMResponseInspector: React.FC = ({ resps: LLMResponse[], eatenvars: string[], fixed_width: number, + hide_eval_scores?: boolean, ) => { const hide_llm_name = eatenvars.includes("LLM"); return resps.map((res_obj, res_idx) => { @@ -435,6 +485,7 @@ const LLMResponseInspector: React.FC = ({ contains_eval_res && onlyShowScores, hide_llm_name ? undefined : getLLMName(res_obj), wideFormat, + hide_eval_scores, ); // At the deepest level, there may still be some vars left over. We want to display these @@ -467,9 +518,10 @@ const LLMResponseInspector: React.FC = ({ let var_cols: string[], colnames: string[], getColVal: (r: LLMResponse) => string | number | undefined, - found_sel_var_vals: string[]; + found_sel_var_vals: string[], + eval_res_cols: string[]; let metavar_cols: string[] = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying. - if (tableColVar === "LLM") { + if (tableColVar === "$LLM") { var_cols = found_vars; getColVal = getLLMName; found_sel_var_vals = found_llms; @@ -480,19 +532,38 @@ const LLMResponseInspector: React.FC = ({ .filter((v) => v !== tableColVar) .concat(found_llms.length > 1 ? ["LLM"] : []); // only add LLM column if num LLMs > 1 getColVal = (r) => r.vars[tableColVar]; + colnames = var_cols; + found_sel_var_vals = []; + } + // If the user wants to plot eval results in separate column, OR there's only a single LLM to show + if (tableColVar === "$EVAL_RES") { + // Plot evaluation results on separate column(s): + eval_res_cols = getEvalResCols(responses); + // if (tableColVar === "$EVAL_RES") { + // This adds a column, "Response", abusing the way getColVal and found_sel_var_vals is used + // below by making a dummy value (one giant group with all responses in it). We then + // sort the responses by LLM, to give a nicer view. + colnames = colnames.concat("Response", eval_res_cols); + getColVal = () => "_"; + found_sel_var_vals = ["_"]; + responses.sort((a, b) => getLLMName(a).localeCompare(getLLMName(b))); + // } else { + // colnames = colnames.concat(eval_res_cols); + // } + } else if (tableColVar !== "$LLM") { // Get the unique values for the selected variable found_sel_var_vals = Array.from( responses.reduce((acc, res_obj) => { acc.add( tableColVar in res_obj.vars - ? (res_obj.vars[tableColVar] as string) + ? res_obj.vars[tableColVar] : "(unspecified)", ); return acc; }, new Set()), ); - colnames = var_cols.concat(found_sel_var_vals); + colnames = colnames.concat(found_sel_var_vals); } const getVar = (r: LLMResponse, v: string) => @@ -517,6 +588,27 @@ const LLMResponseInspector: React.FC = ({ const val = resp_objs[0].metavars[v]; return val !== undefined ? val : "(unspecified)"; }); + let eval_cols_vals: React.ReactNode[] = []; + if (eval_res_cols && eval_res_cols.length > 0) { + // We can assume that there's only one response object, since to + // if eval_res_cols is set, there must be only one LLM. + eval_cols_vals = eval_res_cols.map((metric_name, metric_idx) => { + const items = resp_objs[0].eval_res?.items; + if (!items) return "(no result)"; + return items.map((item) => { + if (item === undefined) return "(undefined)"; + if ( + typeof item !== "object" && + metric_idx === 0 && + metric_name === "Score" + ) + return getEvalResultStr(item, true); + else if (typeof item === "object" && metric_name in item) + return getEvalResultStr(item[metric_name], true); + else return "(unspecified)"; + }); // treat n>1 resps per prompt as multi-line results in the column + }); + } const resp_objs_by_col_var = groupResponsesBy( resp_objs, getColVal, @@ -526,7 +618,14 @@ const LLMResponseInspector: React.FC = ({ const rs = resp_objs_by_col_var[val]; // Return response divs as response box here: return ( -
{generateResponseBoxes(rs, var_cols, 100)}
+
+ {generateResponseBoxes( + rs, + var_cols, + 100, + eval_res_cols !== undefined, + )} +
); } else { return {empty_cell_text}; @@ -534,10 +633,12 @@ const LLMResponseInspector: React.FC = ({ }); return ( - + {var_cols_vals.map((c, i) => ( - {c} + + {c} + ))} {metavar_cols_vals.map((c, i) => ( @@ -550,13 +651,25 @@ const LLMResponseInspector: React.FC = ({ {c} ))} + {eval_cols_vals.map((c, i) => ( + + {c} + + ))} ); }, ); setResponseDivs([ - +
{colnames.map((c) => ( @@ -603,7 +716,7 @@ const LLMResponseInspector: React.FC = ({ const defaultOpened = !first_opened || eatenvars.length === 0 || - eatenvars[eatenvars.length - 1] === "LLM"; + eatenvars[eatenvars.length - 1] === "$LLM"; first_opened = true; leaf_id += 1; return ( @@ -627,13 +740,13 @@ const LLMResponseInspector: React.FC = ({ // we also bucket any 'leftover' responses that didn't have the requested variable (a kind of 'soft fail') const group_name = varnames[0]; const [grouped_resps, leftover_resps] = - group_name === "LLM" + group_name === "$LLM" ? groupResponsesBy(resps, getLLMName) : groupResponsesBy(resps, (r) => group_name in r.vars ? r.vars[group_name] : null, ); const get_header = - group_name === "LLM" + group_name === "$LLM" ? (key: string, val?: string) => (
= ({ const defaultOpened = !first_opened || eatenvars.length === 0 || - eatenvars[eatenvars.length - 1] === "LLM"; + eatenvars[eatenvars.length - 1] === "$LLM"; const grouped_resps_divs = Object.keys(grouped_resps).map((g) => groupByVars( grouped_resps[g], diff --git a/chainforge/react-server/src/MultiEvalNode.tsx b/chainforge/react-server/src/MultiEvalNode.tsx new file mode 100644 index 00000000..d010679c --- /dev/null +++ b/chainforge/react-server/src/MultiEvalNode.tsx @@ -0,0 +1,872 @@ +import React, { + useState, + useCallback, + useEffect, + useMemo, + useRef, + useContext, +} from "react"; +import { Handle, Position } from "reactflow"; +import { v4 as uuid } from "uuid"; +import { + TextInput, + Text, + Group, + ActionIcon, + Menu, + Card, + rem, + Collapse, + Button, + Alert, + Tooltip, +} from "@mantine/core"; +import { useDisclosure } from "@mantine/hooks"; +import { + IconAbacus, + IconBox, + IconChevronDown, + IconChevronRight, + IconDots, + IconPlus, + IconRobot, + IconSearch, + IconSparkles, + IconTerminal, + IconTrash, +} from "@tabler/icons-react"; +import BaseNode from "./BaseNode"; +import NodeLabel from "./NodeLabelComponent"; +import InspectFooter from "./InspectFooter"; +import LLMResponseInspectorModal, { + LLMResponseInspectorModalRef, +} from "./LLMResponseInspectorModal"; +import useStore from "./store"; +import { + APP_IS_RUNNING_LOCALLY, + batchResponsesByUID, + genDebounceFunc, + toStandardResponseFormat, +} from "./backend/utils"; +import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; +import { + CodeEvaluatorComponent, + CodeEvaluatorComponentRef, +} from "./CodeEvaluatorNode"; +import { LLMEvaluatorComponent, LLMEvaluatorComponentRef } from "./LLMEvalNode"; +import { GatheringResponsesRingProgress } from "./LLMItemButtonGroup"; +import { Dict, LLMResponse, QueryProgress } from "./backend/typing"; +import { AlertModalContext } from "./AlertModal"; +import { Status } from "./StatusIndicatorComponent"; + +const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY(); + +const EVAL_TYPE_PRETTY_NAME = { + python: "Python", + javascript: "JavaScript", + llm: "LLM", +}; + +export interface EvaluatorContainerProps { + name: string; + type: string; + padding?: string | number; + onDelete: () => void; + onChangeTitle: (newTitle: string) => void; + progress?: QueryProgress; + customButton?: React.ReactNode; + children: React.ReactNode; + initiallyOpen?: boolean; +} + +/** A wrapper for a single evaluator, that can be renamed */ +const EvaluatorContainer: React.FC = ({ + name, + type: evalType, + padding, + onDelete, + onChangeTitle, + progress, + customButton, + children, + initiallyOpen, +}) => { + const [opened, { toggle }] = useDisclosure(initiallyOpen ?? false); + const _padding = useMemo(() => padding ?? "0px", [padding]); + const [title, setTitle] = useState(name ?? "Criteria"); + + const handleChangeTitle = (newTitle: string) => { + setTitle(newTitle); + if (onChangeTitle) onChangeTitle(newTitle); + }; + + return ( + + + + + + setTitle(e.target.value)} + onBlur={(e) => handleChangeTitle(e.target.value)} + placeholder="Criteria name" + variant="unstyled" + size="sm" + className="nodrag nowheel" + styles={{ + input: { + padding: "0px", + height: "14pt", + minHeight: "0pt", + fontWeight: 500, + }, + }} + /> + + + {customButton} + + + {evalType} + + + {progress ? ( + + ) : ( + <> + )} + {/* */} + + + + + + + + + {/* }> + Inspect scores + + }> + Help / info + */} + } + color="red" + onClick={onDelete} + > + Delete + + + + + + + + + {children} + + + ); +}; + +export interface EvaluatorContainerDesc { + name: string; // the user's nickname for the evaluator, which displays as the title of the banner + uid: string; // a unique identifier for this evaluator, since name can change + type: "python" | "javascript" | "llm"; // the type of evaluator + state: Dict; // the internal state necessary for that specific evaluator component (e.g., a prompt for llm eval, or code for code eval) + progress?: QueryProgress; + justAdded?: boolean; +} + +export interface MultiEvalNodeProps { + data: { + evaluators: EvaluatorContainerDesc[]; + refresh: boolean; + title: string; + }; + id: string; +} + +/** A node that stores multiple evaluator functions (can be mix of LLM scorer prompts and arbitrary code.) */ +const MultiEvalNode: React.FC = ({ data, id }) => { + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + const pullInputData = useStore((state) => state.pullInputData); + const pingOutputNodes = useStore((state) => state.pingOutputNodes); + const bringNodeToFront = useStore((state) => state.bringNodeToFront); + const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); + + const flags = useStore((state) => state.flags); + const AI_SUPPORT_ENABLED = useMemo(() => { + return flags.aiSupport; + }, [flags]); + + const [status, setStatus] = useState(Status.NONE); + // For displaying error messages to user + const showAlert = useContext(AlertModalContext); + const inspectModal = useRef(null); + + // -- EvalGen access -- + // const pickCriteriaModalRef = useRef(null); + // const onClickPickCriteria = () => { + // const inputs = handlePullInputs(); + // pickCriteriaModalRef?.current?.trigger(inputs, (implementations: EvaluatorContainerDesc[]) => { + // // Returned if/when the Pick Criteria modal finishes generating implementations. + // console.warn(implementations); + // // Append the returned implementations to the end of the existing eval list + // setEvaluators((evs) => evs.concat(implementations)); + // }); + // }; + + const [uninspectedResponses, setUninspectedResponses] = useState(false); + const [lastResponses, setLastResponses] = useState([]); + const [lastRunSuccess, setLastRunSuccess] = useState(true); + const [showDrawer, setShowDrawer] = useState(false); + + // Debounce helpers + const debounceTimeoutRef = useRef(null); + const debounce = genDebounceFunc(debounceTimeoutRef); + + /** Store evaluators as array of JSON serialized state: + * { name: // the user's nickname for the evaluator, which displays as the title of the banner + * type: 'python' | 'javascript' | 'llm' // the type of evaluator + * state: // the internal state necessary for that specific evaluator component (e.g., a prompt for llm eval, or code for code eval) + * } + */ + const [evaluators, setEvaluators] = useState(data.evaluators ?? []); + + // Add an evaluator to the end of the list + const addEvaluator = useCallback( + (name: string, type: EvaluatorContainerDesc["type"], state: Dict) => { + setEvaluators(evaluators.concat({ name, uid: uuid(), type, state, justAdded: true })); + }, + [evaluators], + ); + + // Sync evaluator state to stored state of this node + useEffect(() => { + setDataPropsForNode(id, { evaluators: evaluators.map((e) => ({...e, justAdded: undefined})) }); + }, [evaluators]); + + // Generate UI for the evaluator state + const evaluatorComponentRefs = useRef< + { + type: "code" | "llm"; + name: string; + ref: CodeEvaluatorComponentRef | LLMEvaluatorComponentRef | null; + }[] + >([]); + + const updateEvalState = ( + idx: number, + transformFunc: (e: EvaluatorContainerDesc) => void, + ) => { + setStatus(Status.WARNING); + setEvaluators((es) => + es.map((e, i) => { + if (idx === i) transformFunc(e); + return e; + }), + ); + }; + + // const evaluatorComponents = useMemo(() => { + // // evaluatorComponentRefs.current = []; + + // return evaluators.map((e, idx) => { + // let component: React.ReactNode; + // if (e.type === "python" || e.type === "javascript") { + // component = ( + // + // (evaluatorComponentRefs.current[idx] = { + // type: "code", + // name: e.name, + // ref: el, + // }) + // } + // code={e.state?.code} + // progLang={e.type} + // type="evaluator" + // id={id} + // onCodeEdit={(code) => + // updateEvalState(idx, (e) => (e.state.code = code)) + // } + // showUserInstruction={false} + // /> + // ); + // } else if (e.type === "llm") { + // component = ( + // + // (evaluatorComponentRefs.current[idx] = { + // type: "llm", + // name: e.name, + // ref: el, + // }) + // } + // prompt={e.state?.prompt} + // grader={e.state?.grader} + // format={e.state?.format} + // id={id} + // showUserInstruction={false} + // onPromptEdit={(prompt) => + // updateEvalState(idx, (e) => (e.state.prompt = prompt)) + // } + // onLLMGraderChange={(grader) => + // updateEvalState(idx, (e) => (e.state.grader = grader)) + // } + // onFormatChange={(format) => + // updateEvalState(idx, (e) => (e.state.format = format)) + // } + // /> + // ); + // } else { + // console.error( + // `Unknown evaluator type ${e.type} inside multi-evaluator node. Cannot display evaluator UI.`, + // ); + // component = Error: Unknown evaluator type {e.type}; + // } + // return ( + // { + // delete evaluatorComponentRefs.current[idx]; + // setEvaluators(evaluators.filter((_, i) => i !== idx)); + // }} + // onChangeTitle={(newTitle) => + // setEvaluators( + // evaluators.map((e, i) => { + // if (i === idx) e.name = newTitle; + // console.log(e); + // return e; + // }), + // ) + // } + // padding={e.type === "llm" ? "8px" : undefined} + // > + // {component} + // + // ); + // }); + // }, [evaluators, id]); + + const handleError = useCallback( + (err: Error | string) => { + console.error(err); + setStatus(Status.ERROR); + showAlert && showAlert(err); + }, + [showAlert, setStatus], + ); + + const handlePullInputs = useCallback(() => { + // Pull input data + try { + const pulled_inputs = pullInputData(["responseBatch"], id); + if (!pulled_inputs || !pulled_inputs.responseBatch) { + console.warn(`No inputs to the Multi-Evaluator node.`); + return []; + } + // Convert to standard response format (StandardLLMResponseFormat) + return pulled_inputs.responseBatch.map(toStandardResponseFormat); + } catch (err) { + handleError(err as Error); + return []; + } + }, [pullInputData, id, toStandardResponseFormat]); + + const handleRunClick = useCallback(() => { + // Pull inputs to the node + const pulled_inputs = handlePullInputs(); + if (!pulled_inputs || pulled_inputs.length === 0) return; + + // Get the ids from the connected input nodes: + // TODO: Remove this dependency; have everything go through pull instead. + const input_node_ids = inputEdgesForNode(id).map((e) => e.source); + if (input_node_ids.length === 0) { + console.warn("No inputs to multi-evaluator node."); + return; + } + + // Sanity check that there's evaluators in the multieval node + if ( + !evaluatorComponentRefs.current || + evaluatorComponentRefs.current.length === 0 + ) { + console.error("Cannot run multievals: No current evaluators found."); + return; + } + + // Set status and created rejection callback + setStatus(Status.LOADING); + setLastResponses([]); + + // Helper function to update progress ring on a single evaluator component + const updateProgressRing = ( + evaluator_idx: number, + progress?: QueryProgress, + ) => { + // Update the progress rings, debouncing to avoid too many rerenders + debounce( + (_idx, _progress) => + setEvaluators((evs) => { + if (_idx >= evs.length) return evs; + evs[_idx].progress = _progress; + return [...evs]; + }), + 30, + )(evaluator_idx, progress); + }; + + // Run all evaluators here! + // TODO + const runPromises = evaluatorComponentRefs.current.map( + ({ type, name, ref }, idx) => { + if (ref === null) return { type: "error", name, result: null }; + + // Start loading spinner status on running evaluators + updateProgressRing(idx, { success: 0, error: 0 }); + + // Run each evaluator + if (type === "code") { + // Run code evaluator + // TODO: Change runInSandbox to be user-controlled, for Python code evals (right now it is always sandboxed) + return (ref as CodeEvaluatorComponentRef) + .run(pulled_inputs, undefined) + .then((ret) => { + console.log("Code evaluator done!", ret); + updateProgressRing(idx, undefined); + if (ret.error !== undefined) throw new Error(ret.error); + return { + type: "code", + name, + result: ret.responses, + }; + }); + } else { + // Run LLM-based evaluator + // TODO: Add back live progress, e.g. (progress) => updateProgressRing(idx, progress)) but with appropriate mapping for progress. + return (ref as LLMEvaluatorComponentRef) + .run(input_node_ids, (progress) => { + updateProgressRing(idx, progress); + }) + .then((ret) => { + console.log("LLM evaluator done!", ret); + updateProgressRing(idx, undefined); + return { + type: "llm", + name, + result: ret, + }; + }); + } + }, + ); + + // When all evaluators finish... + Promise.allSettled(runPromises).then((settled) => { + if (settled.some((s) => s.status === "rejected")) { + setStatus(Status.ERROR); + setLastRunSuccess(false); + // @ts-expect-error Reason exists on rejected settled promises, but TS doesn't know it for some reason. + handleError(settled.find((s) => s.status === "rejected").reason); + return; + } + + // Remove progress rings without errors + setEvaluators((evs) => + evs.map((e) => { + if (e.progress && !e.progress.error) e.progress = undefined; + return e; + }), + ); + + // Ignore null refs + settled = settled.filter( + (s) => s.status === "fulfilled" && s.value.result !== null, + ); + + // Success -- set the responses for the inspector + // First we need to group up all response evals by UID, *within* each evaluator. + const evalResults = settled.map((s) => { + const v = + s.status === "fulfilled" + ? s.value + : { type: "code", name: "Undefined", result: [] }; + if (v.type === "llm") return v; // responses are already batched by uid + // If code evaluator, for some reason, in this version of CF the code eval has de-batched responses. + // We need to re-batch them by UID before returning, to correct this: + return { + type: v.type, + name: v.name, + result: batchResponsesByUID(v.result ?? []), + }; + }); + + // Now we have a duplicates of each response object, one per evaluator run, + // with evaluation results per evaluator. They are not yet merged. We now need + // to merge the evaluation results within response objects with the same UIDs. + // It *should* be the case (invariant) that response objects with the same UID + // have exactly the same number of evaluation results (e.g. n=3 for num resps per prompt=3). + const merged_res_objs_by_uid: Dict = {}; + // For each set of evaluation results... + evalResults.forEach(({ name, result }) => { + // For each response obj in the results... + result?.forEach((res_obj: LLMResponse) => { + // If it's not already in the merged dict, add it: + const uid = res_obj.uid; + if ( + res_obj.eval_res !== undefined && + !(uid in merged_res_objs_by_uid) + ) { + // Transform evaluation results into dict form, indexed by "name" of the evaluator: + res_obj.eval_res.items = res_obj.eval_res.items.map((item) => { + if (typeof item === "object") item = item.toString(); + return { + [name]: item, + }; + }); + res_obj.eval_res.dtype = "KeyValue_Mixed"; // "KeyValue_Mixed" enum; + merged_res_objs_by_uid[uid] = res_obj; // we don't make a copy, to save time + } else { + // It is already in the merged dict, so add the new eval results + // Sanity check that the lengths of eval result lists are equal across evaluators: + if (merged_res_objs_by_uid[uid].eval_res === undefined) return; + else if ( + // @ts-expect-error We've already checked that eval_res is defined, yet TS throws an error anyway... skip it: + merged_res_objs_by_uid[uid].eval_res.items.length !== + res_obj.eval_res?.items?.length + ) { + console.error( + `Critical error: Evaluation result lists for response ${uid} do not contain the same number of items per evaluator. Skipping...`, + ); + return; + } + // Add the new evaluation result, keyed by evaluator name: + // @ts-expect-error We've already checked that eval_res is defined, yet TS throws an error anyway... skip it: + merged_res_objs_by_uid[uid].eval_res.items.forEach((item, idx) => { + if (typeof item === "object") { + let v = res_obj.eval_res?.items[idx]; + if (typeof v === "object") v = v.toString(); + item[name] = v ?? "undefined"; + } + }); + } + }); + }); + + // We now have a dict of the form { uid: LLMResponse } + // We need return only the values of this dict: + setLastResponses(Object.values(merged_res_objs_by_uid)); + setLastRunSuccess(true); + + setStatus(Status.READY); + }); + }, [ + handlePullInputs, + pingOutputNodes, + status, + showDrawer, + evaluators, + evaluatorComponentRefs, + ]); + + const showResponseInspector = useCallback(() => { + if (inspectModal && inspectModal.current && lastResponses) { + setUninspectedResponses(false); + inspectModal.current.trigger(); + } + }, [inspectModal, lastResponses]); + + // Something changed upstream + useEffect(() => { + if (data.refresh && data.refresh === true) { + setDataPropsForNode(id, { refresh: false }); + setStatus(Status.WARNING); + } + }, [data]); + + return ( + + } + status={status} + handleRunClick={handleRunClick} + runButtonTooltip="Run all evaluators over inputs" + /> + + + {/* */} + + + {/* {evaluatorComponents} */} + {evaluators.map((e, idx) => ( + + + + ) : undefined + } + onDelete={() => { + delete evaluatorComponentRefs.current[idx]; + setEvaluators(evaluators.filter((_, i) => i !== idx)); + }} + onChangeTitle={(newTitle) => + setEvaluators((evs) => + evs.map((e, i) => { + if (i === idx) e.name = newTitle; + console.log(e); + return e; + }), + ) + } + padding={e.type === "llm" ? "8px" : undefined} + > + {e.type === "python" || e.type === "javascript" ? ( + + (evaluatorComponentRefs.current[idx] = { + type: "code", + name: e.name, + ref: el, + }) + } + code={e.state?.code} + progLang={e.type} + sandbox={e.state?.sandbox} + type="evaluator" + id={id} + onCodeEdit={(code) => + updateEvalState(idx, (e) => (e.state.code = code)) + } + showUserInstruction={false} + /> + ) : e.type === "llm" ? ( + + (evaluatorComponentRefs.current[idx] = { + type: "llm", + name: e.name, + ref: el, + }) + } + prompt={e.state?.prompt} + grader={e.state?.grader} + format={e.state?.format} + id={`${id}-${e.uid}`} + showUserInstruction={false} + onPromptEdit={(prompt) => + updateEvalState(idx, (e) => (e.state.prompt = prompt)) + } + onLLMGraderChange={(grader) => + updateEvalState(idx, (e) => (e.state.grader = grader)) + } + onFormatChange={(format) => + updateEvalState(idx, (e) => (e.state.format = format)) + } + /> + ) : ( + Error: Unknown evaluator type {e.type} + )} + + ))} + + + {/* TO IMPLEMENT */} + +
+ + + + + + + + + + + } + onClick={() => + addEvaluator( + `Criteria ${evaluators.length + 1}`, + "javascript", + { + code: "function evaluate(r) {\n\treturn r.text.length;\n}", + }, + ) + } + > + JavaScript + + {IS_RUNNING_LOCALLY ? ( + } + onClick={() => + addEvaluator(`Criteria ${evaluators.length + 1}`, "python", { + code: "def evaluate(r):\n\treturn len(r.text)", + sandbox: true, + }) + } + > + Python + + ) : ( + <> + )} + } + onClick={() => + addEvaluator(`Criteria ${evaluators.length + 1}`, "llm", { + prompt: "", + format: "bin", + }) + } + > + LLM + + {/* {AI_SUPPORT_ENABLED ? : <>} */} + {/* {AI_SUPPORT_ENABLED ? ( + } + onClick={onClickPickCriteria} + > + Let an AI decide! + + ) : ( + <> + )} */} + + +
+ + {/* EvalGen {evaluators && evaluators.length === 0 ? ( + + + + */} + {/* */} + {/* + ) : ( + <> + )} */} + + {lastRunSuccess && lastResponses && lastResponses.length > 0 ? ( + + Inspect scores  + + + } + onClick={showResponseInspector} + showNotificationDot={uninspectedResponses} + isDrawerOpen={showDrawer} + showDrawerButton={true} + onDrawerClick={() => { + setShowDrawer(!showDrawer); + setUninspectedResponses(false); + bringNodeToFront(id); + }} + /> + ) : ( + <> + )} + + +
+ ); +}; + +export default MultiEvalNode; diff --git a/chainforge/react-server/src/ResponseBoxes.tsx b/chainforge/react-server/src/ResponseBoxes.tsx index 7d484f9d..3fa0b395 100644 --- a/chainforge/react-server/src/ResponseBoxes.tsx +++ b/chainforge/react-server/src/ResponseBoxes.tsx @@ -1,5 +1,5 @@ import React, { Suspense, useMemo, lazy } from "react"; -import { Collapse, Flex } from "@mantine/core"; +import { Collapse, Flex, Stack } from "@mantine/core"; import { useDisclosure } from "@mantine/hooks"; import { truncStr } from "./backend/utils"; import { @@ -15,19 +15,25 @@ const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar")); /* HELPER FUNCTIONS */ const SUCCESS_EVAL_SCORES = new Set(["true", "yes"]); const FAILURE_EVAL_SCORES = new Set(["false", "no"]); -const getEvalResultStr = ( - eval_item: string[] | Dict | string | number | boolean, +export const getEvalResultStr = ( + eval_item: EvaluationScore, + hide_prefix: boolean, ) => { if (Array.isArray(eval_item)) { - return "scores: " + eval_item.join(", "); + return (hide_prefix ? "" : "scores: ") + eval_item.join(", "); } else if (typeof eval_item === "object") { - const strs = Object.keys(eval_item).map((key) => { + const strs = Object.keys(eval_item).map((key, j) => { let val = eval_item[key]; if (typeof val === "number" && val.toString().indexOf(".") > -1) val = val.toFixed(4); // truncate floats to 4 decimal places - return `${key}: ${val}`; + return ( +
+ {key}: + {getEvalResultStr(val, true)} +
+ ); }); - return strs.join(", "); + return {strs}; } else { const eval_str = eval_item.toString().trim().toLowerCase(); const color = SUCCESS_EVAL_SCORES.has(eval_str) @@ -37,7 +43,7 @@ const getEvalResultStr = ( : "black"; return ( <> - {"score: "} + {!hide_prefix && {"score: "}} {eval_str} ); @@ -164,10 +170,12 @@ export const genResponseTextsDisplay = ( onlyShowScores?: boolean, llmName?: string, wideFormat?: boolean, + hideEvalScores?: boolean, ): React.ReactNode[] | React.ReactNode => { if (!res_obj) return <>; - const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null; + const eval_res_items = + !hideEvalScores && res_obj.eval_res ? res_obj.eval_res.items : null; // Bucket responses that have the same text, and sort by the // number of same responses so that the top div is the most prevalent response. @@ -251,7 +259,7 @@ export const genResponseTextsDisplay = ( )} {eval_res_items ? (

- {getEvalResultStr(resp_str_to_eval_res[r])} + {getEvalResultStr(resp_str_to_eval_res[r], true)}

) : ( <> diff --git a/chainforge/react-server/src/backend/typing.ts b/chainforge/react-server/src/backend/typing.ts index 7608cdcd..3bbeb14b 100644 --- a/chainforge/react-server/src/backend/typing.ts +++ b/chainforge/react-server/src/backend/typing.ts @@ -199,6 +199,7 @@ export type EvaluationScore = | number | string | Dict; + export type EvaluationResults = { items: EvaluationScore[]; dtype: diff --git a/chainforge/react-server/src/text-fields-node.css b/chainforge/react-server/src/text-fields-node.css index 188d3ae9..8274f31a 100644 --- a/chainforge/react-server/src/text-fields-node.css +++ b/chainforge/react-server/src/text-fields-node.css @@ -4,6 +4,9 @@ .monofont { font-family: var(--monofont); } +.linebreaks { + white-space: pre-wrap; +} .text-fields-node { background-color: #fff; @@ -390,7 +393,7 @@ g.ytick text { padding-bottom: 20px; min-width: 160px; border-right: 1px solid #eee; - padding-left: 8px !important; + padding-left: 0px !important; padding-right: 0px !important; } .inspect-responses-drawer { @@ -646,17 +649,18 @@ g.ytick text { cursor: text; } .small-response-metrics { - font-size: 10pt; + font-size: 9pt; font-family: -apple-system, "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; font-weight: 500; text-align: center; border-top-left-radius: 20px; border-top-right-radius: 20px; - padding: 0px 2px 1px 0px; + padding: 0px 2px 2px 0px; margin: 8px 20% -6px 20%; - background-color: rgba(255, 255, 255, 0.3); + /* background-color: rgba(255, 255, 255, 0.3); */ color: #333; + white-space: pre-wrap; } .num-same-responses { position: relative; diff --git a/setup.py b/setup.py index 8563e37b..57534eb4 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ def readme(): setup( name='chainforge', - version='0.3.1.2', + version='0.3.1.5', packages=find_packages(), author="Ian Arawjo", description="A Visual Programming Environment for Prompt Engineering",