In [1]:
# 1. Setup & Imports
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:
# 2. Encoder Module
class Encoder(nn.Module):
    """Transforms input images into a latent embedding z."""
    def __init__(self, latent_dim: int = 128):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                  
            nn.Conv2d(32, 64, 3, padding=1),  
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                  
            nn.Flatten(),                    
        )
        self.fc = nn.Linear(64 * 7 * 7, latent_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.feature_extractor(x)
        return self.fc(features)

In [4]:
# 3. MemoryController Module
class MemoryController:
    """Decides whether to admit or replace entries in the memory buffer."""
    def __init__(self, capacity: int):
        self.capacity = capacity

    def should_store(self, salience: float, buffer: list) -> bool:
        if len(buffer) < self.capacity:
            return True
        min_r = min(entry['r'] for entry in buffer)
        return salience > min_r

    def replace_index(self, buffer: list) -> int:
        saliences = [entry['r'] for entry in buffer]
        return int(np.argmin(saliences))


In [5]:
# 4. EpisodicMemory Module
class EpisodicMemory:
    """Stores raw episodes (x, z, y, r0, r, tau)."""
    def __init__(self, capacity: int = 1000):
        self.capacity = capacity
        self.memory = []  # list of dicts: {x, z, y, r0, r, tau}

    def add(self, entry: dict):
        if len(self.memory) < self.capacity:
            self.memory.append(entry)
        else:
            idx = entry.pop('replace_idx', None)
            if idx is not None:
                self.memory[idx] = entry


In [6]:
# 5. RecallEngine Module
class RecallEngine:
    """Handles salience decay, scoring, and top-k retrieval from EpisodicMemory."""
    def __init__(self, memory: EpisodicMemory, decay_rate: float = 1e-3, top_k: int = 5):
        self.memory = memory
        self.decay_rate = decay_rate
        self.top_k = top_k
        self.controller = MemoryController(memory.capacity)

    def recall(self, query_z: torch.Tensor) -> torch.Tensor:
        if not self.memory.memory:
            return torch.zeros_like(query_z)
        now = time.time()
        scores = []
        for entry in self.memory.memory:
            delta_t = now - entry['tau']
            entry['r'] = entry['r0'] * np.exp(-self.decay_rate * delta_t)
            sim = F.cosine_similarity(
                query_z.unsqueeze(0), entry['z'].to(query_z.device).unsqueeze(0), dim=1
            )[0]
            scores.append((sim * entry['r']).item())
        k = min(self.top_k, len(scores))
        idxs = np.argsort(scores)[-k:]
        recall_vec = torch.stack([
            self.memory.memory[i]['z'].to(query_z.device) for i in idxs
        ]).mean(dim=0)
        return recall_vec

    def store(self, x, z, y, salience: float):
        entry = {
            'x': x.detach().cpu(),
            'z': z.detach().cpu(),
            'y': y.detach().cpu(),
            'r0': salience,
            'r': salience,
            'tau': time.time(),
        }
        if self.controller.should_store(salience, self.memory.memory):
            if len(self.memory.memory) < self.memory.capacity:
                self.memory.memory.append(entry)
            else:
                entry['replace_idx'] = self.controller.replace_index(self.memory.memory)
                self.memory.memory.append(entry)


In [7]:
# 6. Decoder Module
class Decoder(nn.Module):
    """Takes latent + recall vectors → class logits."""
    def __init__(self, latent_dim: int = 128, recall_dim: int = 128, num_classes: int = 10):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim + recall_dim, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )

    def forward(self, z: torch.Tensor, recall_vec: torch.Tensor) -> torch.Tensor:
        return self.classifier(torch.cat([z, recall_vec], dim=1))


In [8]:
# 7. EpiNetModel Assembly
class EpiNetModel(nn.Module):
    """Orchestrates encoding, episodic storage, recall, and decoding."""
    def __init__(self, latent_dim=128, memory_capacity=1000, decay_rate=1e-3, top_k=5, num_classes=10):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.memory = EpisodicMemory(memory_capacity)
        self.recaller = RecallEngine(self.memory, decay_rate, top_k)
        self.decoder = Decoder(latent_dim, latent_dim, num_classes)

    def forward(self, x: torch.Tensor):
        z = self.encoder(x)
        recall_vec = self.recaller.recall(z)
        logits = self.decoder(z, recall_vec)
        return logits, z

    def memorize(self, x, z, y, salience: float):
        self.recaller.store(x, z, y, salience)

In [9]:
# 8. Quick Sanity Check
model = EpiNetModel(latent_dim=64, memory_capacity=50, top_k=3)
dummy = torch.randn(4,1,28,28)
logits, z = model(dummy)
print(f"Logits shape: {logits.shape}, Embedding shape: {z.shape}")


Logits shape: torch.Size([4, 10]), Embedding shape: torch.Size([4, 64])


In [10]:
# 9. Full Training Loop (Split-MNIST) + Visualization
def train_split_mnist(
    latent_dim=64, memory_capacity=200, decay_rate=1e-3, top_k=5,
    num_classes=10, batch_size=64, lr=1e-3, beta=0.5, epochs_per_task=3
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    full_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    tasks = [list(range(0,5)), list(range(5,10))]
    model = EpiNetModel(latent_dim, memory_capacity, decay_rate, top_k, num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_losses = {0: [], 1: []}

    for t_idx, classes in enumerate(tasks):
        idxs = [i for i, lbl in enumerate(full_train.targets) if int(lbl) in classes]
        loader = DataLoader(Subset(full_train, idxs), batch_size=batch_size, shuffle=True)
        for ep in range(epochs_per_task):
            epoch_loss = 0.0
            for xb, yb in loader:
                xb, yb = xb.to(device), yb.to(device)
                logits, z = model(xb)
                loss_main = F.cross_entropy(logits, yb)
                if model.memory.memory:
                    k_sample = min(len(model.memory.memory), batch_size)
                    mem_idxs = np.random.choice(len(model.memory.memory), k_sample, replace=False)
                    xm = torch.stack([model.memory.memory[i]['x'] for i in mem_idxs]).to(device)
                    ym = torch.tensor([model.memory.memory[i]['y'] for i in mem_idxs],
                                      dtype=torch.long, device=device)
                    logits_mem, _ = model(xm)
                    loss_replay = F.cross_entropy(logits_mem, ym)
                    loss = loss_main + beta * loss_replay
                else:
                    loss = loss_main
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                model.memorize(xb[0], z[0], yb[0], salience=loss_main.item())
                epoch_loss += loss.item() * xb.size(0)
            avg = epoch_loss / len(idxs)
            train_losses[t_idx].append(avg)
            print(f"Task {t_idx} Epoch {ep} Loss: {avg:.4f}")
    return model, train_losses

# Run training
model, losses = train_split_mnist()

# %%
# 10. Plot Loss Curves
for t, vals in losses.items():
    plt.plot(vals, label=f'Task {t}')
plt.title("Training Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

RuntimeError: a Tensor with 64 elements cannot be converted to Scalar