diff --git a/.changeset/open-lamps-drop.md b/.changeset/open-lamps-drop.md new file mode 100644 index 00000000..2fc5b72a --- /dev/null +++ b/.changeset/open-lamps-drop.md @@ -0,0 +1,6 @@ +--- +'@e2b/code-interpreter-python': minor +'@e2b/code-interpreter': minor +--- + +added context methods to the sdk diff --git a/js/src/sandbox.ts b/js/src/sandbox.ts index 4272d846..0a87d20f 100644 --- a/js/src/sandbox.ts +++ b/js/src/sandbox.ts @@ -320,4 +320,96 @@ export class Sandbox extends BaseSandbox { throw formatRequestTimeoutError(error) } } + + /** + * Removes a context. + * + * @param context context to remove. + * + * @returns void. + */ + async removeCodeContext(context: Context | string): Promise { + try { + const id = typeof context === 'string' ? context : context.id + const res = await fetch(`${this.jupyterUrl}/contexts/${id}`, { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + ...this.connectionConfig.headers, + }, + keepalive: true, + signal: this.connectionConfig.getSignal( + this.connectionConfig.requestTimeoutMs + ), + }) + + const error = await extractError(res) + if (error) { + throw error + } + } catch (error) { + throw formatRequestTimeoutError(error) + } + } + + /** + * List all contexts. + * + * @returns list of contexts. + */ + async listCodeContexts(): Promise { + try { + const res = await fetch(`${this.jupyterUrl}/contexts`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + ...this.connectionConfig.headers, + }, + keepalive: true, + signal: this.connectionConfig.getSignal( + this.connectionConfig.requestTimeoutMs + ), + }) + + const error = await extractError(res) + if (error) { + throw error + } + + return await res.json() + } catch (error) { + throw formatRequestTimeoutError(error) + } + } + + /** + * Restart a context. + * + * @param context context to restart. + * + * @returns void. + */ + async restartCodeContext(context: Context | string): Promise { + try { + const id = typeof context === 'string' ? context : context.id + const res = await fetch(`${this.jupyterUrl}/contexts/${id}/restart`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...this.connectionConfig.headers, + }, + keepalive: true, + signal: this.connectionConfig.getSignal( + this.connectionConfig.requestTimeoutMs + ), + }) + + const error = await extractError(res) + if (error) { + throw error + } + } catch (error) { + throw formatRequestTimeoutError(error) + } + } } diff --git a/js/tests/contexts.test.ts b/js/tests/contexts.test.ts new file mode 100644 index 00000000..7dda2090 --- /dev/null +++ b/js/tests/contexts.test.ts @@ -0,0 +1,63 @@ +import { expect } from 'vitest' + +import { sandboxTest } from './setup' + +sandboxTest('create context with no options', async ({ sandbox }) => { + const context = await sandbox.createCodeContext() + + const contexts = await sandbox.listCodeContexts() + const lastContext = contexts[contexts.length - 1] + + expect(lastContext.id).toBe(context.id) + expect(lastContext.language).toBe(context.language) + expect(lastContext.cwd).toBe(context.cwd) +}) + +sandboxTest('create context with options', async ({ sandbox }) => { + const context = await sandbox.createCodeContext({ + language: 'python', + cwd: '/root', + }) + + const contexts = await sandbox.listCodeContexts() + const lastContext = contexts[contexts.length - 1] + + expect(lastContext.id).toBe(context.id) + expect(lastContext.language).toBe(context.language) + expect(lastContext.cwd).toBe(context.cwd) +}) + +sandboxTest('remove context', async ({ sandbox }) => { + const context = await sandbox.createCodeContext() + + await sandbox.removeCodeContext(context.id) + const contexts = await sandbox.listCodeContexts() + + expect(contexts.map((context) => context.id)).not.toContain(context.id) +}) + +sandboxTest('list contexts', async ({ sandbox }) => { + const contexts = await sandbox.listCodeContexts() + + // default contexts should include python and javascript + expect(contexts.map((context) => context.language)).toContain('python') + expect(contexts.map((context) => context.language)).toContain('javascript') +}) + +sandboxTest('restart context', async ({ sandbox }) => { + const context = await sandbox.createCodeContext() + + // set a variable in the context + await sandbox.runCode('x = 1', { context: context }) + + // restart the context + await sandbox.restartCodeContext(context.id) + + // check that the variable no longer exists + const execution = await sandbox.runCode('x', { context: context }) + + // check for an NameError with message "name 'x' is not defined" + expect(execution.error).toBeDefined() + expect(execution.error?.name).toBe('NameError') + expect(execution.error?.value).toBe("name 'x' is not defined") +}) diff --git a/python/e2b_code_interpreter/code_interpreter_async.py b/python/e2b_code_interpreter/code_interpreter_async.py index 789d6a70..98f59dd0 100644 --- a/python/e2b_code_interpreter/code_interpreter_async.py +++ b/python/e2b_code_interpreter/code_interpreter_async.py @@ -1,7 +1,7 @@ import logging import httpx -from typing import Optional, Dict, overload, Union, Literal +from typing import Optional, Dict, overload, Union, Literal, List from httpx import AsyncClient from e2b import ( @@ -273,3 +273,89 @@ async def create_code_context( return Context.from_json(data) except httpx.TimeoutException: raise format_request_timeout_error() + + async def remove_code_context( + self, + context: Union[Context, str], + ) -> None: + """ + Removes a context. + + :param context: Context to remove. Can be a Context object or a context ID string. + + :return: None + """ + context_id = context.id if isinstance(context, Context) else context + + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + + try: + response = await self._client.delete( + f"{self._jupyter_url}/contexts/{context_id}", + headers=headers, + timeout=self.connection_config.request_timeout, + ) + + err = await aextract_exception(response) + if err: + raise err + except httpx.TimeoutException: + raise format_request_timeout_error() + + async def list_code_contexts(self) -> List[Context]: + """ + List all contexts. + + :return: List of contexts. + """ + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + + try: + response = await self._client.get( + f"{self._jupyter_url}/contexts", + headers=headers, + timeout=self.connection_config.request_timeout, + ) + + err = await aextract_exception(response) + if err: + raise err + + data = response.json() + return [Context.from_json(context_data) for context_data in data] + except httpx.TimeoutException: + raise format_request_timeout_error() + + async def restart_code_context( + self, + context: Union[Context, str], + ) -> None: + """ + Restart a context. + + :param context: Context to restart. Can be a Context object or a context ID string. + + :return: None + """ + context_id = context.id if isinstance(context, Context) else context + + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + + try: + response = await self._client.post( + f"{self._jupyter_url}/contexts/{context_id}/restart", + headers=headers, + timeout=self.connection_config.request_timeout, + ) + + err = await aextract_exception(response) + if err: + raise err + except httpx.TimeoutException: + raise format_request_timeout_error() diff --git a/python/e2b_code_interpreter/code_interpreter_sync.py b/python/e2b_code_interpreter/code_interpreter_sync.py index 6cf56c11..67492398 100644 --- a/python/e2b_code_interpreter/code_interpreter_sync.py +++ b/python/e2b_code_interpreter/code_interpreter_sync.py @@ -1,7 +1,7 @@ import logging import httpx -from typing import Optional, Dict, overload, Literal, Union +from typing import Optional, Dict, overload, Literal, Union, List from httpx import Client from e2b import Sandbox as BaseSandbox, InvalidArgumentException @@ -270,3 +270,89 @@ def create_code_context( return Context.from_json(data) except httpx.TimeoutException: raise format_request_timeout_error() + + def remove_code_context( + self, + context: Union[Context, str], + ) -> None: + """ + Removes a context. + + :param context: Context to remove. Can be a Context object or a context ID string. + + :return: None + """ + context_id = context.id if isinstance(context, Context) else context + + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + + try: + response = self._client.delete( + f"{self._jupyter_url}/contexts/{context_id}", + headers=headers, + timeout=self.connection_config.request_timeout, + ) + + err = extract_exception(response) + if err: + raise err + except httpx.TimeoutException: + raise format_request_timeout_error() + + def list_code_contexts(self) -> List[Context]: + """ + List all contexts. + + :return: List of contexts. + """ + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + + try: + response = self._client.get( + f"{self._jupyter_url}/contexts", + headers=headers, + timeout=self.connection_config.request_timeout, + ) + + err = extract_exception(response) + if err: + raise err + + data = response.json() + return [Context.from_json(context_data) for context_data in data] + except httpx.TimeoutException: + raise format_request_timeout_error() + + def restart_code_context( + self, + context: Union[Context, str], + ) -> None: + """ + Restart a context. + + :param context: Context to restart. Can be a Context object or a context ID string. + + :return: None + """ + context_id = context.id if isinstance(context, Context) else context + + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + + try: + response = self._client.post( + f"{self._jupyter_url}/contexts/{context_id}/restart", + headers=headers, + timeout=self.connection_config.request_timeout, + ) + + err = extract_exception(response) + if err: + raise err + except httpx.TimeoutException: + raise format_request_timeout_error() diff --git a/python/tests/async/test_async_contexts.py b/python/tests/async/test_async_contexts.py new file mode 100644 index 00000000..6be07734 --- /dev/null +++ b/python/tests/async/test_async_contexts.py @@ -0,0 +1,62 @@ +from e2b_code_interpreter.code_interpreter_async import AsyncSandbox + + +async def test_create_context_with_no_options(async_sandbox: AsyncSandbox): + context = await async_sandbox.create_code_context() + + contexts = await async_sandbox.list_code_contexts() + last_context = contexts[-1] + + assert last_context.id == context.id + assert last_context.language == context.language + assert last_context.cwd == context.cwd + + +async def test_create_context_with_options(async_sandbox: AsyncSandbox): + context = await async_sandbox.create_code_context( + language="python", + cwd="/root", + ) + + contexts = await async_sandbox.list_code_contexts() + last_context = contexts[-1] + + assert last_context.id == context.id + assert last_context.language == context.language + assert last_context.cwd == context.cwd + + +async def test_remove_context(async_sandbox: AsyncSandbox): + context = await async_sandbox.create_code_context() + + await async_sandbox.remove_code_context(context.id) + + contexts = await async_sandbox.list_code_contexts() + assert context.id not in [ctx.id for ctx in contexts] + + +async def test_list_contexts(async_sandbox: AsyncSandbox): + contexts = await async_sandbox.list_code_contexts() + + # default contexts should include python and javascript + languages = [context.language for context in contexts] + assert "python" in languages + assert "javascript" in languages + + +async def test_restart_context(async_sandbox: AsyncSandbox): + context = await async_sandbox.create_code_context() + + # set a variable in the context + await async_sandbox.run_code("x = 1", context=context) + + # restart the context + await async_sandbox.restart_code_context(context.id) + + # check that the variable no longer exists + execution = await async_sandbox.run_code("x", context=context) + + # check for a NameError with message "name 'x' is not defined" + assert execution.error is not None + assert execution.error.name == "NameError" + assert execution.error.value == "name 'x' is not defined" diff --git a/python/tests/sync/test_contexts.py b/python/tests/sync/test_contexts.py new file mode 100644 index 00000000..a7cbd884 --- /dev/null +++ b/python/tests/sync/test_contexts.py @@ -0,0 +1,62 @@ +from e2b_code_interpreter.code_interpreter_sync import Sandbox + + +def test_create_context_with_no_options(sandbox: Sandbox): + context = sandbox.create_code_context() + + contexts = sandbox.list_code_contexts() + last_context = contexts[-1] + + assert last_context.id == context.id + assert last_context.language == context.language + assert last_context.cwd == context.cwd + + +def test_create_context_with_options(sandbox: Sandbox): + context = sandbox.create_code_context( + language="python", + cwd="/root", + ) + + contexts = sandbox.list_code_contexts() + last_context = contexts[-1] + + assert last_context.id == context.id + assert last_context.language == context.language + assert last_context.cwd == context.cwd + + +def test_remove_context(sandbox: Sandbox): + context = sandbox.create_code_context() + + sandbox.remove_code_context(context.id) + + contexts = sandbox.list_code_contexts() + assert context.id not in [ctx.id for ctx in contexts] + + +def test_list_contexts(sandbox: Sandbox): + contexts = sandbox.list_code_contexts() + + # default contexts should include python and javascript + languages = [context.language for context in contexts] + assert "python" in languages + assert "javascript" in languages + + +def test_restart_context(sandbox: Sandbox): + context = sandbox.create_code_context() + + # set a variable in the context + sandbox.run_code("x = 1", context=context) + + # restart the context + sandbox.restart_code_context(context.id) + + # check that the variable no longer exists + execution = sandbox.run_code("x", context=context) + + # check for a NameError with message "name 'x' is not defined" + assert execution.error is not None + assert execution.error.name == "NameError" + assert execution.error.value == "name 'x' is not defined" diff --git a/template/server/main.py b/template/server/main.py index ea89a9d8..1f296926 100644 --- a/template/server/main.py +++ b/template/server/main.py @@ -2,7 +2,7 @@ import sys import httpx -from typing import Dict, Union, Literal, Set +from typing import Dict, Union, Literal, List from contextlib import asynccontextmanager from fastapi import FastAPI, Request @@ -133,19 +133,18 @@ async def post_contexts(request: CreateContext) -> Context: @app.get("/contexts") -async def get_contexts() -> Set[Context]: +async def get_contexts() -> List[Context]: logger.info("Listing contexts") - context_ids = websockets.keys() - - return set( + return [ Context( - id=websockets[context_id].context_id, - language=websockets[context_id].language, - cwd=websockets[context_id].cwd, + id=ws.context_id, + language=ws.language, + cwd=ws.cwd, ) - for context_id in context_ids - ) + for key, ws in websockets.items() + if key != "default" + ] @app.post("/contexts/{context_id}/restart")