In [None]:
###############
## Libraries ##
###############

import tensorflow as tf
import matplotlib.pyplot as plt 
import numpy as np
from tensorflow.keras import datasets, layers, models, losses
from keras.callbacks import ModelCheckpoint
from keras.models import load_model
import keras

# Import Data

In [None]:
##########################
## Load in MNIST Digits ##
##########################

all_data = np.load("/scratch/gpfs/eysu/src_data/mnist.npz")
print(all_data.files)
x_test = all_data['x_test']
x_train = all_data['x_train']
y_train = all_data['y_train']
y_test = all_data['y_test']

print(x_test.shape)
print(x_train.shape)
print(y_train.shape)
print(y_test.shape)

In [None]:
###############################
## Partition and resize data ##
###############################

labels = ["0",  # index 0
          "1",  # index 1
          "2",  # index 2 
          "3",  # index 3 
          "4",  # index 4
          "5",  # index 5
          "6",  # index 6 
          "7",  # index 7 
          "8",  # index 8 
          "9"]  # index 9

# save train labels
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

y_train_labels = y_train
y_test_labels = y_test

# Further break training data into train / validation sets (# put 5000 into validation set and keep remaining 55,000 for train)
(x_train, x_valid) = x_train[5000:], x_train[:5000] 
(y_train, y_valid) = y_train[5000:], y_train[:5000]

# Reshape input data from (28, 28) to (28, 28, 1)
w, h = 28, 28
x_train = x_train.reshape(x_train.shape[0], w, h, 1)
x_valid = x_valid.reshape(x_valid.shape[0], w, h, 1)
x_test = x_test.reshape(x_test.shape[0], w, h, 1)

# Validation set
y_valid = tf.keras.utils.to_categorical(y_valid, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

In [None]:
##############################################
## Load, partition, and resize MNIST Digits ##
##############################################

## SAME THING JUST AS A FUNCTION
def loadData():
    all_data = np.load("/scratch/gpfs/eysu/src_data/mnist.npz")

    x_test = all_data['x_test']
    x_train = all_data['x_train']
    y_train = all_data['y_train']
    y_test = all_data['y_test']

    labels = ["0",  # index 0
              "1",  # index 1
              "2",  # index 2 
              "3",  # index 3 
              "4",  # index 4
              "5",  # index 5
              "6",  # index 6 
              "7",  # index 7 
              "8",  # index 8 
              "9"]  # index 9

    # save train labels
    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255


    # y_train_labels = y_train
    # y_test_labels = y_test

    # Further break training data into train / validation sets (# put 5000 into validation set and keep remaining 55,000 for train)
    (x_train, x_valid) = x_train[5000:], x_train[:5000] 
    (y_train, y_valid) = y_train[5000:], np.array(y_train[:5000]).squeeze()

    # Reshape input data from (28, 28) to (28, 28, 1)
    w, h = 28, 28
    x_train = x_train.reshape(x_train.shape[0], w, h, 1)
    x_valid = x_valid.reshape(x_valid.shape[0], w, h, 1)
    x_test = x_test.reshape(x_test.shape[0], w, h, 1)
    
    return x_train, x_valid, x_test, y_train, y_valid, y_test

In [None]:
# Examine any image

# Image index, you can pick any number between 0 and 44,999
img_index = 0
label_index = y_train[img_index]
# Print the label, for example 2 Pullover
print("y = " + str(label_index) + " (" +(labels[label_index]) + ")")
plt.imshow(x_train[img_index])
plt.show()

# Iterated Retraining By Sampling

In [None]:
##############################
## Sampling Helper Function ##
##############################

def sample(distributions):
    N = distributions.shape[0]
    labels = [None] * N
    for i in range(N):
        label = np.random.choice(10, p=distributions[i])
        labels[i] = label
    return labels

In [None]:
##############################################################
## This cell runs the iterated learning training procedure. ##
##############################################################
for END_IDX in range(5, 6):
    # Number of iterations in the serial reproduction
    MAX_ITER = 1000
    # Number of epochs per training run
    EPOCHS = 10
     
    save_path = "/scratch/gpfs/eysu/Sampling/pretrained2_" + str(MAX_ITER) + "/"
    weight_path = "/scratch/gpfs/eysu/low_shot_weights/" + str(END_IDX) + "/"
    x_train, x_valid, x_test, y_train, y_valid, y_test = loadData()

    # create an empty array to store the new labels for every iter
    all_labels = np.zeros((x_train.shape[0], MAX_ITER + 1))

    for iteration in range(0,MAX_ITER):
        # If iteration is seed, train on original target vectors, else, train on y_hat from time t-1
        if iteration == 0:
            # Save the label and then one-hot encode the labels
            all_labels[:, 0] = y_train
            y_train = tf.keras.utils.to_categorical(y_train, 10)
            y_valid = tf.keras.utils.to_categorical(y_valid, 10)
            y_test = tf.keras.utils.to_categorical(y_test, 10)
            
            mpth = 'model.weights.best.hdf5'
            y_hat_test_name = 'y_hat_test_seed'
            y_hat_train_name = 'y_hat_train_seed'      
        elif iteration > 0:
            # Key step: set new targets as y_hat
            y_train = new_train
            mpth = 'model.weights.best.' + 'iter' + str(iteration) + '.hdf5'
            y_hat_test_name = 'y_hat_test_' + 'iter' + str(iteration)
            y_hat_train_name = 'y_hat_train_' + 'iter' + str(iteration)

        # Define the model: a small CNN model (could probably be done outside loop)
        model = tf.keras.Sequential()

        # Must define the input shape in the first layer of the neural network
        model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1))) 
        model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
        model.add(tf.keras.layers.Dropout(0.3))

        model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
        model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
        model.add(tf.keras.layers.Dropout(0.3))

        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(256, activation='relu'))
        model.add(tf.keras.layers.Dropout(0.5))
        model.add(tf.keras.layers.Dense(10, activation='softmax'))

        # Each time, use the pretrained model with the prior from the lo shot training
        model.load_weights(weight_path + 'model.weights.best.pretrain.hdf5')
        # model.summary()

        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

        # Save checkpoints
        checkpointer = ModelCheckpoint(filepath= save_path + mpth, verbose = 1, save_best_only=True) #True
        # Train the model
        model.fit(x_train,
                 y_train,
                 batch_size=64,
                 epochs=EPOCHS,
                 validation_data=(x_valid, y_valid),
                 callbacks=[checkpointer])

        # Load the weights with the best validation accuracy
        y_hat = model.predict(x_train) #feed back serial reproduction targets
        y_hat_test = model.predict(x_test)

        #### START OF SAMPLING ####

        # use helper function to sample label for every image in train 
        new_labels = np.array(sample(y_hat))

        # store new labels for all images under its corresponding iteration
        all_labels[:, iteration + 1] = new_labels

        # expand dimensions of new labels and set this as new training vector
        new_train = tf.keras.utils.to_categorical(new_labels, 10)

        #### END OF SAMPLING ####

        model.load_weights(save_path + mpth)
        # Evaluate the model on test set
        score = model.evaluate(x_test, y_test, verbose=0)
        # Print test accuracy
        print('\n', 'Test accuracy:', score[1])

        # Save results for each iteration in the serial reproduction chain
        np.save(save_path + y_hat_train_name + '.npy', y_train)
        print(save_path + y_hat_train_name)

        np.save(save_path + y_hat_test_name + '.npy', y_hat_test)
        print(save_path + y_hat_test_name)

    np.save(save_path + 'labels.npy', all_labels)
    print('Saved labels!')


# Scratch work