In [2]:
!pip install equinox



In [3]:
from google.colab import files
uploaded = files.upload()

KeyboardInterrupt: 

In [4]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
# import sys
# sys.path.append('/content/drive/My Drive/julian_scripts')

ValueError: mount failed

In [1]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
import matplotlib.pyplot as plt

from scripts.models import LinearFlow, ConcatMLP, Autoencoder, SharedWeightAutoencoder, MLP
from scripts.losses import CNF_batch_loss, CNF_reverse_kl_batch_loss
from scripts.distributions import sample_multimodal_gaussian, multimodal_gaussian_logpdf, define_distributions, get_hypersphere_modes, find_3_orthogonal_points
from scripts.training import train_CNF, train_CNF_parallel, train_CNF_analytic_parallel
from scripts.utils.ode_solver import phi
from scripts.utils.plotting import make_fig_ax
from scripts.utils.distribution_statistics import stable_rank_svd

ModuleNotFoundError: No module named 'scripts'

In [2]:
model_dict = {
    'linear': LinearFlow,
    'mlp': MLP,
    'autoencoder': Autoencoder,
    'shared_weight_autoencoder': SharedWeightAutoencoder,
    'concat': ConcatMLP
}

calc_rank_dict = {
    'linear': lambda weight_list: stable_rank_svd(jnp.array(weight_list[0])),
    'linear_with_bias': lambda weight_list: stable_rank_svd(jnp.array(weight_list[0])),
    'linear_antisym': lambda weight_list: jnp.array([stable_rank_svd(weight_list[0]), stable_rank_svd((weight_list[0] - weight_list[0].T)/2)]),
    'mlp': lambda weight_list: jnp.array([stable_rank_svd(weight_list[0]), stable_rank_svd(weight_list[2])]),
    'autoencoder': lambda weight_list: jnp.array([stable_rank_svd(weight_list[0]), stable_rank_svd(weight_list[1])]),
    'shared_weight_autoencoder': lambda weight_list: stable_rank_svd(jnp.array(weight_list[0])),
    'concat':  lambda weight_list: jnp.array([stable_rank_svd(weight_list[0]), stable_rank_svd(weight_list[5])])
}

"""
model args:
antisym=(dim, key, init_var)
linear=(dim, key, init_var, init_weight)
concat=(datasize, width, depth, key, init_std)
mlp=(dim, width, depth, hidden_act, final_act, init_std, key)
"""

'\nmodel args:\nantisym=(dim, key, init_var)\nlinear=(dim, key, init_var, init_weight)\nconcat=(datasize, width, depth, key, init_std)\nmlp=(dim, width, depth, hidden_act, final_act, init_std, key)\n'

In [3]:
def create_cnf_model(model_skeleton, key):
    """Create a single CNF model"""
    return model_skeleton(key)

def unbatch_models(batched_models):
    """
    Convert batched pytree back to list of individual models.
    
    Input: SharedWeightAutoencoder(W=f32[5,2,2], b=f32[5,2], d=weak_i32[5], h=weak_i32[5])
    Output: [model_0, model_1, model_2, model_3, model_4]
    """
    num_models = batched_models.W.shape[0]
    
    models = []
    for i in range(num_models):
        # Extract i-th slice from each array in the pytree
        model_i = jax.tree_util.tree_map(
            lambda x: x[i] if hasattr(x, '__getitem__') else x,
            batched_models
        )
        models.append(model_i)
    
    return models

#### Train parallel with backprop

