# Agent AI State Management — Workshop Demo

This notebook walks through three context management strategies, progressing from naive to sophisticated:

1. **Phase 1: Sliding Window** — The problem: lost context
2. **Phase 2: Summarization** — A partial fix: compressed history
3. **Phase 3: Temperature-Scored Tiered Memory** — The full solution: self-organizing memory

By the end, you'll see side-by-side how each strategy handles a 25-turn conversation and why tiered memory with temperature scoring produces the best recall.

## Setup & Dependencies

Install required packages. We use `sentence-transformers` for local embeddings (no API key needed for the embedding step) and `openai` for the LLM calls.

In [None]:
!pip install -q numpy sentence-transformers tiktoken openai matplotlib

In [None]:
import os
import json
import numpy as np
import tiktoken
import matplotlib.pyplot as plt
from math import exp
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
from openai import OpenAI

# --- Configuration ---
# Set your OpenAI API key (or use any compatible provider)
# os.environ["OPENAI_API_KEY"] = "sk-..."

LLM_MODEL = "gpt-4o-mini"  # Use a fast, cheap model for the demo
EMBED_MODEL = "all-MiniLM-L6-v2"  # Local embedding model, no API key needed

client = OpenAI()
embedder = SentenceTransformer(EMBED_MODEL)
tokenizer = tiktoken.encoding_for_model("gpt-4")

print(f"LLM: {LLM_MODEL}")
print(f"Embeddings: {EMBED_MODEL} (local)")
print("Ready.")

## Helper Functions

Shared utilities used by all three phases.

In [None]:
def count_tokens(messages: list[dict]) -> int:
    """Count tokens in a list of chat messages."""
    total = 0
    for m in messages:
        total += len(tokenizer.encode(m.get("content", "")))
        total += 4  # role + formatting overhead per message
    return total


def llm_call(messages: list[dict]) -> str:
    """Make an LLM call and return the response text."""
    response = client.chat.completions.create(
        model=LLM_MODEL,
        messages=messages,
        max_tokens=300,
        temperature=0.3,
    )
    return response.choices[0].message.content


def embed(text: str) -> list[float]:
    """Generate embedding using local model."""
    return embedder.encode(text).tolist()


def cosine_similarity(a, b):
    """Compute cosine similarity between two vectors."""
    a, b = np.array(a), np.array(b)
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8))


def extract_keywords(text: str) -> list[str]:
    """Simple keyword extraction — split and filter short/stop words."""
    stop = {"i", "we", "the", "a", "an", "is", "are", "to", "do", "it",
            "my", "our", "and", "or", "of", "on", "in", "at", "for", "with",
            "have", "has", "that", "this", "what", "how", "why", "did", "say"}
    words = text.lower().replace("?", "").replace(".", "").replace(",", "").split()
    return [w for w in words if w not in stop and len(w) > 2]


print("Helpers loaded.")

## The Scenario

A 25-turn DevOps conversation where a user progressively describes a production issue. Early turns establish key facts that are referenced by recall questions at the end.

**Turns 1-20**: Build up context about the system, the problem, and the environment.  
**Turns 21-25**: Ask recall questions that require information from early turns.

In [None]:
scenario = [
    # --- Context building (turns 0-19) ---
    "Hi, I'm Alex. I'm working on a Python microservice.",                  # 0
    "We deploy to AWS using ECS.",                                           # 1
    "The service handles payment processing with Stripe.",                   # 2
    "We're seeing timeouts on the /checkout endpoint.",                      # 3
    "The database is PostgreSQL on RDS, port 5432.",                         # 4
    "Our CI/CD pipeline uses GitHub Actions.",                               # 5
    "The team uses Docker with multi-stage builds.",                         # 6
    "We recently added a Redis cache for sessions.",                         # 7
    "The monitoring stack is Prometheus + Grafana.",                         # 8
    "We suspect the Stripe webhook is causing the timeout.",                 # 9
    "The webhook handler does a synchronous DB write.",                      # 10
    "We tried increasing the timeout to 30s but it didn't help.",           # 11
    "The error logs show connection pool exhaustion.",                       # 12
    "Our pool size is set to 5 connections.",                                # 13
    "Traffic spikes happen during lunch hours, 11am-1pm.",                  # 14
    "We also have a batch job that runs at noon.",                           # 15
    "The batch job processes refunds from the previous day.",                # 16
    "It opens 3 long-running DB connections.",                               # 17
    "So during peak, we have 3 batch + N request connections.",              # 18
    "We're considering PgBouncer for connection pooling.",                   # 19
    # --- Recall questions (turns 20-24) ---
    "What cloud provider and service are we using?",                         # 20 → AWS ECS
    "What's our current DB connection pool size?",                           # 21 → 5
    "What payment processor do we use?",                                     # 22 → Stripe
    "Why do we think the timeouts happen at lunch?",                         # 23 → batch + traffic
    "Summarize the root cause and proposed solution.",                       # 24 → pool exhaustion + PgBouncer
]

