<a href="https://colab.research.google.com/github/its-emile/AISafetyResources/blob/main/Memory_safe_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# memory-safe (blind) agentic simulation

Hypothesis: an LLM can solve an agentic task without seeing any of the intermediate data between tool calls, and every tool can strictly control the flow of its input and output data with a policy, guarding against the LLM's unbounded data flow.

In this example simulation, we store a log line without the LLM knowing the time data that's being stored.

In a real-life scenario, this would let the agent manage calls to privileged tools while preventing any exposure of the tools' authentication parameters or sensitive intermediate outputs to the agent or unintended tools.

In [192]:
from google.colab import userdata
from google import genai
from abc import ABC, abstractmethod
from pydantic import BaseModel

# Use Gemini for now since it has a free API.
# Ideally we would end up using two models from different labs.
GEMINI_CLIENT = genai.Client(api_key=userdata.get("GOOGLE_API_KEY"))


class Model(ABC):
    @abstractmethod
    def model_json(self, message: str) -> str:
        pass

class Gemini(Model):
    def __init__(self):
        self.client = GEMINI_CLIENT

    def model_json(self, message: str, output_schema: str) -> str:
      """Model must generate within the specified schema"""
      return self.client.models.generate_content(
          model = "gemini-2.0-flash",
          contents = message,
          config={
            'response_mime_type': 'application/json',
            'response_schema': output_schema,
          },
      )

In [193]:
import random
import time
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple, Optional, Union
import json

AGENT_TEMPLATE="""
You are a highly capable agent, but operating in a blind environment. You have access to the tools below, but won't see the data I/O between them:
{tools}

To fulfill your task below, you will write an execution flow in json as a series of {{method_name, parameter_dict, call_id, out_params}}
Rule 1: You will not see the output data of each call but you can reference any output parameter name by formatting a variable name as call_id:out_param_name
Rule 2: For security reasons, the method "_" must be used for any literals; any other method calls will fail if a parameter_dict contains any literals: must set params as references to call_id:out_param_name
Rule 3: For security reasons, the last method must have "method_name":"result", "call_id":"result" and "parameter_dict" must contain all information the user may need regarding your task completion
Rule 4: The constraints in each tool above are hard constraints and must be respected: you must not use parameter references unless the tools above permit them with exact parameter names matching.
Rule 5: You must use integers for call_id and must start with method _ having this integer call_id: {initial_call_id}
Rule 6: You should only call _ once (group all the literals).

Your new task to fulfill is:
<usertask>
{task}
</usertask>

Now carefully design an execution flow for this task, and print it only as valid json-formatted as per the example, and don't forget to pass to each method call all the parameters it requires.
Note: You've already executed some prior calls below on the user's behalf. Continue the execution flow.
{prior_calls}
"""

SCHEMA_EXEC_FLOW = {
                "type": "array",
                "items": {
                  "type": "object",
                  "properties": {
                    "method_name": {
                      "type": "string",
                    },
                    "call_id": {
                      "type": "integer",
                    },
                    "out_params": {
                      "type": "array",
                      "items": {
                        "type": "string",
                      }
                    },
                    "parameter_dict": {
                      "type": "array",
                      "items": {
                        "type": "object",
                        "properties": {
                          "param_name": {
                            "type": "string",
                          },
                          "param_source": {
                            "type": "string",
                          },
                        },
                      }
                    }
                  },
                  "required": ["method_name","call_id"],
                },
              }

