Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session to project #6249

Merged
merged 6 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 11 additions & 11 deletions langchain/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def wandb_tracing_enabled(

@contextmanager
def tracing_v2_enabled(
session_name: Optional[str] = None,
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
) -> Generator[None, None, None]:
Expand All @@ -120,7 +120,7 @@ def tracing_v2_enabled(
example_id = UUID(example_id)
cb = LangChainTracer(
example_id=example_id,
session_name=session_name,
project_name=project_name,
)
tracing_v2_callback_var.set(cb)
yield
Expand All @@ -131,12 +131,12 @@ def tracing_v2_enabled(
def trace_as_chain_group(
group_name: str,
*,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
) -> Generator[CallbackManager, None, None]:
"""Get a callback manager for a chain group in a context manager."""
cb = LangChainTracer(
session_name=session_name,
project_name=project_name,
example_id=example_id,
)
cm = CallbackManager.configure(
Expand All @@ -152,12 +152,12 @@ def trace_as_chain_group(
async def atrace_as_chain_group(
group_name: str,
*,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]:
"""Get a callback manager for a chain group in a context manager."""
cb = LangChainTracer(
session_name=session_name,
project_name=project_name,
example_id=example_id,
)
cm = AsyncCallbackManager.configure(
Expand Down Expand Up @@ -1039,10 +1039,10 @@ def _configure(
tracing_v2_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING_V2") or tracer_v2 is not None
)
tracer_session = os.environ.get("LANGCHAIN_SESSION")
tracer_project = os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
)
debug = _get_debug()
if tracer_session is None:
tracer_session = "default"
if (
verbose
or debug
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def _configure(
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracerV1()
handler.load_session(tracer_session)
handler.load_session(tracer_project)
callback_manager.add_handler(handler, True)
if wandb_tracing_enabled_ and not any(
isinstance(handler, WandbTracer) for handler in callback_manager.handlers
Expand All @@ -1090,7 +1090,7 @@ def _configure(
callback_manager.add_handler(tracer_v2, True)
else:
try:
handler = LangChainTracer(session_name=tracer_session)
handler = LangChainTracer(project_name=tracer_project)
callback_manager.add_handler(handler, True)
except Exception as e:
logger.warning(
Expand Down
8 changes: 5 additions & 3 deletions langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class LangChainTracer(BaseTracer):
def __init__(
self,
example_id: Optional[Union[UUID, str]] = None,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
client: Optional[LangChainPlusClient] = None,
**kwargs: Any,
) -> None:
Expand All @@ -55,7 +55,9 @@ def __init__(
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
self.project_name = project_name or os.getenv(
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
)
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient()
Expand Down Expand Up @@ -103,7 +105,7 @@ def _persist_run_single(self, run: Run) -> None:
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
try:
self.client.create_run(**run_dict, session_name=self.session_name)
self.client.create_run(**run_dict, project_name=self.project_name)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("post", e)
Expand Down
64 changes: 31 additions & 33 deletions langchain/client/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,18 @@ async def run_coroutine_with_semaphore(
return results


async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChainTracer]:
"""
Initialize a tracer to share across tasks.

Args:
session_name: The session name for the tracer.
project_name: The project name for the tracer.

Returns:
A LangChainTracer instance with an active session.
A LangChainTracer instance with an active project.
"""
if session_name:
tracer = LangChainTracer(session_name=session_name)
if project_name:
tracer = LangChainTracer(project_name=project_name)
return tracer
else:
return None
Expand All @@ -260,12 +260,12 @@ async def arun_on_examples(
*,
concurrency_level: int = 5,
num_repetitions: int = 1,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Run the chain on examples and store traces to the specified session name.
Run the chain on examples and store traces to the specified project name.

Args:
examples: Examples to run the model or chain over
Expand All @@ -276,7 +276,7 @@ async def arun_on_examples(
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
session_name: Session name to use when tracing runs.
project_name: Project name to use when tracing runs.
verbose: Whether to print progress.
tags: Tags to add to the traces.

Expand Down Expand Up @@ -307,7 +307,7 @@ async def process_example(

await _gather_with_concurrency(
concurrency_level,
functools.partial(_tracer_initializer, session_name),
functools.partial(_tracer_initializer, project_name),
*(functools.partial(process_example, e) for e in examples),
)
return results
Expand Down Expand Up @@ -386,11 +386,11 @@ def run_on_examples(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
num_repetitions: int = 1,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Run the chain on examples and store traces to the specified session name.
"""Run the chain on examples and store traces to the specified project name.

Args:
examples: Examples to run model or chain over.
Expand All @@ -401,14 +401,14 @@ def run_on_examples(
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
session_name: Session name to use when tracing runs.
project_name: Project name to use when tracing runs.
verbose: Whether to print progress.
tags: Tags to add to the run traces.
Returns:
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, Any] = {}
tracer = LangChainTracer(session_name=session_name) if session_name else None
tracer = LangChainTracer(project_name=project_name) if project_name else None
for i, example in enumerate(examples):
result = run_llm_or_chain(
example,
Expand All @@ -425,13 +425,13 @@ def run_on_examples(
return results


def _get_session_name(
session_name: Optional[str],
def _get_project_name(
project_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str,
) -> str:
if session_name is not None:
return session_name
if project_name is not None:
return project_name
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if isinstance(llm_or_chain_factory, BaseLanguageModel):
model_name = llm_or_chain_factory.__class__.__name__
Expand All @@ -446,13 +446,13 @@ async def arun_on_dataset(
*,
concurrency_level: int = 5,
num_repetitions: int = 1,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Run the chain on a dataset and store traces to the specified session name.
Run the chain on a dataset and store traces to the specified project name.

Args:
client: Client to use to read the dataset.
Expand All @@ -464,19 +464,18 @@ async def arun_on_dataset(
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
session_name: Name of the session to store the traces in.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.

Returns:
A dictionary containing the run's session name and the resulting model outputs.
A dictionary containing the run's project name and the resulting model outputs.
"""
client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
client_.create_session(session_name, mode="eval")
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))

Expand All @@ -485,12 +484,12 @@ async def arun_on_dataset(
llm_or_chain_factory,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
session_name=session_name,
project_name=project_name,
verbose=verbose,
tags=tags,
)
return {
"session_name": session_name,
"project_name": project_name,
"results": results,
}

Expand All @@ -500,12 +499,12 @@ def run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
num_repetitions: int = 1,
session_name: Optional[str] = None,
project_name: Optional[str] = None,
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified session name.
"""Run the chain on a dataset and store traces to the specified project name.

Args:
dataset_name: Name of the dataset to run the chain on.
Expand All @@ -516,30 +515,29 @@ def run_on_dataset(
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
session_name: Name of the session to store the traces in.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.

Returns:
A dictionary containing the run's session name and the resulting model outputs.
A dictionary containing the run's project name and the resulting model outputs.
"""
client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
client_.create_session(session_name, mode="eval")
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
results = run_on_examples(
examples,
llm_or_chain_factory,
num_repetitions=num_repetitions,
session_name=session_name,
project_name=project_name,
verbose=verbose,
tags=tags,
)
return {
"session_name": session_name,
"project_name": project_name,
"results": results,
}