<a href="https://colab.research.google.com/github/kundajelab/label_shift_experiments/blob/master/cifar10/Download_CIFAR10_models_from_zenodo_and_make_predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import print_function
import keras
from keras.models import load_model
from keras.models import Sequential, Model
print("keras version:", keras.__version__)
import tensorflow as tf
print("tensorflow version:", tf.__version__)
import random
import os
import sys
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D, Activation
from keras import backend as K
from keras.callbacks import EarlyStopping

Using TensorFlow backend.


keras version: 2.2.4
tensorflow version: 1.14.0


In [2]:
batch_size = 128
num_classes = 10
epochs = 50

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

full_x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
full_x_train /= 255
x_test /= 255
x_valid = full_x_train[-10000:]
print('x_train shape:', full_x_train.shape)
print(full_x_train.shape[0], 'train samples')
print(x_valid.shape[0], 'valid samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
full_y_train = keras.utils.to_categorical(y_train, num_classes)
y_valid = full_y_train[-10000:]
y_test = keras.utils.to_categorical(y_test, num_classes)

output_file = "test_labels.txt"
f = open(output_file, 'w')
f.write("\n".join(["\t".join([str(x) for x in y]) for y in y_test]))
f.close()
os.system("gzip -f "+output_file)

output_file = "valid_labels.txt"
f = open(output_file, 'w')
f.write("\n".join(["\t".join([str(x) for x in y]) for y in y_valid]))
f.close()
os.system("gzip -f "+output_file)

output_file = "train_labels.txt"
f = open(output_file, 'w')
f.write("\n".join(["\t".join([str(x) for x in y]) for y in full_y_train]))
f.close()
os.system("gzip -f "+output_file)

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 valid samples
10000 test samples


0

In [3]:
for model_idx,train_set_size in enumerate([250, 500, 1000, 2000, 4000,
                                           8000, 16000]):
    np.random.seed(model_idx*100)
    random.seed(model_idx*100)
    print("On train set size",train_set_size)

    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes))
    model.add(Activation("softmax"))

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])
    x_train = full_x_train[:train_set_size] 
    y_train = full_y_train[:train_set_size]
    print("Mean y train:",np.mean(y_train, axis=0))
    print("Mean y valid:",np.mean(y_valid, axis=0))
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_valid, y_valid),
              callbacks=[EarlyStopping(
                monitor='val_loss', patience=5,
                restore_best_weights=True)])
    model.save("model_trainsize-"+str(train_set_size)+".h5")
    
    model_files = [ "model_mnist_set-"+str(train_set_size)+"_seed-"+str(seed)+".h5" for seed in range(0,100,10) ]
    for model_file in model_files:
        pre_softmax_model = Model(input=model.input,
                            output=model.layers[-2].output)
        print("Making predictions on validation set")
        valid_preacts = pre_softmax_model.predict(x_valid)
        print("Making predictions on test set")
        test_preacts = pre_softmax_model.predict(x_test)
        print('Test accuracy:', np.mean(np.argmax(test_preacts,axis=-1)
                                        ==np.argmax(y_test,axis=-1)))
        print('Valid accuracy:', np.mean(np.argmax(valid_preacts,axis=-1)
                                        ==np.argmax(y_valid,axis=-1)))
        sys.stdout.flush()
        test_predictions_file = ("testpreacts_"+model_file.split(".")[0])+".txt"
        print("Saving", test_predictions_file)
        f = open(test_predictions_file,'w')
        for test_preact in test_preacts:
            f.write("\t".join([str(x) for x in test_preact])+"\n") 
        f.close()
        !md5sum $test_predictions_file
        !gzip $test_predictions_file
        
        valid_predictions_file = ("validpreacts_"+model_file.split(".")[0])+".txt"
        print("Saving", valid_predictions_file)
        f = open(valid_predictions_file,'w')
        for valid_preact in valid_preacts:
            f.write("\t".join([str(x) for x in valid_preact])+"\n") 
        f.close()
        !md5sum $valid_predictions_file
        !gzip $valid_predictions_file