class MemorySafeAgent(ABC):
  def __init__(self, tools: str):
    self.model = Gemini()
    self.tools = tools
    self.call_results={}
    self.call_history_safe={}

  def process_task(self, task:str):
    instructions = AGENT_TEMPLATE.format(
          tools = self.tools,
          task = task,
          initial_call_id = len(self.call_results),
          prior_calls = self.call_history_safe)
    print("Model instructions ",instructions)
    generated_flow = Gemini().model_json(
      instructions,
      output_schema = SCHEMA_EXEC_FLOW
    ).text
    print("Model generated flow:",generated_flow)
    self.current_execution_flow = json.loads(generated_flow)

    for step in self.current_execution_flow:
      if step["method_name"]=="_":
        self.call_results[step["call_id"]]={}
        for p in self.current_execution_flow[0]["parameter_dict"]:
          self.call_results[step["call_id"]][p["param_name"]] = p["param_source"]
        self.call_history_safe[step["call_id"]]={"method_name": "_", "parameter_dict": step["parameter_dict"]}
      else:
        self.call_history_safe[step["call_id"]]={"method_name": step["method_name"], "parameter_dict": step["parameter_dict"]}

        if step["method_name"]=="result":
          self.call_results[step["call_id"]] = self.memory_fetch_unsafe(step["parameter_dict"])
          return self.call_results[step["call_id"]]
        else:
          self.call_results[step["call_id"]] = self.method_call(step["method_name"], step["parameter_dict"])


  def method_call(self, method_name: str, params: Dict[str, Any]):
    if method_name not in self.tools.keys():
      raise ValueError(f"Unknown method name: {method_name}")

    param_values = self.memory_fetch_safe(method_name, params)

    print(f"Calling {method_name} with {param_values}")

    method = getattr(self, method_name)
    res = method(**param_values)

    print(f"Result: {res}")
    return res

  def memory_fetch_safe(self, method_name: str, param_dict: list(Dict[str, Any])) -> Dict[str, Any]:
    for p in param_dict:
      # verify this sink is allowed by the source the agent is trying to use
      source_call, source_param = p['param_source'].split(":")
      source_method = self.call_history_safe[int(source_call)]["method_name"]
      print(f"verifying param ({p['param_name']}) access policy for {method_name} (source: {source_method}:{source_param})")

      if source_method != "_" and f'{method_name}:{p["param_name"]}' not in self.tools[source_method]["out_params_allowed_sinks"][source_param]:
          raise ValueError(f"{source_method} does not authorize {method_name} to read {source_param}")

      if f'{source_method}:{source_param}' not in self.tools[method_name]["in_params_allowed_sources"][p['param_name']]:
        raise ValueError(f"{method_name} does not authorize reading {source_param} from {source_method}")

    # params are permitted by the specified method and their source.
    return self.memory_fetch_unsafe(param_dict)

  def memory_fetch_unsafe(self, param_dict: list(Dict[str, Any])):
    # gets specified output values but not inherently memory safe, requires outer validation
    params={}
    for p in param_dict:
      print(f'loading {p["param_name"]} = {p["param_source"]}')

      p_source, p_name = p["param_source"].split(":")

      if int(p_source) not in self.call_results.keys():
        raise ValueError(f"Unknown source call id: {p_source}")

      if p_name not in self.call_results[int(p_source)].keys():
        raise ValueError(f"Unknown output from {self.call_history_safe[int(p_source)]['method_name']}: {p_name}")

      value = self.call_results[int(p_source)][p_name]
      params[p["param_name"]] = value
      print(f'loaded {p["param_name"]} from {p["param_source"]} = {value}')
    return params

  # simulated tools for testing purposes
  def get_time(self) -> str:
    fmt = "%Y-%m-%d %H:%M:%S"
    return {"time":time.strftime(fmt)}

  def store_log(self, time: str, event: str | None = None) -> None:
    print(f"storing log {time}: {event}")
    return {"log_lines":len([k for k in self.call_history_safe if self.call_history_safe[k]["method_name"]=="store_log"])}



In [194]:
# EXAMPLE USAGE:

# TOOL_LIST must include for each tool:
# - where may it read input parameters (e.g. from _ if values/literals direct from LLM should be allowed)
# - what its output parameters are and where they may be accessed from

TOOL_LIST = {
    "get_time": {
        "in_params_allowed_sources": {
        },
        "out_params_allowed_sinks": {
            "time": ["store_log:time"]
        }
    },
    "store_log": {
        "in_params_allowed_sources": {
            "time": ["get_time:time"],
            "event": ["_:event"]
        },
        "out_params_allowed_sinks": {
            "log_lines": ["result:log_lines"]
        }
    }
}

a = MemorySafeAgent(TOOL_LIST)

In [195]:
a.process_task("log this event: we just discovered nuclear fusion. then tell me how large the log is")

Model instructions  
You are a highly capable agent, but operating in a blind environment. You have access to the tools below, but won't see the data I/O between them:
{'get_time': {'in_params_allowed_sources': {}, 'out_params_allowed_sinks': {'time': ['store_log:time']}}, 'store_log': {'in_params_allowed_sources': {'time': ['get_time:time'], 'event': ['_:event']}, 'out_params_allowed_sinks': {'log_lines': ['result:log_lines']}}}

To fulfill your task below, you will write an execution flow in json as a series of {method_name, parameter_dict, call_id, out_params}
Rule 1: You will not see the output data of each call but you can reference any output parameter name by formatting a variable name as call_id:out_param_name
Rule 2: For security reasons, the method "_" must be used for any literals; any other method calls will fail if a parameter_dict contains any literals: must set params as references to call_id:out_param_name
Rule 3: For security reasons, the last method must have "method_

{'log_lines': 1}

In [196]:
a.process_task("now log this event: turns out it was just the melting of ice. then tell me how large the log is")

Model instructions  
You are a highly capable agent, but operating in a blind environment. You have access to the tools below, but won't see the data I/O between them:
{'get_time': {'in_params_allowed_sources': {}, 'out_params_allowed_sinks': {'time': ['store_log:time']}}, 'store_log': {'in_params_allowed_sources': {'time': ['get_time:time'], 'event': ['_:event']}, 'out_params_allowed_sinks': {'log_lines': ['result:log_lines']}}}

To fulfill your task below, you will write an execution flow in json as a series of {method_name, parameter_dict, call_id, out_params}
Rule 1: You will not see the output data of each call but you can reference any output parameter name by formatting a variable name as call_id:out_param_name
Rule 2: For security reasons, the method "_" must be used for any literals; any other method calls will fail if a parameter_dict contains any literals: must set params as references to call_id:out_param_name
Rule 3: For security reasons, the last method must have "method_

{'log_lines': 2}