# What If Model Memory Is Reorganized Like Human Episodic recall?

In [9]:
# 1. Setup & Imports
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import matplotlib.pyplot as plt

# # EpiNet Encoder Module
## 1. **Import** the required libraries  
## 2. **Define** the `Encoder` class  
## 3. **Instantiate** the encoder and inspect its architecture  
## 4. **Run** a forward pass on a dummy image to verify output shape 

In [17]:
class Encoder(nn.Module):
    """Transforms input images into a latent embedding z."""
    def __init__(self, latent_dim: int = 128):
        super(Encoder, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 28×28 → 28×28
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 28×28 → 14×14
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 14×14 → 14×14
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 14×14 → 7×7
            nn.Flatten(),     # → (64*7*7)
    )
        # Project flattened features to latent_dim
        self.project = nn.Linear(64 * 7 * 7, latent_dim)  # → z ∈ ℝᵈ
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.feature_extractor(x)
        z = self.project(features)
        return z # Matches “z ∈ ℝᵈ”

In [18]:
# Instantiate encoder to verify
encoder = Encoder(latent_dim=128)
print(encoder)
dummy = torch.randn(1, 1, 28, 28)
print("Output shape:", encoder(dummy).shape)

Encoder(
  (feature_extractor): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (project): Linear(in_features=3136, out_features=128, bias=True)
)
Output shape: torch.Size([1, 128])


# # MemoryController Module
# This defines and tests the `MemoryController` class, which computes
# biologically‐inspired salience decay:
# **Sections:** 
# 1. Define `MemoryController`  
# 2. Instantiate and Inspect  
# 3. Test `decay()` with dummy data  

In [19]:
# 3. MemoryController Module
class MemoryController:
    """
    Computes the biologically-inspired salience decay.
    salience_m = r * exp(–α · (τ_now – τ_m))

    Notation:
      • r      … initial salience score (r₀)
      • τ_m    … timestamp when memory m was stored
      • α      … decay rate (learnable or fixed)
      • salience_m … decayed salience at the current time
    """
    def __init__(self, alpha: float):
        """
        α: decay rate (higher → faster forgetting)
        """
        self.alpha = alpha

    def decay(self, r: torch.Tensor, tau_m: torch.Tensor) -> torch.Tensor:
        """
           Apply decay to a batch of memories.

        Args:
            r      (torch.Tensor): shape [M], initial salience scores for M memories
            tau_m  (torch.Tensor): shape [M], stored timestamps (in seconds) for each memory

        Returns:
            torch.Tensor of shape [M]:
              current salience_m = r * exp(–α · (τ_now – τ_m))
        """
        #  Current time in seconds, on same device as inputs
        tau_now = torch.tensor(time.time(), device=r.device)
        #  Δτ = τ_now – τ_m
        delta_tau = tau_now - tau_m
        #  Apply decay element-wise
        salience_m = r * torch.exp(-self.alpha * delta_tau)
        return salience_m

In [20]:
# ## Instantiate and Inspect
#
# Create a `MemoryController` with a chosen decay rate.

# %%
alpha = 0.1
mem_ctrl = MemoryController(alpha=alpha)
print(mem_ctrl)

<__main__.MemoryController object at 0x12180fa90>


In [21]:
# Test `decay()` with Dummy Data
# Dummy salience scores
r = torch.tensor([1.0, 0.5, 2.0], dtype=torch.float32)
# Simulate that these memories were stored 0s, 10s, and 20s ago
now = time.time()
tau_m = torch.tensor([now, now - 10, now - 20], dtype=torch.float32)

# Compute decayed saliences
decayed = mem_ctrl.decay(r, tau_m)
print("Initial r:     ", r)
print("Timestamps Δτ: ", (torch.tensor(time.time()) - tau_m))
print("Decayed salience:", decayed)

Initial r:      tensor([1.0000, 0.5000, 2.0000])
Timestamps Δτ:  tensor([0., 0., 0.])
Decayed salience: tensor([1.0000, 0.5000, 2.0000])


# # EpisodicMemory Module
# This defines and tests the `EpisodicMemory` class, which stores a fixed‑capacity buffer of memories
# and evicts the lowest‑salience trace when full.
# **Sections:**
# 1. Define `EpisodicMemory`  
# 2. Instantiate and Inspect  
# 3. Test `add()` and Eviction Logic  