In [4]:
def initalize_and_train_parallel(num_models, dim, radii, training_iterations, key, num_repeats, model_type, model_args, mode_arrangement, unimodal_init=False, batch_size=128, verbose=True, eps_deg=0):

    key, train_key, mode_key1, mode_key2 = jax.random.split(key, 4)

    # Initialise models
    model_skeleton = lambda init_key: model_dict[model_type](init_key, *model_args)
    # model_skeleton = lambda init_key: LinearFlow(dim=dim, key=init_key, init_var=0.01)
    train_key, *keys = jax.random.split(key, num_repeats+1)
    # models_list = [create_cnf_model(model_skeleton, k) for k in keys]
    # We want the num_models models initialised with num_repeats different initialisations
    models_list = [create_cnf_model(model_skeleton, keys[k%num_repeats]) for k in range(radii.shape[0])]
    models = jax.tree_util.tree_map(lambda *leaves: jnp.stack(leaves), *models_list)

    # Define initial distribution
    identity = jnp.identity(dim)
    num_initial_modes = 2
    if unimodal_init:
        get_initial_samples = lambda key, initial_modes: jax.random.multivariate_normal(key, mean=jnp.zeros(dim), cov=jnp.identity(dim), shape=batch_size)
    else: 
        if mode_arrangement == 'orthogonal':
            initial_modes_batch = jnp.tile(identity[jnp.arange(num_initial_modes)], (num_models, 1, 1))*radii.reshape((num_models, 1, 1))
        elif mode_arrangement == 'orthogonal_eps_overlap':
            initial_modes = identity[jnp.arange(num_initial_modes)]
            eps_rad = jnp.deg2rad(eps_deg)
            eps_rotation = jnp.array([[jnp.cos(eps_rad), 0, -jnp.sin(eps_rad), 0],
                                    [0, jnp.cos(eps_rad), 0, -jnp.sin(eps_rad)],
                                    [jnp.sin(eps_rad), 0 , jnp.cos(eps_rad), 0],
                                    [0, jnp.sin(eps_rad), 0, jnp.cos(eps_rad)]])
            identity_reduced = jnp.identity(dim-4)
            zeros = jnp.zeros((dim-4, 4))
            rotation_block = jnp.block([[eps_rotation, zeros.T], [zeros, identity_reduced]])
            initial_modes = (rotation_block @ initial_modes.T).T
            initial_modes_batch = jnp.tile(initial_modes, (num_models, 1, 1))*radii.reshape((num_models, 1, 1))
        elif mode_arrangement == 'symmetric':
            initial_modes_batch = jnp.tile((jnp.repeat(identity, 2, axis=0) * (-1)**jnp.arange(2*dim)[:, jnp.newaxis])[jnp.arange(num_initial_modes)], (num_models, 1, 1))*radii.reshape((num_models, 1, 1))
        elif mode_arrangement == 'random_hypersphere_orthogonal':
            initial_modes1 = jax.vmap(lambda radius: get_hypersphere_modes(dim, 1, radius, jax.random.PRNGKey(0))[0])(radii)
            new_initial_modes = jax.vmap(lambda initial_mode1: find_3_orthogonal_points(initial_mode1.reshape((-1, 1))))(initial_modes1)

            def stack_modes(initial, new_modes):
                initial_modes = jnp.stack([initial, new_modes[0]])
                target_modes = new_modes[1:]
                return initial_modes, target_modes

            initial_modes_batch, target_modes_batch = jax.vmap(stack_modes)(initial_modes1, new_initial_modes)

        initial_covs = jnp.tile(jnp.identity(dim), (num_initial_modes, 1, 1))
        initial_weights = jnp.ones(num_initial_modes)
        get_initial_samples = lambda key, initial_modes: sample_multimodal_gaussian(key, means=initial_modes, covs=initial_covs, weights=initial_weights, num_samples=batch_size)

    # Define target distribution
    num_target_modes = 2
    if mode_arrangement == 'orthogonal':
        target_modes_batch = jnp.tile(identity[jnp.arange(num_initial_modes, num_initial_modes+num_target_modes)], (num_models, 1, 1))*radii.reshape((num_models, 1, 1))
    elif mode_arrangement == 'orthogonal_eps_overlap':
        target_modes_batch = jnp.tile(identity[jnp.arange(num_initial_modes, num_initial_modes+num_target_modes)], (num_models, 1, 1))*radii.reshape((num_models, 1, 1))
    elif mode_arrangement == 'symmetric':
        target_modes_batch = jnp.tile((jnp.repeat(identity, 2, axis=0) * (-1)**jnp.arange(2*dim)[:, jnp.newaxis])[jnp.arange(num_initial_modes, num_initial_modes+num_target_modes)], (num_models, 1, 1))*radii.reshape((num_models, 1, 1))
    target_covs = jnp.tile(jnp.identity(dim), (num_target_modes, 1, 1)) # dimension=dim, num_modes=2
    target_weights = jnp.ones(num_target_modes)
    target_pdf = lambda x, target_modes: multimodal_gaussian_logpdf(x, target_modes, target_covs, target_weights)

    # Loss fn
    ts_forwards = [0, 1, 0.01]
    loss_fn = lambda model, zs, key, target_modes: CNF_reverse_kl_batch_loss(model, zs, ts_forwards, lambda x: target_pdf(x, target_modes), key, approx=True)

    # Training params
    lr = 1e-3
    optimizer = optax.sgd(lr)
    calc_ranks_parallel = lambda weight_list: stable_rank_svd(jnp.array(weight_list))

    trained_models, loss_history, weight_history, rank_history, grads_history = train_CNF_parallel(models, get_initial_samples, loss_fn,
                                                            optimizer, train_key, calc_ranks_parallel,
                                                            initial_modes_batch, target_modes_batch, training_iterations, save_weights_and_grads=False,
                                                            verbose=verbose)
    return trained_models, loss_history, weight_history, rank_history, grads_history

