### Interleaved training study

Code to test what amount of interleaving is optimal when learning two datasets simultaneously.

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from generative_model import VAE, build_encoder_decoder_small
from generative_tests import check_generative_recall
from sleep_utils import *
import config
from tensorflow import keras
import numpy as np
from random import randrange
from PIL import Image
import matplotlib.backends.backend_pdf
from generative_model import models_dict
import matplotlib
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import random
from tensorflow.keras.datasets import mnist, fashion_mnist
from sklearn.model_selection import train_test_split

In [None]:
def interleave_datasets(mnist_data, fashion_data, n):
    mnist_x = mnist_data
    fashion_x = fashion_data

    if len(mnist_x) != len(fashion_x):
        raise ValueError("Datasets should have the same number of samples")

    combined_x = []

    total_samples = len(mnist_x)
    idx = 0

    while idx < total_samples:
        combined_x.extend(mnist_x[idx:idx+n])
        combined_x.extend(fashion_x[idx:idx+n])
        idx += n

    combined_x = np.array(combined_x)

    return combined_x


In [None]:
def compute_reconstruction_error(vae, test_data):
    encs = vae.encoder.predict(test_data)
    reconstructed_data = vae.decoder.predict(encs[0])
    mse = np.mean((test_data - reconstructed_data) ** 2)
    return mse

def train_mnist_vae(train_data, dataset, generative_epochs=50, latent_dim=20, kl_weighting=1, learning_rate=0.01):
    encoder, decoder = models_dict[dataset](latent_dim=latent_dim)
    vae = VAE(encoder, decoder, kl_weighting)
    opt = keras.optimizers.Adam(learning_rate=learning_rate, jit_compile=False)
    vae.compile(optimizer=opt)
    history = vae.fit(train_data, epochs=generative_epochs, verbose=1,
                      batch_size=1, shuffle=False)
    vae.encoder.save('encoder.h5')
    vae.decoder.save('decoder.h5')
    return vae

mnist_train_x, mnist_test_x, fashion_train_x, fashion_test_x = prepare_datasets(split_by_digits=False, 
                                                                      split_by_inversion=True)
mnist_train_x = mnist_train_x[0:20000]
mnist_test_x = mnist_test_x[0:1000]
fashion_train_x = fashion_train_x[0:20000]
fashion_test_x = fashion_test_x[0:1000]

def run_experiment(random_seed, plot=False):
    np.random.seed(random_seed)

    # Test different values of n in the range 1 to 100
    reconstruction_errors = []

    for n in [1, 5, 10, 50]:
        # Combine and interleave the datasets
        train_x = interleave_datasets(mnist_train_x, fashion_train_x, n)

        # Train the VAE
        vae = train_mnist_vae(train_x, 'mnist', generative_epochs=1, learning_rate=0.001)

        # Compute the mean reconstruction error for MNIST and Fashion-MNIST
        mnist_error = compute_reconstruction_error(vae, mnist_test_x)
        fashion_error = compute_reconstruction_error(vae, fashion_test_x)

        # Store the errors
        reconstruction_errors.append((n, mnist_error, fashion_error))
        print(f"Interleaving factor {n}: MNIST error = {mnist_error}, Fashion-MNIST error = {fashion_error}")
        
        if plot is True:
            check_generative_recall(vae, mnist_test_x, noise_level=0.15)
            check_generative_recall(vae, fashion_test_x, noise_level=0.15)
            plot_error_dists(vae, mnist_test_x, fashion_test_x)
    
    return reconstruction_errors


In [None]:
num_seeds = 1
seeds = np.random.randint(0, 10000, size=num_seeds)
all_errors = []

for seed in seeds:
    print(f"Running experiment with random seed: {seed}")
    errors = run_experiment(seed, plot=True)
    all_errors.append(errors)

mean_errors = np.mean(all_errors, axis=0)
std_errors = np.std(all_errors, axis=0)


In [None]:
plt.figure(figsize=(10, 6))

n_values = [1, 5, 10, 50]
plt.errorbar(n_values, mean_errors[:, 1], yerr=std_errors[:, 1], label='MNIST', marker='o', capsize=5)
plt.errorbar(n_values, mean_errors[:, 2], yerr=std_errors[:, 2], label='Fashion-MNIST', marker='o', capsize=5)

plt.xlabel('Interleaving Factor (n)')
plt.ylabel('Reconstruction Error')
plt.title('Reconstruction Error vs. Interleaving Factor (with error bars)')
plt.legend()

plt.show()