In [22]:
# 4. EpisodicMemory Module
class EpisodicMemory:
    """
    Fixed‑capacity buffer storing tuples, m = (z, c, r₀, τₘ, yₘ).
    admitting only the most salient memories over time.
    Admission & eviction are driven by current salience decay:
      • Admit new memory if under capacity,
      • Otherwise evict the memory with lowest decayed salience.
    """
    def __init__(
            self,
            capacity: int,
            latent_dim: int,
            decay_rate: float,
            device: torch.device
    ):
        """
        Args:
            capacity (int): maximum number of memories
            latent_dim (int): dimensionality of the latent space
            decay_rate (float): decay rate for salience decay
            device (torch.device): device to store the memory
        """
        self.capacity = capacity
        self.device = device

        # Memory buffer (initially empty)
        self.z_buffer = torch.empty((0, latent_dim), device=device) # z: [N × d] latent embeddings z
        self.c_buffer = torch.empty((0, latent_dim), device=device) # c: [N × d] context vectors
        self.r0_buffer = torch.empty((0,), device=device)  # r0: [N] initial salience score r₀
        self.tau_buffer = torch.empty((0,), device=device) # timestamps τₘ when stored
        self.y_buffer = torch.empty((0,), dtype=torch.long, device=device) # labels yₘ

        # Reuse your MemoryController for decay
        self.mem_ctrl = MemoryController(decay_rate)

    def add(
            self,
            z: torch.Tensor,
            c: torch.Tensor,
            r0: float,
            y: torch.Tensor
    ):
        """
        Try to admit new memory (z,c,r0,τ,y).
        If at capacity, evict the lowest‐decayed‐salience memory.
        """
        # Preparing tensors for concatenation
        z = z.detach().to(self.device).view(1, -1)
        c = c.detach().to(self.device).view(1, -1)
        r0 = torch.tensor([r0], device=self.device)
        y = y.detach().to(self.device).view(1)
        tau = torch.tensor([time.time()], device=self.device)

        # If under capacity, just append/admit
        if self.z_buffer.shape[0] < self.capacity:
            self._append(z, c, r0, tau, y)
            return

        # Otherwise compute decayed salience of existing memories
        s_existing = self.mem_ctrl.decay(self.r0_buffer, self.tau_buffer)

        # If this new memory isn't more salience than the least one, skip
        if r0 <= s_existing.min():
            return

        # Else evict the lowest‐salience and replace it
        idx = torch.argmin(s_existing).item()
        self._replace(idx, z, c, r0, tau, y)

    def _append(self, z, c, r0, tau, y):
        """Add a new memory at the end of each buffer."""
        self.z_buffer = torch.cat([self.z_buffer, z], dim=0)
        self.c_buffer = torch.cat([self.c_buffer, c], dim=0)
        self.r0_buffer = torch.cat([self.r0_buffer, r0], dim=0)
        self.tau_buffer = torch.cat([self.tau_buffer, tau], dim=0)
        self.y_buffer = torch.cat([self.y_buffer, y], dim=0)


    def _replace(self, idx, z, c, r0, tau, y):
        """Overwrite the memory at index `idx` with the new one."""
        self.z_buffer[idx] = z
        self.c_buffer[idx] = c
        self.r0_buffer[idx] = r0
        self.tau_buffer[idx] = tau
        self.y_buffer[idx] = y

    def clear(self):
        """Reset all buffers to empty."""
        self.__init__(self.capacity, self.z_buffer.size(1), self.mem_ctrl.alpha, self.device)



In [23]:
# ## Instantiate and Inspect
# Example: capacity=3, latent_dim=4
capacity   = 3
latent_dim = 4
decay_rate = 0.1
device = torch.device('cpu')

mem = EpisodicMemory(capacity, latent_dim, decay_rate, device)
print(mem)


<__main__.EpisodicMemory object at 0x121814550>


In [24]:
# ## Test `add()` and Eviction Logic
# Create some dummy embeddings and labels
for i in range(5):
    z = torch.randn(latent_dim)
    c = torch.randn(latent_dim)
    y = torch.tensor(i % 2)           # binary labels
    r0 = float(i + 1)                 # increasing initial salience
    mem.add(z, c, r0, y)
    print(f"After adding memory {i+1}: buffer size = {mem.z_buffer.size(0)}")
    print("r0_buffer:", mem.r0_buffer.tolist())