# Ground-truth answers for evaluation
expected_answers = {
    20: "AWS ECS",
    21: "5 connections",
    22: "Stripe",
    23: "batch job (3 connections) + lunch traffic spike causes pool exhaustion",
    24: "connection pool exhaustion (5 pool + 3 batch at peak); proposed fix: PgBouncer",
}

RECALL_START = 20
print(f"Scenario: {len(scenario)} turns ({RECALL_START} context + {len(scenario) - RECALL_START} recall questions)")

---

# Phase 1: Sliding Window

The simplest approach — keep only the last N turns in the prompt. Everything older is discarded.

**What to watch for**: When the recall questions arrive (turns 20-24), the sliding window will have lost all early facts.

In [None]:
class SlidingWindowAgent:
    """Keeps the last `window_size` turns as context."""

    def __init__(self, window_size: int = 5):
        self.history: list[dict] = []
        self.window_size = window_size
        self.token_log: list[int] = []
        self.responses: list[str] = []

    def chat(self, user_message: str) -> tuple[str, int]:
        self.history.append({"role": "user", "content": user_message})

        # Only send the last N messages
        window = self.history[-self.window_size * 2:]  # *2 to include assistant turns
        tokens_used = count_tokens(window)
        self.token_log.append(tokens_used)

        response = llm_call(messages=window)
        self.history.append({"role": "assistant", "content": response})
        self.responses.append(response)
        return response, tokens_used


print("SlidingWindowAgent defined (window_size=5)")

In [None]:
# Run Phase 1
agent_sw = SlidingWindowAgent(window_size=5)

print("=" * 60)
print("PHASE 1: SLIDING WINDOW")
print("=" * 60)

for i, turn in enumerate(scenario):
    response, tokens = agent_sw.chat(turn)
    marker = " <<< RECALL" if i >= RECALL_START else ""
    print(f"\nTurn {i:2d} [{tokens:4d} tokens]{marker}")
    print(f"  User: {turn}")
    if i >= RECALL_START:
        print(f"  Agent: {response[:200]}")
        print(f"  Expected: {expected_answers[i]}")

### Phase 1 Observation

Notice how the agent **cannot answer** the recall questions because the relevant facts (from turns 1-4, 13, etc.) fell outside the 5-turn window. The agent may guess or hallucinate.

**Problem**: Fixed window = fixed amnesia.

---

# Phase 2: Summarization

Keep a rolling summary of older turns plus a small recent window. When history grows too long, summarize and compress.

**What to watch for**: Better recall than sliding window, but summaries are lossy — specific details (exact numbers, port values) may be lost.

In [None]:
class SummarizationAgent:
    """Compresses older turns into a rolling summary."""

    def __init__(self, window_size: int = 3):
        self.history: list[dict] = []
        self.summary: str = ""
        self.window_size = window_size
        self.token_log: list[int] = []
        self.responses: list[str] = []

    def _summarize(self, old_turns: list[dict]) -> str:
        prompt = f"Previous summary:\n{self.summary}\n\nNew conversation turns:\n"
        for t in old_turns:
            prompt += f"{t['role']}: {t['content']}\n"
        prompt += ("\nWrite a concise summary preserving ALL key facts "
                   "(names, numbers, services, ports, timestamps). "
                   "Keep it under 200 words:")
        return llm_call(messages=[{"role": "user", "content": prompt}])

    def chat(self, user_message: str) -> tuple[str, int]:
        self.history.append({"role": "user", "content": user_message})

        # Summarize when history exceeds threshold
        if len(self.history) > self.window_size * 2:
            old_turns = self.history[:-self.window_size]
            self.summary = self._summarize(old_turns)
            self.history = self.history[-self.window_size:]

        messages = []
        if self.summary:
            messages.append({
                "role": "system",
                "content": f"Summary of earlier conversation:\n{self.summary}"
            })
        messages.extend(self.history[-self.window_size * 2:])

        tokens_used = count_tokens(messages)
        self.token_log.append(tokens_used)

        response = llm_call(messages=messages)
        self.history.append({"role": "assistant", "content": response})
        self.responses.append(response)
        return response, tokens_used