On train set size 250





Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


Mean y train: [0.12  0.136 0.092 0.108 0.096 0.068 0.096 0.104 0.076 0.104]
Mean y valid: [0.0991 0.1064 0.099  0.103  0.0983 0.0915 0.0967 0.109  0.1009 0.0961]
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 250 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50




Making predictions on validation set
Making predictions on test set
Test accuracy: 0.8225
Valid accuracy: 0.833
Saving testpreacts_model_mnist_set-250_seed-0.txt
ce628bdc0b41bb65f596ff615507d791  testpreacts_model_mnist_set-250_seed-0.txt
Saving validpreacts_model_mnist_set-250_seed-0.txt
cff467e2eb54c4ff177a5fd8a12302c5  validpreacts_model_mnist_set-250_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.8225
Valid accuracy: 0.833
Saving testpreacts_model_mnist_set-250_seed-10.txt
ce628bdc0b41bb65f596ff615507d791  testpreacts_model_mnist_set-250_seed-10.txt
Saving validpreacts_model_mnist_set-250_seed-10.txt
cff467e2eb54c4ff177a5fd8a12302c5  validpreacts_model_mnist_set-250_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.8225
Valid accuracy: 0.833
Saving testpreacts_model_mnist_set-250_seed-20.txt
ce628bdc0b41bb65f596ff615507d791  testpreacts_model_mnist_set-250_seed-20.txt
Saving validpreacts



Making predictions on validation set
Making predictions on test set
Test accuracy: 0.8889
Valid accuracy: 0.8934
Saving testpreacts_model_mnist_set-500_seed-0.txt
e616f18a2fb58e78edd6c26ab881ddc0  testpreacts_model_mnist_set-500_seed-0.txt
Saving validpreacts_model_mnist_set-500_seed-0.txt
aed6687664c2895b9601d1f8f52a9a9a  validpreacts_model_mnist_set-500_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.8889
Valid accuracy: 0.8934
Saving testpreacts_model_mnist_set-500_seed-10.txt
e616f18a2fb58e78edd6c26ab881ddc0  testpreacts_model_mnist_set-500_seed-10.txt
Saving validpreacts_model_mnist_set-500_seed-10.txt
aed6687664c2895b9601d1f8f52a9a9a  validpreacts_model_mnist_set-500_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.8889
Valid accuracy: 0.8934
Saving testpreacts_model_mnist_set-500_seed-20.txt
e616f18a2fb58e78edd6c26ab881ddc0  testpreacts_model_mnist_set-500_seed-20.txt
Saving validprea



Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9405
Valid accuracy: 0.9447
Saving testpreacts_model_mnist_set-1000_seed-0.txt
0389f95c8e61ffa06d99b42e0ffd1961  testpreacts_model_mnist_set-1000_seed-0.txt
Saving validpreacts_model_mnist_set-1000_seed-0.txt
ea20023fbbd59ee83c008a751d2e1f51  validpreacts_model_mnist_set-1000_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9405
Valid accuracy: 0.9447
Saving testpreacts_model_mnist_set-1000_seed-10.txt
0389f95c8e61ffa06d99b42e0ffd1961  testpreacts_model_mnist_set-1000_seed-10.txt
Saving validpreacts_model_mnist_set-1000_seed-10.txt
ea20023fbbd59ee83c008a751d2e1f51  validpreacts_model_mnist_set-1000_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9405
Valid accuracy: 0.9447
Saving testpreacts_model_mnist_set-1000_seed-20.txt
0389f95c8e61ffa06d99b42e0ffd1961  testpreacts_model_mnist_set-1000_seed-20.txt
Saving



Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9573
Valid accuracy: 0.9592
Saving testpreacts_model_mnist_set-2000_seed-0.txt
8bd385cf0205edaa12708fbc3f345ccf  testpreacts_model_mnist_set-2000_seed-0.txt
Saving validpreacts_model_mnist_set-2000_seed-0.txt
2a5ebfa81a8e9142ea71035f853f764d  validpreacts_model_mnist_set-2000_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9573
Valid accuracy: 0.9592
Saving testpreacts_model_mnist_set-2000_seed-10.txt
8bd385cf0205edaa12708fbc3f345ccf  testpreacts_model_mnist_set-2000_seed-10.txt
Saving validpreacts_model_mnist_set-2000_seed-10.txt
2a5ebfa81a8e9142ea71035f853f764d  validpreacts_model_mnist_set-2000_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9573
Valid accuracy: 0.9592
Saving testpreacts_model_mnist_set-2000_seed-20.txt
8bd385cf0205edaa12708fbc3f345ccf  testpreacts_model_mnist_set-2000_seed-20.txt
Saving