# Final contents
print("\nFinal z_buffer shape:", mem.z_buffer.shape)
print("Final r0_buffer values:", mem.r0_buffer.tolist())

After adding memory 1: buffer size = 1
r0_buffer: [1.0]
After adding memory 2: buffer size = 2
r0_buffer: [1.0, 2.0]
After adding memory 3: buffer size = 3
r0_buffer: [1.0, 2.0, 3.0]
After adding memory 4: buffer size = 3
r0_buffer: [4.0, 2.0, 3.0]
After adding memory 5: buffer size = 3
r0_buffer: [4.0, 5.0, 3.0]

Final z_buffer shape: torch.Size([3, 4])
Final r0_buffer values: [4.0, 5.0, 3.0]


# # RecallEngine Module
# This defines and tests the `RecallEngine` class, which retrieves the most
# salient memories based on cosine similarity with a query and decayed salience.
# **Sections:** 
# 1. Define `RecallEngine`  
# 2. Instantiate and Inspect  
# 3. Test `recall()` with Dummy Data  

In [25]:
# 5. RecallEngine Module
# - **top_k**: number of highest‑scoring memories to retrieve  
# - **RecallScoreₘ** = cos(z_t, cₘ) · salienceₘ  
# - **salienceₘ** = r₀ₘ · exp(–α·(τ_now – τₘ))  
# - **Recall embedding**: where the sum is over the Top‑K memories.

class RecallEngine:
    """
    Handles salience decay, scoring, and top-k retrieval from EpisodicMemory.
    Retrieves salient memories based on cosine similarity and decayed salience.

    Given a query embedding z_t and stored memories (z_m, c_m, r0_m, τ_m),
    computes for each memory:
      RecallScore_m = cos(z_t, c_m) * salience_m
    where salience_m = r0_m * exp(-α * (τ_now - τ_m)).
    Selects Top‑K memories by RecallScore, then computes recall embedding:
      r_t = (1 / Σ_i r_i) * Σ_i r_i * z_i,
    where r_i is the decayed salience of the selected memories.
    """
    def __init__(self, top_k: int):
        """
        top_k: number of highest‑scoring memories to retrieve.
        """
        self.top_k = top_k

    def recall(self, z_query: torch.Tensor, memory: EpisodicMemory) -> torch.Tensor:
        """
        Perform memory recall for a batch of query embeddings.

        Args:
            z_query (torch.Tensor): shape [B, d], query latent embeddings z_t.
            memory (EpisodicMemory): contains buffers:
              - c_buffer: [N, d] context embeddings c_m
              - z_buffer: [N, d] latent embeddings z_m
              - r0_buffer: [N]   initial salience r0_m
              - tau_buffer: [N]  timestamps τ_m

        Returns:
            torch.Tensor: shape [B, d], recall embeddings r_t.
        """
        # Guard for empty memory
        if memory.z_buffer.size(0) == 0:
            return torch.zeros_like(z_query)

        # Compute decayed salience for all stored memories: [N]
        salience = memory.mem_ctrl.decay(memory.r0_buffer, memory.tau_buffer)

        # Compute cosine similarity: [B, N]
        cos_sim = F.cosine_similarity(
            z_query.unsqueeze(1),         # [B, 1, d]
            memory.c_buffer.unsqueeze(0),   # [1, N, d]
            dim=-1
        )

        # Compute recall scores for selection: [B, N]
        recall_scores = cos_sim * salience.unsqueeze(0)

        # Select Top-K indices by recall score: [B, K]
        _, top_idx = torch.topk(recall_scores, self.top_k, dim=1)

        # Gather salience and latent embeddings of selected memories
        salience_topk = salience[top_idx]          # [B, K]
        z_topk = memory.z_buffer[top_idx]   # [B, K, d]

        # Normalize by sum of salience: weights = r_i / Σ_j r_j
        weights = salience_topk / (salience_topk.sum(dim=1, keepdim=True) + 1e-8)

        # Weighted sum to form recall embedding: [B, d]
        recall_emb = (weights.unsqueeze(-1) * z_topk).sum(dim=1)
        return recall_emb


In [26]:
# ## Instantiate and Inspect