print("SummarizationAgent defined (window_size=3)")

In [None]:
# Run Phase 2
agent_sum = SummarizationAgent(window_size=3)

print("=" * 60)
print("PHASE 2: SUMMARIZATION")
print("=" * 60)

for i, turn in enumerate(scenario):
    response, tokens = agent_sum.chat(turn)
    marker = " <<< RECALL" if i >= RECALL_START else ""
    print(f"\nTurn {i:2d} [{tokens:4d} tokens]{marker}")
    print(f"  User: {turn}")
    if i >= RECALL_START:
        print(f"  Agent: {response[:200]}")
        print(f"  Expected: {expected_answers[i]}")

print(f"\nFinal summary ({len(agent_sum.summary)} chars):")
print(agent_sum.summary)

### Phase 2 Observation

Better recall — the summary retains key facts like "AWS ECS" and "Stripe". But notice:

- **Lossy**: Specific numbers (pool size = 5, port 5432) may be compressed away
- **No prioritization**: The summary treats all facts equally
- **Extra LLM calls**: Every summarization step costs tokens and latency

**Improvement needed**: What if we could automatically prioritize the facts that matter most?

---

# Phase 3: Temperature-Scored Tiered Memory

Every turn is stored as a memory with an embedding. On each query, we score all memories using a **temperature formula** that accounts for recency, relevance, frequency, and entity overlap. Memories are organized into tiers:

| Tier | Temperature | Behavior |
|------|------------|----------|
| **HOT** | >= 0.70 | Injected into context, fast access |
| **WARM** | 0.50 - 0.70 | Available for retrieval |
| **COLD** | < 0.50 | Archived, rarely accessed |

Memories automatically **promote** (cold -> warm -> hot) and **demote** (hot -> warm -> cold) based on usage.

### Step 1: Memory Data Model

In [None]:
@dataclass
class Memory:
    """A single memory unit with temperature and tier tracking."""
    id: str
    text: str
    embedding: list[float]
    memory_type: str              # SEMANTIC, EPISODIC, PROCEDURAL
    temperature: float = 0.5
    tier: str = "cold"            # hot, warm, cold
    access_count: int = 0
    last_accessed: datetime = field(default_factory=datetime.now)
    created_at: datetime = field(default_factory=datetime.now)
    entities: list[str] = field(default_factory=list)


print("Memory dataclass defined.")
print(f"Fields: {[f.name for f in Memory.__dataclass_fields__.values()]}")

### Step 2: Temperature Scoring Formula

The temperature of a memory is a weighted combination of five signals:

```
temperature = 0.30 * recency         (how recently was it accessed?)
            + 0.25 * relevance        (cosine similarity to current query)
            + 0.20 * frequency         (how often has it been accessed?)
            + 0.15 * entity_overlap    (do query keywords match memory keywords?)
            + 0.10 * agent_match       (was it created by the same agent?)
```

In [None]:
WEIGHTS = {
    "recency": 0.30,
    "relevance": 0.25,
    "frequency": 0.20,
    "entity_overlap": 0.15,
    "agent_match": 0.10,
}


def compute_temperature(
    memory: Memory,
    query_embedding: list[float],
    query_entities: list[str],
    max_access: int = 50,
) -> float:
    """Compute the temperature score for a memory given a query."""

    # Recency: exponential decay based on hours since last access
    hours_ago = (datetime.now() - memory.last_accessed).total_seconds() / 3600
    recency = exp(-0.1 * hours_ago)

    # Relevance: cosine similarity between query and memory embeddings
    relevance = cosine_similarity(query_embedding, memory.embedding)

    # Frequency: normalized access count
    frequency = min(memory.access_count / max_access, 1.0)

    # Entity overlap: fraction of query keywords found in memory
    if query_entities and memory.entities:
        overlap = len(set(query_entities) & set(memory.entities)) / len(set(query_entities))
    else:
        overlap = 0.0

    # Agent match: always 1.0 in this single-agent demo
    agent = 1.0

    temp = (
        WEIGHTS["recency"] * recency
        + WEIGHTS["relevance"] * relevance
        + WEIGHTS["frequency"] * frequency
        + WEIGHTS["entity_overlap"] * overlap
        + WEIGHTS["agent_match"] * agent
    )
    return round(temp, 4)


print("Temperature formula defined.")
print(f"Weights: {WEIGHTS}")

### Step 3: Tier Promotion & Demotion