In [5]:
def initalize_and_train_parallel_analytic(num_models, dim, radius, num_initial_modes, num_target_modes, training_iterations, key, model_type, model_args, mode_arrangement, batch_size, optimizer, unimodal_init=False, verbose=True, eps_deg=0):
    """"
    Train batch of models with analytic gradient with different seeds. Dim, radius, number of modes is constant
    """
    key, train_key, *init_keys = jax.random.split(key, num_models + 2)

    # Initialise models
    model_skeleton = lambda init_key: model_dict[model_type](init_key, *model_args)
    models_list = [create_cnf_model(model_skeleton, init_keys[k]) for k in range(num_models)]
    models = jax.tree_util.tree_map(lambda *leaves: jnp.stack(leaves), *models_list)

    # Define initial distribution
    initial_modes_single, target_modes_single, _, _, _, _ = define_distributions(dim, radius, radius, num_initial_modes, num_target_modes, mode_arrangement, num_samples=1024,  key=jax.random.PRNGKey(0), unimodal_init=unimodal_init, eps_deg=eps_deg)
    initial_modes_batch = jnp.tile(initial_modes_single, (num_models, 1, 1))
    target_modes_batch = jnp.tile(target_modes_single, (num_models, 1, 1))

    # Fn to calculate ranks
    calc_ranks_parallel = lambda weight_list: stable_rank_svd(jnp.array(weight_list))

    trained_models, loss_history, weight_history, biases_history, rank_history, grads_history = train_CNF_analytic_parallel(models, initial_modes_batch, target_modes_batch, 
                                                                                                            optimizer, train_key, calc_ranks_parallel, training_iterations, 
                                                                                                            batch_size, save_weights_and_grads=True, save_biases=False, verbose=verbose)
    return trained_models, loss_history, weight_history, rank_history, grads_history

#### Train using backprop

In [None]:
num_repeats = 10
radii = jnp.repeat(jnp.arange(4, 9), num_repeats)
dim = 30
model_args = (dim, 0.5)
eps_deg = 2
mode_arrangement = 'orthogonal_eps_overlap'
models, losses, weights, ranks, grads = initalize_and_train_parallel(num_models=50, dim=dim, radii=radii, 
                                                                     training_iterations=5000, key=jax.random.PRNGKey(0), 
                                                                     num_repeats=num_repeats, model_type='linear', 
                                                                     model_args=model_args, mode_arrangement=mode_arrangement, 
                                                                     unimodal_init=False, batch_size=128, verbose=True, eps_deg=eps_deg)

TypeError: initalize_and_train_parallel_analytic() got an unexpected keyword argument 'radii'

#### Train with analytic gradient

In [11]:
key = jax.random.PRNGKey(0)

# Model/problem setup
num_models = 30
dim = 20
radius = 10
num_initial_modes = 3
num_target_modes = 3
model_type = 'linear'
model_args = (dim, 0.01)
mode_arrangement = 'orthogonal'

# Training params
training_iterations = 5000
batch_size = 10000
lr = 1e-3
optimizer = optax.sgd(lr)

# Train models
models, losses, weights, ranks, grads = initalize_and_train_parallel_analytic(num_models, dim, radius, num_initial_modes, num_target_modes, 
                                                                              training_iterations, key, model_type, model_args, mode_arrangement, 
                                                                              batch_size, optimizer, unimodal_init=False, verbose=True, eps_deg=0)


Mean Grad Norm: 28.50529:   1%|          | 27/5000 [00:12<39:17,  2.11it/s] 


KeyboardInterrupt: 

#### Unbatch models

In [9]:
unbatched_models = unbatch_models(models)
for i, model in enumerate(unbatched_models):
    save_losses = losses[:, i]
    save_weights = weights[:, i]
    save_biases = jnp.empty(0)
    save_ranks = jnp.empty(0) #ranks[:, i]
    save_grads = jnp.empty(0) #grads[:, i]
    to_save = {
        'model': model,
        'losses': save_losses,
        'weights': save_weights,
        'biases': save_biases,
        'weight_ranks': save_ranks,
        'gradients': save_grads
    }
    eqx.tree_serialise_leaves(f'Linear_10D_10radius_3modes_orthogonal_seed{i}.eqx', to_save)