Making predictions on validation set
Making predictions on test set
Test accuracy: 0.969
Valid accuracy: 0.971
Saving testpreacts_model_mnist_set-4000_seed-0.txt
392689ed5924e7aa62afebea683de69d  testpreacts_model_mnist_set-4000_seed-0.txt
Saving validpreacts_model_mnist_set-4000_seed-0.txt
488d12e69ab40fc4fa794214efff0645  validpreacts_model_mnist_set-4000_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.969
Valid accuracy: 0.971
Saving testpreacts_model_mnist_set-4000_seed-10.txt
392689ed5924e7aa62afebea683de69d  testpreacts_model_mnist_set-4000_seed-10.txt
Saving validpreacts_model_mnist_set-4000_seed-10.txt
488d12e69ab40fc4fa794214efff0645  validpreacts_model_mnist_set-4000_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.969
Valid accuracy: 0.971
Saving testpreacts_model_mnist_set-4000_seed-20.txt
392689ed5924e7aa62afebea683de69d  testpreacts_model_mnist_set-4000_seed-20.txt
Saving valid



Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9774
Valid accuracy: 0.9799
Saving testpreacts_model_mnist_set-8000_seed-0.txt
6ce7f26314d6563adcfb37dd1a94b5ad  testpreacts_model_mnist_set-8000_seed-0.txt
Saving validpreacts_model_mnist_set-8000_seed-0.txt
bea6213a04a983c8dfd633e9666414b9  validpreacts_model_mnist_set-8000_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9774
Valid accuracy: 0.9799
Saving testpreacts_model_mnist_set-8000_seed-10.txt
6ce7f26314d6563adcfb37dd1a94b5ad  testpreacts_model_mnist_set-8000_seed-10.txt
Saving validpreacts_model_mnist_set-8000_seed-10.txt
bea6213a04a983c8dfd633e9666414b9  validpreacts_model_mnist_set-8000_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9774
Valid accuracy: 0.9799
Saving testpreacts_model_mnist_set-8000_seed-20.txt
6ce7f26314d6563adcfb37dd1a94b5ad  testpreacts_model_mnist_set-8000_seed-20.txt
Saving



Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9863
Valid accuracy: 0.9866
Saving testpreacts_model_mnist_set-16000_seed-0.txt
d404a5e69e63b69e90a5d62d8fc29b87  testpreacts_model_mnist_set-16000_seed-0.txt
Saving validpreacts_model_mnist_set-16000_seed-0.txt
f150050585ec3ea569056e36a56bbb8c  validpreacts_model_mnist_set-16000_seed-0.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9863
Valid accuracy: 0.9866
Saving testpreacts_model_mnist_set-16000_seed-10.txt
d404a5e69e63b69e90a5d62d8fc29b87  testpreacts_model_mnist_set-16000_seed-10.txt
Saving validpreacts_model_mnist_set-16000_seed-10.txt
f150050585ec3ea569056e36a56bbb8c  validpreacts_model_mnist_set-16000_seed-10.txt
Making predictions on validation set
Making predictions on test set
Test accuracy: 0.9863
Valid accuracy: 0.9866
Saving testpreacts_model_mnist_set-16000_seed-20.txt
d404a5e69e63b69e90a5d62d8fc29b87  testpreacts_model_mnist_set-16000_seed-20.