After scoring, memories move between tiers:
- **temp >= 0.70** -> promote to HOT
- **0.50 <= temp < 0.70** -> WARM
- **temp < 0.50** -> demote to COLD

In [None]:
TIER_ORDER = ["cold", "warm", "hot"]


def update_tiers(memories: list[Memory]) -> tuple[list, list]:
    """Update tier assignments and return lists of promotions and demotions."""
    promotions, demotions = [], []

    for m in memories:
        old_tier = m.tier
        if m.temperature >= 0.70:
            m.tier = "hot"
        elif m.temperature >= 0.50:
            m.tier = "warm"
        else:
            m.tier = "cold"

        if m.tier != old_tier:
            change = (m.id, m.text[:50], old_tier, m.tier, m.temperature)
            if TIER_ORDER.index(m.tier) > TIER_ORDER.index(old_tier):
                promotions.append(change)
            else:
                demotions.append(change)

    return promotions, demotions


print("Tier promotion/demotion logic defined.")

### Step 4: Tiered Memory Agent

This agent stores every turn as a memory, scores all memories on each query, promotes/demotes tiers, and injects only the top-k most relevant memories into the LLM prompt.

In [None]:
class TieredMemoryAgent:
    """Agent with temperature-scored tiered memory."""

    def __init__(self, embed_fn, top_k: int = 10):
        self.memories: list[Memory] = []
        self.embed_fn = embed_fn
        self.top_k = top_k
        self.token_log: list[int] = []
        self.responses: list[str] = []
        self.tier_history: list[dict] = []

    def store(self, text: str, entities: list[str], memory_type: str = "EPISODIC"):
        """Store a new memory in the cold tier."""
        emb = self.embed_fn(text)
        mem = Memory(
            id=f"mem_{len(self.memories):03d}",
            text=text,
            embedding=emb,
            memory_type=memory_type,
            entities=entities,
        )
        self.memories.append(mem)

    def recall(self, query: str, query_entities: list[str]):
        """Score memories, update tiers, return top-k."""
        query_emb = self.embed_fn(query)

        # Score all memories
        for m in self.memories:
            m.temperature = compute_temperature(m, query_emb, query_entities)

        # Update tiers
        promotions, demotions = update_tiers(self.memories)
        self.tier_history.append({
            "promotions": promotions,
            "demotions": demotions,
            "tier_counts": self._tier_counts(),
        })

        # Bump access stats for retrieved memories
        for m in self.memories:
            if m.tier in ("hot", "warm"):
                m.access_count += 1
                m.last_accessed = datetime.now()

        ranked = sorted(self.memories, key=lambda m: m.temperature, reverse=True)
        return ranked[:self.top_k], promotions, demotions

    def chat(self, user_message: str) -> tuple[str, int, list, list]:
        """Store turn, recall relevant memories, generate response."""
        entities = extract_keywords(user_message)
        self.store(user_message, entities)

        results, promotions, demotions = self.recall(user_message, entities)

        # Build context from top memories with tier labels
        context_lines = []
        for m in results:
            context_lines.append(f"[{m.tier.upper()} t={m.temperature:.2f}] {m.text}")
        context = "\n".join(context_lines)

        messages = [
            {"role": "system", "content": f"Relevant memories (sorted by importance):\n{context}"},
            {"role": "user", "content": user_message},
        ]

        tokens_used = count_tokens(messages)
        self.token_log.append(tokens_used)

        response = llm_call(messages=messages)
        self.responses.append(response)
        return response, tokens_used, promotions, demotions

    def _tier_counts(self) -> dict:
        counts = {"hot": 0, "warm": 0, "cold": 0}
        for m in self.memories:
            counts[m.tier] += 1
        return counts


print("TieredMemoryAgent defined (top_k=10)")

In [None]:
# Run Phase 3
agent_tm = TieredMemoryAgent(embed_fn=embed, top_k=10)

print("=" * 60)
print("PHASE 3: TIERED MEMORY WITH TEMPERATURE SCORING")
print("=" * 60)

for i, turn in enumerate(scenario):
    response, tokens, promos, demos = agent_tm.chat(turn)
    marker = " <<< RECALL" if i >= RECALL_START else ""
    print(f"\nTurn {i:2d} [{tokens:4d} tokens]{marker}")
    print(f"  User: {turn}")

    if promos:
        for p in promos:
            print(f"  PROMOTED: {p[1]} ({p[2]} -> {p[3]}, temp={p[4]:.2f})")
    if demos:
        for d in demos:
            print(f"  DEMOTED:  {d[1]} ({d[2]} -> {d[3]}, temp={d[4]:.2f})")

    if i >= RECALL_START:
        print(f"  Agent: {response[:200]}")
        print(f"  Expected: {expected_answers[i]}")

