In [48]:
# 1. Setup & Imports
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
from torch.optim import Adam
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt

In [49]:
class Encoder(nn.Module):
    """Transforms input images into a latent embedding z."""
    def __init__(self, latent_dim: int = 128):
        super(Encoder, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 28×28 → 28×28
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 28×28 → 14×14
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 14×14 → 14×14
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 14×14 → 7×7
            nn.Flatten(),     # → (64*7*7)
    )
        # Project flattened features to latent_dim
        self.project = nn.Linear(64 * 7 * 7, latent_dim)  # → z ∈ ℝᵈ
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.feature_extractor(x)
        z = self.project(features)
        return z # Matches “z ∈ ℝᵈ”

In [50]:
class Decoder(nn.Module):
    """
    Mixture-of-Experts Decoder for EpiNet.

    Given latent embedding z_t and recall embedding r_t,
    concatenates to h_t ∈ ℝ^{2d}, then uses E parallel experts
    and a gating network to produce class logits:
      • Gate: g = softmax(G·h_t) ∈ ℝ^E
      • Experts: ℓ^{(e)} = expert_e(h_t) ∈ ℝ^K
      • logits = Σ_{e=1}^E g_e · ℓ^{(e)}
    """
    def __init__(
        self,
        latent_dim: int,
        hidden_dim: int,
        num_classes: int,
        num_experts: int = 4
    ):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_experts = num_experts
        # Gating network: 2d → E
        self.gate = nn.Linear(latent_dim * 2, num_experts)
        # Experts: each maps 2d → hidden_dim → num_classes
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim * 2, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, num_classes)
            ) for _ in range(num_experts)
        ])

    def forward(self, z: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """
        Args:
          z (torch.Tensor): [B, d], latent embedding from Encoder
          r (torch.Tensor): [B, d], recall embedding from RecallEngine
        Returns:
          logits (torch.Tensor): [B, num_classes]
        """
        # Concatenate embeddings: [B, 2d]
        h = torch.cat([z, r], dim=1)
        # Compute gating weights: [B, E]
        gate_logits = self.gate(h)
        gate_weights = F.softmax(gate_logits, dim=1)
        # Expert outputs: stack into [B, E, K]
        expert_outputs = torch.stack(
            [expert(h) for expert in self.experts],
            dim=1
        )
        # Weighted sum of experts: [B, K]
        gate_weights = gate_weights.unsqueeze(-1)  # [B, E, 1]
        logits = (gate_weights * expert_outputs).sum(dim=1)
        return logits

In [51]:
class NoReplayModel(nn.Module):
    """
    A pure supervised model: encode → decode,
    with no episodic memory or replay.
    """
    def __init__(self,
                 latent_dim: int = 128,
                 hidden_dim: int = 256,
                 num_classes: int = 10,
                 device: torch.device = None):
        super().__init__()
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = Encoder(latent_dim).to(self.device)
        self.decoder = Decoder(latent_dim, hidden_dim, num_classes).to(self.device)
        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def forward(self, x: torch.Tensor, y: torch.Tensor = None):
        x = x.to(self.device)
        # 1) encode
        z = self.encoder(x)                        # [B, latent_dim]
        # 2) no recall → zero vector
        recall_vec = torch.zeros_like(z)           # [B, latent_dim]
        # 3) decode
        logits = self.decoder(z, recall_vec)       # [B, num_classes]

        if y is None:
            return logits

        # 4) supervised loss only
        y = y.to(self.device)
        loss = self.criterion(logits, y)
        return loss

In [52]:
def train_test_split_loader(
    batch_size:  int,
    num_workers: int  = 4,
    pin_memory:  bool = True,
    test_frac:   float = 0.2
):
    """
    Returns four DataLoaders for Split‑MNIST with an 80/20 train‑test split in each task:
      - Task1 (digits 0–4): train1, test1
      - Task2 (digits 5–9): train2, test2
    """
    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])

    # Load full MNIST training set
    full = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    targets = full.targets

    # Split indices
    idx1 = (targets < 5).nonzero(as_tuple=True)[0]
    idx2 = (targets >= 5).nonzero(as_tuple=True)[0]
    ds1 = Subset(full, idx1)
    ds2 = Subset(full, idx2)

    # 80/20 split each
    def split(ds):
        n = len(ds)
        n_test = int(n * test_frac)
        n_train = n - n_test
        return random_split(ds, [n_train, n_test])

    train1_ds, test1_ds = split(ds1)
    train2_ds, test2_ds = split(ds2)

    # DataLoaders
    loader_args = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
    train1 = DataLoader(train1_ds, shuffle=True,  **loader_args)
    test1  = DataLoader(test1_ds,  shuffle=False, **loader_args)
    train2 = DataLoader(train2_ds, shuffle=True,  **loader_args)
    test2  = DataLoader(test2_ds,  shuffle=False, **loader_args)

    return train1, test1, train2, test2

