In [19]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import znnl as nl
import tensorflow_datasets as tfds
import numpy as np
from flax import linen as nn
import optax
from jax import random
from jax.lib import xla_bridge
import matplotlib.pyplot as plt
print(f"Using: {xla_bridge.get_backend().platform}")
import copy as cp
import jax.numpy as jnp
from papyrus.measurements import Loss, Accuracy, NTKEntropy, NTKTrace

Using: gpu


The Data Filter

In [20]:
def filter_numbers(data_set: dict, nr_type) -> dict:
    """
    Takes in a data set and returns a filtered data set only consisting of odd or even numbers.
    n; block dims: 128x1x1; grid dims

    Arguments
    ---------
    data_set : dict
            data set to be filtered.
    nr_type : integer
            For even numbers put 0, for odd numbers put 1

    Returns
    -------
    filtered_data_set : dict
            filtered data set
    """
    if nr_type == 1:
        
        # decompose data set into inputs and targets:
        inputs = data_set['inputs']
        targets = data_set['targets']
    
        # Get indices of odd numbers using targets:
        integer_targets = np.argmax(targets, axis=1)
        mod_targets = integer_targets % 2
        indices_of_odd_numbers = np.argwhere(mod_targets == 1).squeeze()
    
        # Take data according to indices: 
        inputs_odd = inputs[indices_of_odd_numbers]
        targets_odd = targets[indices_of_odd_numbers]
    
        # Construct and return new data set:
        return {"inputs": inputs_odd, "targets": targets_odd}
    
    if nr_type == 0:

        # decompose data set into inputs and targets:
        inputs = data_set['inputs']
        targets = data_set['targets']
    
        # Get indices of odd numbers using targets:
        integer_targets = np.argmax(targets, axis=1)
        mod_targets = integer_targets % 2
        indices_of_even_numbers = np.argwhere(mod_targets == 0).squeeze()
    
        # Take data according to indices: 
        inputs_even = inputs[indices_of_even_numbers]
        targets_even = targets[indices_of_even_numbers]
    
        # Construct and return new data set:
        return {"inputs": inputs_even, "targets": targets_even}
    
    else:
        return print("Please enter correct arguments. See documentation for help.")

Selecting Subsets of Data to shorten computation Time

In [21]:
def select_subset_of_data(dataset: dict, seed: int, num_samples: int):
    """
    Selects a subset of data given an input dictionary.

    Arguments
    ---------
    data_set : dict
            data set
    seed : integer
            used to initialize a random number generator
    num_samples : integer
            Number of samples to be taken from the data set
    """

    # Generates random indices to select samples from the data:
    idx = random.randint(random.PRNGKey(seed), shape=(num_samples,), minval=0, maxval=dataset['targets'].shape[0])

    # Conerting into a NumPy array:
    idx = np.array(idx)

    # Filling a new dictionary:
    subset = {k: jnp.take(v, idx, axis=0) for k, v in dataset.items()}
    return subset

The MNIST Data Set

In [22]:
data_generator = nl.data.MNISTGenerator()

The Function for the Study