### Phase 3 Observation

The tiered memory agent:
- **Promotes** relevant memories when they match the query (e.g., AWS/ECS facts get promoted when asked about cloud provider)
- **Injects only the top-k** most relevant memories, keeping token usage bounded
- **Retains specific details** (pool size = 5, port 5432) because they're stored as individual memories, not compressed into a lossy summary

---

# Comparison Dashboard

Let's compare all three approaches side-by-side.

### Token Usage Over Turns

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.plot(agent_sw.token_log, label="Sliding Window", marker="o", markersize=4)
ax.plot(agent_sum.token_log, label="Summarization", marker="s", markersize=4)
ax.plot(agent_tm.token_log, label="Tiered Memory", marker="^", markersize=4)

ax.axvline(x=RECALL_START, color="red", linestyle="--", alpha=0.5, label="Recall questions start")
ax.set_xlabel("Turn")
ax.set_ylabel("Tokens Sent to LLM")
ax.set_title("Token Usage by Context Strategy")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Total tokens — Sliding: {sum(agent_sw.token_log):,} | "
      f"Summary: {sum(agent_sum.token_log):,} | "
      f"Tiered: {sum(agent_tm.token_log):,}")

### Tier Distribution Over Time

In [None]:
turns = range(len(agent_tm.tier_history))
hot_counts = [h["tier_counts"]["hot"] for h in agent_tm.tier_history]
warm_counts = [h["tier_counts"]["warm"] for h in agent_tm.tier_history]
cold_counts = [h["tier_counts"]["cold"] for h in agent_tm.tier_history]

fig, ax = plt.subplots(figsize=(12, 5))
ax.stackplot(
    turns, hot_counts, warm_counts, cold_counts,
    labels=["Hot", "Warm", "Cold"],
    colors=["#ef4444", "#f59e0b", "#3b82f6"],
    alpha=0.8,
)
ax.axvline(x=RECALL_START, color="black", linestyle="--", alpha=0.5, label="Recall questions")
ax.set_xlabel("Turn")
ax.set_ylabel("Number of Memories")
ax.set_title("Memory Tier Distribution Over Time")
ax.legend(loc="upper left")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Recall Accuracy Comparison

In [None]:
print("=" * 80)
print("RECALL ACCURACY COMPARISON (Turns 20-24)")
print("=" * 80)
print(f"{'Turn':<6} {'Question':<45} {'Expected':<25}")
print("-" * 80)

for turn_idx in range(RECALL_START, len(scenario)):
    resp_idx = turn_idx  # response index matches turn index
    print(f"\nTurn {turn_idx}: {scenario[turn_idx]}")
    print(f"  Expected: {expected_answers[turn_idx]}")
    print(f"  Sliding:  {agent_sw.responses[resp_idx][:120]}")
    print(f"  Summary:  {agent_sum.responses[resp_idx][:120]}")
    print(f"  Tiered:   {agent_tm.responses[resp_idx][:120]}")

### Final Memory State (Tiered Agent)

Let's inspect which memories ended up in each tier after the full conversation.

In [None]:
print("=" * 70)
print("FINAL MEMORY STATE")
print("=" * 70)

for tier in ["hot", "warm", "cold"]:
    tier_mems = [m for m in agent_tm.memories if m.tier == tier]
    print(f"\n{'=' * 5} {tier.upper()} ({len(tier_mems)} memories) {'=' * 5}")
    for m in sorted(tier_mems, key=lambda x: x.temperature, reverse=True):
        print(f"  [{m.id}] temp={m.temperature:.3f} access={m.access_count:2d} | {m.text[:70]}")

---

# Key Takeaways

| Strategy | Token Efficiency | Recall Quality | Self-Organizing | Complexity |
|----------|-----------------|---------------|-----------------|------------|
| **Sliding Window** | Bounded but wasteful | Poor (loses early context) | No | Low |
| **Summarization** | Good (compressed) | Moderate (lossy) | No | Medium |
| **Tiered Memory** | Best (only relevant) | Best (individual recall) | Yes | Higher |

1. **State is what makes an agent an agent** — without memory management, you have a chatbot with amnesia
2. **Temperature scoring turns memory into a continuous optimization problem** — not binary keep/discard
3. **Tier promotion/demotion creates self-organizing memory** — frequently accessed facts stay hot
4. **Match your strategy to your use case** — sliding window for simple chat, tiered memory for autonomous agents