In [17]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# SPDX-License-Identifier: CC-BY-4.0
#
# Code for "Bayesian continual learning and forgetting in neural networks"
# Djohan Bonnet, Kellian Cottart, Tifenn Hirtzlin, Tarcisius Januel, Thomas Dalgaty, Elisa Vianello, Damien Querlioz
# arXiv: 2504.13569
# Portions of the code are adapted from the Pytorch project (BSD-3-Clause)
#
# Author: Kellian Cottart <kellian.cottart@gmail.com>
# Date: 2025-07-03

In [18]:

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import os
import seaborn as sns
import re
import json
import pandas as pd
AXESSIZE = 28
FONTSIZE = 26
TICKSIZE = 24   
LEGENDSIZE = 26
plt.rcParams['svg.fonttype'] = 'none'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.05"
FOLDER = "output-figures"
os.makedirs(FOLDER, exist_ok=True)
results_folder = "RESULTS-PRESYNAPTIC"
df = pd.DataFrame()
# iterate through all root folders in the results folder
for folder in os.listdir(results_folder):
    current_path = os.path.join(results_folder, folder)
    # extract the name from the first config
    config_path = os.path.join(current_path, "config0/config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
    # n_iterations is the number of config folders
    n_iterations = len([f for f in os.listdir(current_path) if f.startswith("config") and os.path.isdir(os.path.join(current_path, f))])
    
    # Add the row of parameters to the dataframe   
    row = {
        "path": current_path,
        "opt": config["optimizer"] + " N=" + str(config["optimizer_params"]["N"]) if "mesu" in config["optimizer"] else config["optimizer"],
        "layers": int(config["network_params"]["layers"][:-1][1:][0]),
        "n_tasks": config["n_tasks"],
        "n_epochs": config["epochs"],
        "n_train_samples": config["n_train_samples"],
        "n_test_samples": config["n_test_samples"],
        "n_iterations": n_iterations,
    }
    df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
# One color for each path
colors = sns.color_palette("viridis", len(df["path"].unique()))
markers = ["D", "o", "s", "h", "^", "x", "v", "p", "*", "X", "D", "o", "s", "h", "^", "x", "v", "p", "*", "X"]

In [19]:
data = []
for idx, row in df.iterrows():
    path = row["path"]
    n_tasks = row["n_tasks"]
    n_epochs = row["n_epochs"]
    n_iterations = row["n_iterations"]
    full_accuracies = []
    for it in range(n_iterations):
        current_it_path = os.path.join(path, f"config{it}")
        accuracy_path = os.path.join(current_it_path, "accuracy")
        uncertainty_path = os.path.join(current_it_path, "uncertainty")
        accuracies = []
        for task in range(n_tasks):
            for epoch in range(n_epochs):
                suffix = f"task={task}-epoch={epoch}.npy"
                accuracies.append(jnp.load(os.path.join(accuracy_path, suffix)))
        full_accuracies.append(jnp.array(accuracies)) 
    full_accuracies = jnp.array(full_accuracies)*100
    accuracy_array = jnp.mean(full_accuracies, -1)[:, -1]
    accuracy_mean = accuracy_array.mean()
    accuracy_std = accuracy_array.std()
    last_accuracies_full = full_accuracies[:, -1, :].mean(axis=0)
    last_accuracies_std = full_accuracies[:, -1, :].std(axis=0)
    data.append((accuracy_mean, accuracy_std, last_accuracies_full, last_accuracies_std))
# Add new columns to df
df["accuracies"] = [d[0] for d in data]
df["accuracies_std"] = [d[1] for d in data]
df["accuracies_full"] = [d[2] for d in data]
df["accuracies_full_std"] = [d[3] for d in data]


In [20]:
df

Unnamed: 0,path,opt,layers,n_tasks,n_epochs,n_train_samples,n_test_samples,n_iterations,accuracies,accuracies_std,accuracies_full,accuracies_full_std
0,RESULTS-PRESYNAPTIC/20250704-190454-permutedmn...,mesu N=5800000,200,10,10,10,10,3,92.45693,0.57739115,"[95.76323, 91.93042, 88.29794, 89.51656, 91.48...","[0.17135017, 1.2763591, 3.149945, 0.53284943, ..."


In [21]:
df["accuracies_full"][0]

Array([95.76323 , 91.93042 , 88.29794 , 89.51656 , 91.48303 , 91.770164,
       92.6449  , 93.55969 , 94.537926, 95.065445], dtype=float32)

In [22]:
df["accuracies_full_std"][0]

Array([0.17135017, 1.2763591 , 3.149945  , 0.53284943, 1.0274543 ,
       1.2751026 , 0.23735102, 0.42987263, 0.52653974, 0.14863707],      dtype=float32)

In [None]:
# Code run on presynaptic consolidation

seeds = [{
    "task1": 0.845300018787384,
    "task2": 0.8666999936103821,
    "task3": 0.8234000205993652,
    "task4": 0.822700023651123,
    "task5": 0.8640000224113464,
    "task6": 0.8616999983787537,
    "task7": 0.8709999918937683,
    "task8": 0.8824999928474426,
    "task9": 0.8910999894142151,
    "task10": 0.9018999934196472,}, 
         {
    "task1": 0.7404000163078308,
    "task10": 0.9054999947547913,
    "task2": 0.8001999855041504,
    "task3": 0.8154000043869019,
    "task4": 0.8621000051498413,
    "task5": 0.8536999821662903,
    "task6": 0.8615000247955322,
    "task7": 0.8695999979972839,
    "task8": 0.885699987411499,
    "task9": 0.8916000127792358
},{
    "task1": 0.8238000273704529,
    "task10": 0.9006999731063843,
    "task2": 0.8690999746322632,
    "task3": 0.8434000015258789,
    "task4": 0.8730999827384949,
    "task5": 0.8640999794006348,
    "task6": 0.8640000224113464,
    "task7": 0.849399983882904,
    "task8": 0.8758000135421753,
    "task9": 0.883400022983551
}]

# turn all accuracies into an array of accuracy sorted by task based on [3:] turned into int
accuracies =  [jnp.array([v for k, v in sorted(seed.items(), key=lambda item: int(item[0][4:]))]) for seed in seeds]
accuracies = jnp.array(accuracies) * 100    
print("---- accuracies ----")
print(accuracies)
print("---- accuracies mean per task per seed ----")
print(accuracies.mean(0))
print("---- accuracies std per task per seed ----")
print(accuracies.std(0))
print("---- accuracies mean over tasks over seeds ----")
print(accuracies.mean(-1).mean())
print("---- accuracies std over tasks over seeds ----")
print(accuracies.mean(-1).std())

---- accuracies ----
[[84.53     86.67     82.340004 82.270004 86.4      86.17     87.1
  88.25     89.11     90.19    ]
 [74.04     80.02     81.54     86.21     85.369995 86.15     86.96
  88.57     89.16     90.55    ]
 [82.380005 86.909996 84.34     87.31     86.409996 86.4      84.939995
  87.58     88.340004 90.07    ]]
---- accuracies mean per task per seed ----
[80.31667  84.53334  82.740005 85.263336 86.06     86.240005 86.333336
 88.13335  88.87001  90.270004]
---- accuracies std per task per seed ----
[4.5242333  3.1929123  1.1775657  2.1637182  0.48792246 0.11343202
 0.9868937  0.41249824 0.37532127 0.20396195]
---- accuracies mean over tasks over seeds ----
85.87601
---- accuracies std over tasks over seeds ----
0.72368556