# Set up a small EpisodicMemory and RecallEngine
device     = torch.device('cpu')
latent_dim = 4
capacity   = 5
decay_rate = 0.1
top_k      = 2

memory      = EpisodicMemory(capacity, latent_dim, decay_rate, device)
recall_eng  = RecallEngine(top_k=top_k)

print(memory)
print(recall_eng)


<__main__.EpisodicMemory object at 0x121815cf0>
<__main__.RecallEngine object at 0x121816b60>


In [27]:
# ## Test `recall()` with Dummy Data
# - Add 5 dummy memories with increasing initial salience  
# - Wait briefly to generate time differences  
# - Query with two random embeddings  

# %%
# Add dummy memories
for i in range(5):
    z = torch.randn(latent_dim)
    c = torch.randn(latent_dim)
    y = torch.tensor(i % 3)           # dummy labels
    r0 = float(i + 1)                 # salience increasing
    memory.add(z, c, r0, y)
    time.sleep(0.1)                   # small delay to vary τ_m

print("Stored r0_buffer:", memory.r0_buffer.tolist())

# Create a batch of 2 query embeddings
z_query = torch.randn(2, latent_dim)

# Perform recall
r_emb = recall_eng.recall(z_query, memory)

print("Query embeddings:\n", z_query)
print("Recall embeddings:\n", r_emb)

Stored r0_buffer: [1.0, 2.0, 3.0, 4.0, 5.0]
Query embeddings:
 tensor([[ 0.5505,  2.0736, -0.5507, -0.7457],
        [ 0.2652, -1.1463, -0.9636,  0.6260]])
Recall embeddings:
 tensor([[ 0.8314,  0.4124, -0.0066,  0.7256],
        [ 0.4540, -0.4153, -0.2932, -0.3724]])


# # Decoder Module for EpiNet
# This defines and tests the `Decoder` class, which takes a latent embedding `z_t`
# and a recall embedding `r_t`, concatenates them, and produces class logits via a two-layer MLP.
# **Sections:** 
# 1. Define `Decoder`  
# 2. Instantiate and Inspect  
# 3. Forward Pass Test  

In [32]:
#   - `z_t` ∈ ℝᵈ: latent embedding from the encoder  
#   - `r_t` ∈ ℝᵈ: recall embedding from the memory 
# Concatenate `[z_t; r_t]` → ℝ²ᵈ 

# 6. Decoder Module

class Decoder(nn.Module):
    """
    Mixture-of-Experts Decoder for EpiNet.

    Given latent embedding z_t and recall embedding r_t,
    concatenates to h_t ∈ ℝ^{2d}, then uses E parallel experts
    and a gating network to produce class logits:
      • Gate: g = softmax(G·h_t) ∈ ℝ^E
      • Experts: ℓ^{(e)} = expert_e(h_t) ∈ ℝ^K
      • logits = Σ_{e=1}^E g_e · ℓ^{(e)}
    """
    def __init__(
        self,
        latent_dim: int,
        hidden_dim: int,
        num_classes: int,
        num_experts: int = 4
    ):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_experts = num_experts
        # Gating network: 2d → E
        self.gate = nn.Linear(latent_dim * 2, num_experts)
        # Experts: each maps 2d → hidden_dim → num_classes
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim * 2, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, num_classes)
            ) for _ in range(num_experts)
        ])

    def forward(self, z: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """
        Args:
          z (torch.Tensor): [B, d], latent embedding from Encoder
          r (torch.Tensor): [B, d], recall embedding from RecallEngine
        Returns:
          logits (torch.Tensor): [B, num_classes]
        """
        # Concatenate embeddings: [B, 2d]
        h = torch.cat([z, r], dim=1)
        # Compute gating weights: [B, E]
        gate_logits = self.gate(h)
        gate_weights = F.softmax(gate_logits, dim=1)
        # Expert outputs: stack into [B, E, K]
        expert_outputs = torch.stack(
            [expert(h) for expert in self.experts],
            dim=1
        )
        # Weighted sum of experts: [B, K]
        gate_weights = gate_weights.unsqueeze(-1)  # [B, E, 1]
        logits = (gate_weights * expert_outputs).sum(dim=1)
        return logits

In [33]:
# ## 3. Instantiate and Inspect

latent_dim = 128
hidden_dim = 256
num_classes = 10

decoder = Decoder(latent_dim, hidden_dim, num_classes)
print(decoder)


Decoder(
  (gate): Linear(in_features=256, out_features=4, bias=True)
  (experts): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=256, out_features=10, bias=True)
    )
  )
)