In [53]:
batch_size = 16
train1, test1, train2, test2 = train_test_split_loader(batch_size=batch_size)

In [10]:
epochs_per_task = 10

def evaluate(model: nn.Module, loader: DataLoader) -> float:
    """
    Compute classification accuracy of `model` on `loader`.
    Automatically uses whatever device the model’s parameters live on.
    """
    device = next(model.parameters()).device
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)               # returns logits only
            preds = torch.argmax(logits, dim=1)
            correct += (preds == yb).sum().item()
            total   += yb.size(0)
    return correct / total if total else 0.0

In [11]:
# Instantiate
model = NoReplayModel(
    latent_dim=128,
    hidden_dim=256,
    num_classes=10,
    device=torch.device('cpu')
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop for Task 1
for epoch in range(epochs_per_task):
    model.train()
    for x, y in train1:
        optimizer.zero_grad()
        loss = model(x, y)
        loss.backward()
        optimizer.step()

# Evaluate Task 1
acc1_after_T1 = evaluate(model, test1)

# Training loop for Task 2
for epoch in range(epochs_per_task):
    model.train()
    for x, y in train2:
        optimizer.zero_grad()
        loss = model(x, y)
        loss.backward()
        optimizer.step()

# Evaluate Task 1 again
acc1_after_T2 = evaluate(model, test1)


In [12]:
acc1_after_T1


0.9872528190880863

In [13]:
acc1_after_T2

0.0

In [14]:
## Episodic Memory Added

In [54]:
# 3. MemoryController Module
class MemoryController:
    """
    Computes the biologically-inspired salience decay.
    salience_m = r * exp(–α · (τ_now – τ_m))

    Notation:
      • r      … initial salience score (r₀)
      • τ_m    … timestamp when memory m was stored
      • α      … decay rate (learnable or fixed)
      • salience_m … decayed salience at the current time
    """
    def __init__(self, alpha: float):
        """
        α: decay rate (higher → faster forgetting)
        """
        self.alpha = alpha

    def decay(self, r: torch.Tensor, tau_m: torch.Tensor) -> torch.Tensor:
        """
           Apply decay to a batch of memories.

        Args:
            r      (torch.Tensor): shape [M], initial salience scores for M memories
            tau_m  (torch.Tensor): shape [M], stored timestamps (in seconds) for each memory

        Returns:
            torch.Tensor of shape [M]:
              current salience_m = r * exp(–α · (τ_now – τ_m))
        """
        #  Current time in seconds, on same device as inputs
        tau_now = torch.tensor(time.time(), device=r.device)
        #  Δτ = τ_now – τ_m
        delta_tau = tau_now - tau_m
        #  Apply decay element-wise
        salience_m = r * torch.exp(-self.alpha * delta_tau)
        return salience_m

In [55]:
# 4. EpisodicMemory Module
class EpisodicMemory:
    """
    Fixed‑capacity buffer storing tuples, m = (z, c, r₀, τₘ, yₘ).
    admitting only the most salient memories over time.
    Admission & eviction are driven by current salience decay:
      • Admit new memory if under capacity,
      • Otherwise evict the memory with lowest decayed salience.
    """
    def __init__(
            self,
            capacity: int,
            latent_dim: int,
            decay_rate: float,
            device: torch.device
    ):
        """
        Args:
            capacity (int): maximum number of memories
            latent_dim (int): dimensionality of the latent space
            decay_rate (float): decay rate for salience decay
            device (torch.device): device to store the memory
        """
        self.capacity = capacity
        self.device = device

        # Memory buffer (initially empty)
        self.z_buffer = torch.empty((0, latent_dim), device=device) # z: [N × d] latent embeddings z
        self.c_buffer = torch.empty((0, latent_dim), device=device) # c: [N × d] context vectors
        self.r0_buffer = torch.empty((0,), device=device)  # r0: [N] initial salience score r₀
        self.tau_buffer = torch.empty((0,), device=device) # timestamps τₘ when stored
        self.y_buffer = torch.empty((0,), dtype=torch.long, device=device) # labels yₘ

        # Reuse your MemoryController for decay
        self.mem_ctrl = MemoryController(decay_rate)

    def add(
            self,
            z: torch.Tensor,
            c: torch.Tensor,
            r0: float,
            y: torch.Tensor
    ):
        """
        Try to admit new memory (z,c,r0,τ,y).
        If at capacity, evict the lowest‐decayed‐salience memory.
        """
        # Preparing tensors for concatenation
        z = z.detach().to(self.device).view(1, -1)
        c = c.detach().to(self.device).view(1, -1)
        r0 = torch.tensor([r0], device=self.device)
        y = y.detach().to(self.device).view(1)
        tau = torch.tensor([time.time()], device=self.device)

        # If under capacity, just append/admit
        if self.z_buffer.shape[0] < self.capacity:
            self._append(z, c, r0, tau, y)
            return

        # Otherwise compute decayed salience of existing memories
        s_existing = self.mem_ctrl.decay(self.r0_buffer, self.tau_buffer)

        # If this new memory isn't more salience than the least one, skip
        if r0 <= s_existing.min():
            return

        # Else evict the lowest‐salience and replace it
        idx = torch.argmin(s_existing).item()
        self._replace(idx, z, c, r0, tau, y)

    def _append(self, z, c, r0, tau, y):
        """Add a new memory at the end of each buffer."""
        self.z_buffer = torch.cat([self.z_buffer, z], dim=0)
        self.c_buffer = torch.cat([self.c_buffer, c], dim=0)
        self.r0_buffer = torch.cat([self.r0_buffer, r0], dim=0)
        self.tau_buffer = torch.cat([self.tau_buffer, tau], dim=0)
        self.y_buffer = torch.cat([self.y_buffer, y], dim=0)


    def _replace(self, idx, z, c, r0, tau, y):
        """Overwrite the memory at index `idx` with the new one."""
        self.z_buffer[idx] = z
        self.c_buffer[idx] = c
        self.r0_buffer[idx] = r0
        self.tau_buffer[idx] = tau
        self.y_buffer[idx] = y

    def clear(self):
        """Reset all buffers to empty."""
        self.__init__(self.capacity, self.z_buffer.size(1), self.mem_ctrl.alpha, self.device)


In [56]:
# 3. MemoryController Module
class MemoryController:
    """
    Computes the biologically-inspired salience decay.
    salience_m = r * exp(–α · (τ_now – τ_m))

    Notation:
      • r      … initial salience score (r₀)
      • τ_m    … timestamp when memory m was stored
      • α      … decay rate (learnable or fixed)
      • salience_m … decayed salience at the current time
    """
    def __init__(self, alpha: float):
        """
        α: decay rate (higher → faster forgetting)
        """
        self.alpha = alpha

    def decay(self, r: torch.Tensor, tau_m: torch.Tensor) -> torch.Tensor:
        """
           Apply decay to a batch of memories.

        Args:
            r      (torch.Tensor): shape [M], initial salience scores for M memories
            tau_m  (torch.Tensor): shape [M], stored timestamps (in seconds) for each memory

        Returns:
            torch.Tensor of shape [M]:
              current salience_m = r * exp(–α · (τ_now – τ_m))
        """
        #  Current time in seconds, on same device as inputs
        tau_now = torch.tensor(time.time(), device=r.device)
        #  Δτ = τ_now – τ_m
        delta_tau = tau_now - tau_m
        #  Apply decay element-wise
        salience_m = r * torch.exp(-self.alpha * delta_tau)
        return salience_m

In [57]:
# 5. RecallEngine Module
# - **top_k**: number of highest‑scoring memories to retrieve  
# - **RecallScoreₘ** = cos(z_t, cₘ) · salienceₘ  
# - **salienceₘ** = r₀ₘ · exp(–α·(τ_now – τₘ))  
# - **Recall embedding**: where the sum is over the Top‑K memories.

class RecallEngine:
    """
    Handles salience decay, scoring, and top-k retrieval from EpisodicMemory.
    Retrieves salient memories based on cosine similarity and decayed salience.

    Given a query embedding z_t and stored memories (z_m, c_m, r0_m, τ_m),
    computes for each memory:
      RecallScore_m = cos(z_t, c_m) * salience_m
    where salience_m = r0_m * exp(-α * (τ_now - τ_m)).
    Selects Top‑K memories by RecallScore, then computes recall embedding:
      r_t = (1 / Σ_i r_i) * Σ_i r_i * z_i,
    where r_i is the decayed salience of the selected memories.
    """
    def __init__(self, top_k: int):
        """
        top_k: number of highest‑scoring memories to retrieve.
        """
        self.top_k = int(top_k)

    def recall(self, z_query: torch.Tensor, memory: EpisodicMemory) -> torch.Tensor:
        """
        Perform memory recall for a batch of query embeddings.

        Args:
            z_query (torch.Tensor): shape [B, d], query latent embeddings z_t.
            memory (EpisodicMemory): contains buffers:
              - c_buffer: [N, d] context embeddings c_m
              - z_buffer: [N, d] latent embeddings z_m
              - r0_buffer: [N]   initial salience r0_m
              - tau_buffer: [N]  timestamps τ_m

        Returns:
            torch.Tensor: shape [B, d], recall embeddings r_t.
        """
        # Guard for empty memory
        if memory.z_buffer.size(0) == 0:
            return torch.zeros_like(z_query)

        # Compute decayed salience for all stored memories: [N]
        salience = memory.mem_ctrl.decay(memory.r0_buffer, memory.tau_buffer)

        # Compute cosine similarity: [B, N]
        cos_sim = F.cosine_similarity(
            z_query.unsqueeze(1),         # [B, 1, d]
            memory.c_buffer.unsqueeze(0),   # [1, N, d]
            dim=-1
        )

        # Compute recall scores for selection: [B, N]
        recall_scores = cos_sim * salience.unsqueeze(0)

        # Select Top-K indices by recall score: [B, K]
        _, top_idx = torch.topk(recall_scores, self.top_k, dim=1)

        # Gather salience and latent embeddings of selected memories
        salience_topk = salience[top_idx]          # [B, K]
        z_topk = memory.z_buffer[top_idx]   # [B, K, d]

        # Normalize by sum of salience: weights = r_i / Σ_j r_j
        weights = salience_topk / (salience_topk.sum(dim=1, keepdim=True) + 1e-8)

        # Weighted sum to form recall embedding: [B, d]
        recall_emb = (weights.unsqueeze(-1) * z_topk).sum(dim=1)
        return recall_emb

In [58]:
class EpiNetModel(nn.Module):
    """
    EpiNetModel integrates:
      1) Encoder → latent embedding z_t
      2) RecallEngine/EpisodicMemory → recall embedding r_t
      3) Decoder → class logits
    Computes joint loss:
      L_total = L_task + λ·L_replay
    """
    def __init__(
        self,
        latent_dim: int,
        hidden_dim: int,
        num_classes: int,
        capacity: int,
        decay_rate: float,
        top_k: int,
        lambda_coef: float,
        device: torch.device = None
    ):
        super(EpiNetModel, self).__init__()
        self.device      = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Core modules
        self.encoder       = Encoder(latent_dim)
        self.decoder       = Decoder(latent_dim, hidden_dim, num_classes)
        self.memory        = EpisodicMemory(capacity, latent_dim, decay_rate, self.device)
        self.recall_engine = RecallEngine(top_k)

        # Loss and hyperparameters
        self.criterion   = nn.CrossEntropyLoss()
        self.lambda_coef = lambda_coef

        # Move to device
        self.to(self.device)

    def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor:
        x = x.to(self.device)
        # Encode
        z = self.encoder(x)
        # Recall
        r = self.recall_engine.recall(z, self.memory)
        # Predict
        logits = self.decoder(z, r)
        if y is None:
            return logits

        # Task loss
        y = y.to(self.device)
        loss_task = self.criterion(logits, y)

        # --- Replay loss per Core Math with capacity guard ---
        # Decayed salience: [N]
        salience_mem = self.memory.mem_ctrl.decay(
            self.memory.r0_buffer,
            self.memory.tau_buffer
        )
        # Cosine similarity [B, N]
        cos_sim_mem = F.cosine_similarity(
            z.unsqueeze(1),                # [B,1,d]
            self.memory.c_buffer.unsqueeze(0),  # [1,N,d]
            dim=-1
        )
        # Recall scores [B, N]
        recall_scores = cos_sim_mem * salience_mem.unsqueeze(0)

        # Only compute replay if we have any memories
        N = self.memory.z_buffer.size(0)
        if N > 0:
            # clamp k by current memory size
            k = min(self.recall_engine.top_k, N)
            _, top_idx = torch.topk(
                recall_scores,
                k,
                dim=1
            )
            # Gather for replay
            salience_topk = salience_mem[top_idx]        # [B,K]
            z_topk        = self.memory.z_buffer[top_idx] # [B,K,d]
            y_topk        = self.memory.y_buffer[top_idx] # [B,K]
            # Flatten for batch decode
            B, K, d = z_topk.shape
            z_flat = z_topk.view(B*K, d)
            r_flat = torch.zeros_like(z_flat)
            y_flat = y_topk.view(B*K)
            # Replay logits and per-item losses
            logits_mem = self.decoder(z_flat, r_flat)
            losses_mem = F.cross_entropy(logits_mem, y_flat, reduction='none').view(B, K)
            # Weighted sum for replay loss
            loss_replay = (salience_topk * losses_mem).sum()
            # Total loss with replay
            loss = loss_task + self.lambda_coef * loss_replay
        else:
            # No memories yet, skip replay
            loss = loss_task

        # Update episodic memory
        with torch.no_grad():
            initial_r0 = 1.0
            for zi, yi in zip(z, y):
                # Use z as both embedding & context
                self.memory.add(zi, zi, initial_r0, yi)

        return loss

In [42]:
latent_dim = 16
hidden_dim = 32
num_classes = 10
capacity = 20
decay_rate = 0.05
top_k = 5
lambda_coef = 0.5
device = torch.device('cpu')
model = EpiNetModel(
    latent_dim, hidden_dim, num_classes,
    capacity, decay_rate, top_k, lambda_coef,
    device
)

In [59]:
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in loader:
            logits = model(xb)                  # uses forward(x) → logits only
            preds  = logits.argmax(dim=1)
            correct += (preds == yb.to(logits.device)).sum().item()
            total   += yb.size(0)
    return correct/total if total else 0.0


In [60]:
def train_one_task(
    task_id: int,
    train_loader: DataLoader,
    test_loader:  DataLoader,
    model:        nn.Module,
    optimizer:    torch.optim.Optimizer,
    scaler:       GradScaler,
    epochs:       int
):
    print(f"=== Training Task {task_id} ===")
    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(model.device), yb.to(model.device)
            optimizer.zero_grad()
            with autocast():
                loss = model(xb, yb)        # memory is updated inside forward
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item() * xb.size(0)
        avg_loss = total_loss / len(train_loader.dataset)
        acc = evaluate(model, test_loader)
        print(f"[Task {task_id} · epoch {epoch}] loss={avg_loss:.4f}, test_acc={acc:.4f}")
    print()


In [36]:
### Without clearing memory after task 1

In [37]:
# 1. Prepare data
train1, test1, train2, test2 = train_test_split_loader(batch_size=128)

# 2. Instantiate model, optimizer, scaler
optimizer = Adam(model.parameters(), lr=1e-3)
scaler    = GradScaler()
epochs    = 10

# 3. Task 1
train_one_task(1, train1, test1, model, optimizer, scaler, epochs)

# 4. inspect Task1 retention
acc1_postT1 = evaluate(model, test1)
print("Task1 end-of-training acc:", acc1_postT1)

# 5. Task 2 (memory remains, so replay is possible)
train_one_task(2, train2, test2, model, optimizer, scaler, epochs)

# 6. Final joint eval
acc1_final = evaluate(model, test1)
acc2_final = evaluate(model, test2)
print(f"Final acc Task1: {acc1_final:.4f}, Task2: {acc2_final:.4f}")


=== Training Task 1 ===


  scaler    = GradScaler()
  with autocast():


[Task 1 · epoch 1] loss=524.9949, test_acc=0.9745
[Task 1 · epoch 2] loss=243.3017, test_acc=0.9850
[Task 1 · epoch 3] loss=100.5207, test_acc=0.9871
[Task 1 · epoch 4] loss=135.0226, test_acc=0.9874
[Task 1 · epoch 5] loss=50.1477, test_acc=0.9856
[Task 1 · epoch 6] loss=18.5658, test_acc=0.9912
[Task 1 · epoch 7] loss=4.8374, test_acc=0.9905
[Task 1 · epoch 8] loss=3.0064, test_acc=0.9873
[Task 1 · epoch 9] loss=1.1887, test_acc=0.9833
[Task 1 · epoch 10] loss=0.0381, test_acc=0.9877

Task1 end-of-training acc: 0.9877430952770061
=== Training Task 2 ===
[Task 2 · epoch 1] loss=2.6182, test_acc=0.0000
[Task 2 · epoch 2] loss=2.3647, test_acc=0.0000
[Task 2 · epoch 3] loss=2.3410, test_acc=0.0000
[Task 2 · epoch 4] loss=2.3160, test_acc=0.0000
[Task 2 · epoch 5] loss=2.2898, test_acc=0.0000
[Task 2 · epoch 6] loss=2.2626, test_acc=0.2060
[Task 2 · epoch 7] loss=2.2345, test_acc=0.2060
[Task 2 · epoch 8] loss=440.0653, test_acc=0.2060
[Task 2 · epoch 9] loss=410.3359, test_acc=0.2060
[T

In [38]:
acc1_postT1

0.9877430952770061

In [39]:
acc1_final


0.0

In [40]:
acc2_final

0.20595238095238094

In [43]:
# 1. Prepare data
train1, test1, train2, test2 = train_test_split_loader(batch_size=128)

# 2. Instantiate model, optimizer, scaler
optimizer = Adam(model.parameters(), lr=1e-3)
scaler    = GradScaler()
epochs    = 10

# 3. Task 1
train_one_task(1, train1, test1, model, optimizer, scaler, epochs)

# 4. inspect Task1 retention
acc1_postT1 = evaluate(model, test1)
print("Task1 end-of-training acc:", acc1_postT1)

# 5. Clear memory before Task 2 (fresh slate)
model.memory.clear()
train_one_task(2, train2, test2, model, optimizer, scaler, epochs)

# 6. Final joint eval
acc1_final = evaluate(model, test1)
acc2_final = evaluate(model, test2)
print(f"Final acc Task1: {acc1_final:.4f}, Task2: {acc2_final:.4f}")


=== Training Task 1 ===


  scaler    = GradScaler()
  with autocast():


[Task 1 · epoch 1] loss=147.3190, test_acc=0.9773
[Task 1 · epoch 2] loss=16.4186, test_acc=0.9922
[Task 1 · epoch 3] loss=3.2950, test_acc=0.9922
[Task 1 · epoch 4] loss=0.6319, test_acc=0.9912
[Task 1 · epoch 5] loss=0.2859, test_acc=0.9900
[Task 1 · epoch 6] loss=0.1678, test_acc=0.9956
[Task 1 · epoch 7] loss=0.0990, test_acc=0.9935
[Task 1 · epoch 8] loss=0.0697, test_acc=0.9946
[Task 1 · epoch 9] loss=0.0596, test_acc=0.9946
[Task 1 · epoch 10] loss=0.1681, test_acc=0.9941

Task1 end-of-training acc: 0.9941166857329629
=== Training Task 2 ===
[Task 2 · epoch 1] loss=139.4774, test_acc=0.9517
[Task 2 · epoch 2] loss=0.4626, test_acc=0.9702
[Task 2 · epoch 3] loss=0.4414, test_acc=0.9794
[Task 2 · epoch 4] loss=0.2438, test_acc=0.9838
[Task 2 · epoch 5] loss=0.1068, test_acc=0.9845
[Task 2 · epoch 6] loss=0.0812, test_acc=0.9864
[Task 2 · epoch 7] loss=0.0677, test_acc=0.9872
[Task 2 · epoch 8] loss=0.0579, test_acc=0.9891
[Task 2 · epoch 9] loss=0.5099, test_acc=0.9757
[Task 2 · e

In [44]:
acc1_postT1

0.9941166857329629

In [45]:
acc1_final

0.0

In [61]:
latent_dim = 16
hidden_dim = 32
num_classes = 10
capacity = 5000
decay_rate = 0.05
top_k = 5
lambda_coef = 0.5
device = torch.device('cpu')
model = EpiNetModel(
    latent_dim, hidden_dim, num_classes,
    capacity, decay_rate, top_k, lambda_coef,
    device
)

In [62]:
# 1. Prepare data
train1, test1, train2, test2 = train_test_split_loader(batch_size=128)

# 2. Instantiate model, optimizer, scaler
optimizer = Adam(model.parameters(), lr=1e-3)
scaler    = GradScaler()
epochs    = 10

# 3. Task 1
train_one_task(1, train1, test1, model, optimizer, scaler, epochs)

# 4. inspect Task1 retention
acc1_postT1 = evaluate(model, test1)
print("Task1 end-of-training acc:", acc1_postT1)

# 5. Task 2 (memory remains, so replay is possible)
train_one_task(2, train2, test2, model, optimizer, scaler, epochs)

# 6. Final joint eval
acc1_final = evaluate(model, test1)
acc2_final = evaluate(model, test2)
print(f"Final acc Task1: {acc1_final:.4f}, Task2: {acc2_final:.4f}")

=== Training Task 1 ===


  scaler    = GradScaler()
  with autocast():


[Task 1 · epoch 1] loss=170.8623, test_acc=0.9734
[Task 1 · epoch 2] loss=34.0531, test_acc=0.9863
[Task 1 · epoch 3] loss=25.6282, test_acc=0.9884
[Task 1 · epoch 4] loss=17.4391, test_acc=0.9917
[Task 1 · epoch 5] loss=6.2336, test_acc=0.9936
[Task 1 · epoch 6] loss=5.5208, test_acc=0.9917
[Task 1 · epoch 7] loss=5.1656, test_acc=0.9956
[Task 1 · epoch 8] loss=5.3923, test_acc=0.9943
[Task 1 · epoch 9] loss=2.6825, test_acc=0.9943
[Task 1 · epoch 10] loss=1.0483, test_acc=0.9946

Task1 end-of-training acc: 0.9946069619218827
=== Training Task 2 ===
[Task 2 · epoch 1] loss=1.5114, test_acc=0.7816
[Task 2 · epoch 2] loss=0.4180, test_acc=0.9262
[Task 2 · epoch 3] loss=37.8773, test_acc=0.9471
[Task 2 · epoch 4] loss=60.0141, test_acc=0.9668
[Task 2 · epoch 5] loss=42.7368, test_acc=0.9764
[Task 2 · epoch 6] loss=29.0783, test_acc=0.9759
[Task 2 · epoch 7] loss=22.6780, test_acc=0.9794
[Task 2 · epoch 8] loss=11.2401, test_acc=0.9830
[Task 2 · epoch 9] loss=10.1574, test_acc=0.9835
[Tas

In [63]:
latent_dim = 16
hidden_dim = 32
num_classes = 10
capacity = 50000
decay_rate = 0.05
top_k = 5
lambda_coef = 0.5
device = torch.device('cpu')
model = EpiNetModel(
    latent_dim, hidden_dim, num_classes,
    capacity, decay_rate, top_k, lambda_coef,
    device
)

In [64]:
# 1. Prepare data
train1, test1, train2, test2 = train_test_split_loader(batch_size=128)

# 2. Instantiate model, optimizer, scaler
optimizer = Adam(model.parameters(), lr=1e-3)
scaler    = GradScaler()
epochs    = 10

# 3. Task 1
train_one_task(1, train1, test1, model, optimizer, scaler, epochs)

# 4. inspect Task1 retention
acc1_postT1 = evaluate(model, test1)
print("Task1 end-of-training acc:", acc1_postT1)

# 5. Task 2 (memory remains, so replay is possible)
train_one_task(2, train2, test2, model, optimizer, scaler, epochs)

# 6. Final joint eval
acc1_final = evaluate(model, test1)
acc2_final = evaluate(model, test2)
print(f"Final acc Task1: {acc1_final:.4f}, Task2: {acc2_final:.4f}")

=== Training Task 1 ===


  scaler    = GradScaler()
  with autocast():


[Task 1 · epoch 1] loss=155.1899, test_acc=0.9632
[Task 1 · epoch 2] loss=28.1807, test_acc=0.9788
[Task 1 · epoch 3] loss=13.0096, test_acc=0.9856
[Task 1 · epoch 4] loss=7.7389, test_acc=0.9886
[Task 1 · epoch 5] loss=5.7921, test_acc=0.9920
[Task 1 · epoch 6] loss=4.8277, test_acc=0.9940
[Task 1 · epoch 7] loss=3.1445, test_acc=0.9949
[Task 1 · epoch 8] loss=3.1893, test_acc=0.9941
[Task 1 · epoch 9] loss=3.1580, test_acc=0.9956
[Task 1 · epoch 10] loss=2.2608, test_acc=0.9951

Task1 end-of-training acc: 0.9950972381108024
=== Training Task 2 ===
[Task 2 · epoch 1] loss=171.7361, test_acc=0.9789
[Task 2 · epoch 2] loss=11.2597, test_acc=0.9849
[Task 2 · epoch 3] loss=7.3447, test_acc=0.9840
[Task 2 · epoch 4] loss=5.2615, test_acc=0.9869
[Task 2 · epoch 5] loss=5.1503, test_acc=0.9878
[Task 2 · epoch 6] loss=3.4098, test_acc=0.9889
[Task 2 · epoch 7] loss=3.4443, test_acc=0.9900
[Task 2 · epoch 8] loss=2.2764, test_acc=0.9881
[Task 2 · epoch 9] loss=2.2736, test_acc=0.9906
[Task 2 ·