In [25]:
def train_model_MNIST(NN_model, training_type):
    
    """
    The function accepts a neural network model and a dataset, then trains the model using that dataset. It returns plots of ;;;;;;;; and the trained neural network.
    SaY STH ABOUT WHERE THE DATA WILL BE STORED
    SAY STH ABOUT MNIST DATA
    
    Arguments
    ---------
    NN : Bla Bla
            Neural Network model to be trained.
    
    data_set : dict
            Data set to train the neural network model with.
    
    Returns
    -------
    filtered_data_set : dict
            filtered data set
    """

    # The Training Recorders:

    train_recorder_even = nl.training_recording.JaxRecorder( 
        name="train_recorder_even", 
        storage_path=".", 
        measurements=[
            Loss(apply_fn=nl.loss_functions.CrossEntropyLoss()), 
            Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()), 
        ],
        chunk_size=1e5, 
        update_rate=1
    )

    train_recorder_odd = nl.training_recording.JaxRecorder( 
        name="train_recorder_odd", 
        storage_path=".", 
        measurements=[
                Loss(apply_fn=nl.loss_functions.CrossEntropyLoss()), 
                Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()), 
        ],
        chunk_size=1e5, 
        update_rate=1
    )

    train_recorder_even.instantiate_recorder(data_set=filter_numbers(data_generator.train_ds, 0), model=NN_model)
    train_recorder_odd.instantiate_recorder(data_set=filter_numbers(data_generator.test_ds, 1), model=NN_model)

    # The Testing Recorders:

    test_recorder_even = nl.training_recording.JaxRecorder( 
        name="test_recorder_even", 
        storage_path=".",
        measurements=[
            Loss(apply_fn=nl.loss_functions.CrossEntropyLoss()), 
            Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),
        ],
        chunk_size=1e5,
        update_rate=1
    )
    test_recorder_odd = nl.training_recording.JaxRecorder( 
        name="test_recorder_odd", 
        storage_path=".",
        measurements=[
            Loss(apply_fn=nl.loss_functions.CrossEntropyLoss()), 
            Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),
        ],
        chunk_size=1e5,
        update_rate=1
    )

    test_recorder_even.instantiate_recorder(data_set=filter_numbers(data_generator.train_ds, 0), model=NN_model)
    test_recorder_odd.instantiate_recorder(data_set=filter_numbers(data_generator.test_ds, 1), model=NN_model)

    # The Collective Variable Recorders:

    cv_recorder = nl.training_recording.JaxRecorder( 
        name="cv_recorder", 
        storage_path=".",
        measurements=[
            NTKEntropy(name="ntk_entropy"), 
            NTKTrace(name="ntk_trace"),
        ],
        chunk_size=1e5,
        update_rate=1
    )

    ntk_computation = nl.analysis.JAXNTKComputation(
        apply_fn=NN.ntk_apply_fn, 
        batch_size=10,
    )

    cv_recorder.instantiate_recorder(
        data_set=select_subset_of_data(dataset=filter_numbers(data_generator.test_ds, 1), seed=0, num_samples=100), 
        model=NN_model, 
        ntk_computation=ntk_computation
    )

    # The training Strategy:

    production_pretraining = nl.training_strategies.SimpleTraining(
        model=NN_model, 
        loss_fn=nl.loss_functions.CrossEntropyLoss(),
        accuracy_fn=nl.accuracy_functions.LabelAccuracy(), 
        recorders=[RECORDERS]
    )
    
    # Here the training of the CNN takes place:
    production_pretraining.train_model(
        train_ds=filter_numbers(data_generator.train_ds, 0),
        test_ds=filter_numbers(data_generator.test_ds, 0),
        batch_size=64,
        epochs=2000,
    )

    # Gather Reports:

    train_report_even = train_recorder_even.gather()
    train_report_odd = train_recorder_odd.gather()
    test_report_even = test_recorder_even.gather()
    test_report_odd= test_recorder_odd.gather()
    cv_report = cv_recorder.gather()

    # Plot Reports:

    fix, axs = plt.subplots(2, 2, figsize=(8, 6), tight_layout=True)

    axs[0, 0].plot(cv_report["ntk_trace"], 'o', mfc="None", label="pre-trained")
    axs[0, 1].plot(cv_report["ntk_entropy"], 'o', mfc="None", label="pre-trained")
    axs[0, 0].plot(fresh_cv_report["ntk_trace"], 'o', mfc="None", label="random")
    axs[0, 1].plot(fresh_cv_report["ntk_entropy"], 'o', mfc="None", label="random")

    axs[1, 0].plot(np.array(test_report_odd["loss"]), cv_report["ntk_trace"], 'o', mfc="None", label="pre-trained")
    axs[1, 1].plot(np.array(test_report_odd["loss"]), cv_report["ntk_entropy"], 'o', mfc="None", label="pre-trained")
    axs[1, 0].plot(np.array(test_report_odd["loss"]), fresh_cv_report["ntk_trace"], 'o', mfc="None", label="random")
    axs[1, 1].plot(np.array(test_report_odd["loss"]), fresh_cv_report["ntk_entropy"], 'o', mfc="None", label="random")

    axs[1, 0].xaxis.set_inverted(True)
    axs[1, 1].xaxis.set_inverted(True)

    axs[0, 0].set_yscale('log')
    axs[1, 0].set_yscale('log')
    axs[0, 0].set_xscale('log')
    axs[0, 1].set_xscale('log')
    axs[1, 0].set_xscale('log')
    axs[1, 1].set_xscale('log')

    axs[0, 0].set_xlabel("Epoch")
    axs[0, 1].set_xlabel("Epoch")

    axs[1, 0].set_xlabel("Test Loss")
    axs[1, 1].set_xlabel("Test Loss")

    axs[0, 0].set_ylabel("Trace")
    axs[0, 1].set_ylabel("Entropy")
    axs[1, 0].set_ylabel("Trace")
    axs[1, 1].set_ylabel("Entropy")

    axs[0, 0].legend()
    plt.show()