In [35]:
# ## Forward Pass Test
# Verify that concatenating two random embeddings produces correct logits shape.

# Dummy embeddings
z_dummy = torch.randn(4, latent_dim)  # batch size 4
r_dummy = torch.randn(4, latent_dim)

# Forward pass
logits = decoder(z_dummy, r_dummy)
print("z_dummy shape:", z_dummy.shape)
print("r_dummy shape:", r_dummy.shape)
print("logits shape:", logits.shape)

z_dummy shape: torch.Size([4, 128])
r_dummy shape: torch.Size([4, 128])
logits shape: torch.Size([4, 10])


In [8]:
# 7. EpiNetModel Assembly
class EpiNetModel(nn.Module):
    """Orchestrates encoding, episodic storage, recall, and decoding."""
    def __init__(self, latent_dim=128, memory_capacity=1000, decay_rate=1e-3, top_k=5, num_classes=10):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.memory = EpisodicMemory(memory_capacity)
        self.recaller = RecallEngine(self.memory, decay_rate, top_k)
        self.decoder = Decoder(latent_dim, latent_dim, num_classes)

    def forward(self, x: torch.Tensor):
        z = self.encoder(x)
        recall_vec = self.recaller.recall(z)
        logits = self.decoder(z, recall_vec)
        return logits, z

    def memorize(self, x, z, y, salience: float):
        self.recaller.store(x, z, y, salience)

In [9]:
# 8. Quick Sanity Check
model = EpiNetModel(latent_dim=64, memory_capacity=50, top_k=3)
dummy = torch.randn(4,1,28,28)
logits, z = model(dummy)
print(f"Logits shape: {logits.shape}, Embedding shape: {z.shape}")


Logits shape: torch.Size([4, 10]), Embedding shape: torch.Size([4, 64])


In [10]:
# 9. Full Training Loop (Split-MNIST) + Visualization
def train_split_mnist(
    latent_dim=64, memory_capacity=200, decay_rate=1e-3, top_k=5,
    num_classes=10, batch_size=64, lr=1e-3, beta=0.5, epochs_per_task=3
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    full_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    tasks = [list(range(0,5)), list(range(5,10))]
    model = EpiNetModel(latent_dim, memory_capacity, decay_rate, top_k, num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_losses = {0: [], 1: []}

    for t_idx, classes in enumerate(tasks):
        idxs = [i for i, lbl in enumerate(full_train.targets) if int(lbl) in classes]
        loader = DataLoader(Subset(full_train, idxs), batch_size=batch_size, shuffle=True)
        for ep in range(epochs_per_task):
            epoch_loss = 0.0
            for xb, yb in loader:
                xb, yb = xb.to(device), yb.to(device)
                logits, z = model(xb)
                loss_main = F.cross_entropy(logits, yb)
                if model.memory.memory:
                    k_sample = min(len(model.memory.memory), batch_size)
                    mem_idxs = np.random.choice(len(model.memory.memory), k_sample, replace=False)
                    xm = torch.stack([model.memory.memory[i]['x'] for i in mem_idxs]).to(device)
                    ym = torch.tensor([model.memory.memory[i]['y'] for i in mem_idxs],
                                      dtype=torch.long, device=device)
                    logits_mem, _ = model(xm)
                    loss_replay = F.cross_entropy(logits_mem, ym)
                    loss = loss_main + beta * loss_replay
                else:
                    loss = loss_main
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                model.memorize(xb[0], z[0], yb[0], salience=loss_main.item())
                epoch_loss += loss.item() * xb.size(0)
            avg = epoch_loss / len(idxs)
            train_losses[t_idx].append(avg)
            print(f"Task {t_idx} Epoch {ep} Loss: {avg:.4f}")
    return model, train_losses

# Run training
model, losses = train_split_mnist()

# %%
# 10. Plot Loss Curves
for t, vals in losses.items():
    plt.plot(vals, label=f'Task {t}')
plt.title("Training Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

RuntimeError: a Tensor with 64 elements cannot be converted to Scalar