### Sleep simulations

#### Installation:

In [None]:
!pip install numpy==1.24.2
!pip tensorflow-macos==2.11.0

#### Imports:

In [None]:
from sleep_utils import *
from random import shuffle

#### Train initial VAEs

Train initial VAEs to avoid repeating this each time (leave commented out to use the VAE weights provided):

In [None]:
# seeds = range(0, 1)
# train_with_schedule_multiple_seeds(seeds, 
#                        num_cycles=10, 
#                        start_fraction_rem=0, 
#                        end_fraction_rem=0,
#                        inverted=True,
#                        use_initial_weights=False)

# !mv decoder.h5 decoder_inv.h5
# !mv encoder.h5 encoder_inv.h5

In [None]:
# seeds = range(0, 1)
# train_with_schedule_multiple_seeds(seeds, 
#                        num_cycles=10, 
#                        start_fraction_rem=0, 
#                        end_fraction_rem=0,
#                        inverted=False,
#                        use_initial_weights=False)

# !mv decoder.h5 decoder_non_inv.h5
# !mv encoder.h5 encoder_non_inv.h5

#### Baselines without sleep phase alternation

Before modelling how differing schedules of REM / NREM sleep stages affect learning, let's just test whether generative replay helps avoid catastrophic forgeting of representations.

The shuffled_baselines() function below can be used to do this. With baseline_type='new' only the new memories are used to train the VAE. With baseline_type='old' only self-generated memories (i.e. samples from the existing VAE) are used to train the VAE. With baseline_type='both' both of the above are.

In [None]:
def shuffled_baselines(baseline_type='both',
                       use_initial_weights=True, 
                       latent_dim=5, 
                       seed=0, 
                       inverted=True, 
                       lr=0.001,
                       num=1000,
                       continue_training=True):

    np.random.seed(seed)
    
    if inverted is True:
        mnist_train, mnist_test, fmnist_train, fmnist_test = prepare_datasets(split_by_digits=False, 
                                                                              split_by_inversion=True)
    else:
        mnist_train, mnist_test, fmnist_train, fmnist_test = prepare_datasets(split_by_digits=True, 
                                                                              split_by_inversion=False)

    if use_initial_weights is False:
        vae = train_mnist_vae(mnist_train, 'mnist', generative_epochs=25, learning_rate=0.001, latent_dim=latent_dim)
    else:
        print("Starting with saved weights:")

    encoder, decoder = models_dict['mnist'](latent_dim=latent_dim)
    vae = VAE(encoder, decoder, kl_weighting=1)
    if inverted is True:
        vae.encoder.load_weights("encoder_inv.h5")
        vae.decoder.load_weights("decoder_inv.h5")
    if inverted is False:
        vae.encoder.load_weights("encoder_non_inv.h5")
        vae.decoder.load_weights("decoder_non_inv.h5")
    opt = keras.optimizers.Adam(learning_rate=lr, jit_compile=False)
    vae.compile(optimizer=opt)
    
    m_err, f_err = plot_error_dists(vae, mnist_test, fmnist_test)
    check_generative_recall(vae, mnist_test, noise_level=0.0)
    
    sampled_digits = [sample_item(vae, latent_dim=latent_dim) for i in range(100)]
    sampled_digits = np.array(sampled_digits)
    show_samples(sampled_digits)
    
    mnist_errors = []
    fmnist_errors = []
    
    mnist_errors.append(np.mean(m_err))
    fmnist_errors.append(np.mean(f_err))

    random_indices = np.random.choice(fmnist_train.shape[0], num, replace=False)
    fmnist_subset = fmnist_train[random_indices]
    sampled_digits = [sample_item(vae, latent_dim=latent_dim) for i in range(num)]

    if baseline_type == 'new':
        train_data = fmnist_subset
    if baseline_type == 'old':
        train_data = np.array(sampled_digits)
    if baseline_type == 'both':
        train_data = sampled_digits + list(fmnist_subset)
        shuffle(train_data)
        train_data = np.array(train_data[0:num])
    
    print("Show training samples:")
    show_samples(train_data)
                           
    vae.fit(train_data, epochs=10, verbose=0, batch_size=1, shuffle=True)
    
    # test reconstruction error of mnist_test and fmnist_test
    m_err, f_err = plot_error_dists(vae, mnist_test, fmnist_test)
    mnist_errors.append(np.mean(m_err))
    fmnist_errors.append(np.mean(f_err))

    check_generative_recall(vae, mnist_test, noise_level=0.0)
    check_generative_recall(vae, fmnist_test, noise_level=0.0)
    
    return mnist_errors, fmnist_errors

In [None]:
shuffled_baselines(baseline_type='new', inverted=True)

In [None]:
shuffled_baselines(baseline_type='old', inverted=True)

In [None]:
shuffled_baselines(baseline_type='both', inverted=True)

#### Try different schedules

For example, here we vary just the number of cycles (for a fixed total number of epochs of training).

In [None]:
rem_fraction_pairs = [(0.5,0.5), (0,0), (1,1)]
cycles_values = [50,20,10,5]
lrs = [0.001]

seeds = range(0, 3)

for lr in lrs:
    for num_cycles in cycles_values:
        for (start_fraction_rem, end_fraction_rem) in rem_fraction_pairs:
            train_with_schedule_multiple_seeds(seeds, 
                                               total_eps=50,
                                               num_cycles=25, 
                                               start_fraction_rem=start_fraction_rem, 
                                               end_fraction_rem=end_fraction_rem,
                                               inverted=True,
                                               lr=lr,
                                               num=5,
                                               continue_training=True)
            