In [24]:
import numpy as np
import keras

def calc_feat_mean_stds(name, x):
    means = np.mean(x, axis=0)
    stds = np.std(x, axis=0)
    arr = []
    for i, (mean,std) in enumerate(zip(means, stds)):
        arr.append({"feat_id": i, "mean":mean, "std":std})
    return arr

def calc_data_dist(name, x):
    arr = []
    
    values, bin_edges = np.histogram(x)
    sum_values = sum(values)
    for bin, value in zip(bin_edges[:-1], values):
        arr.append({"bin": bin, "value": value/sum_values, "name": name})
        
    return arr

def calc_all_data_dist_2d(name, x):
    arr = []
    step = int(x.shape[1]/8)  # we want to reduce to 8x8 grid
    
    for i in range(x.shape[1]//step):
        for j in range(x.shape[2]//step):
            values, bin_edges = np.histogram(x[:, i*step, j*step])

            for bin, value in zip(bin_edges[:-1], values):
                arr.append({"x": i, "y": j, "bin": bin, "value": value, "name": name})
        
    return arr   
  
def log_all_data_dist_2d(key, x):
    arr = calc_all_data_dist_2d(key, x)
    wandb.summary[key] = arr

def log_data_dist(key, x):
    arr = calc_data_dist(key, x)
    wandb.summary[key] = arr
    #wandb.summary[name] = wandb.Histogram(x)
    
def log_data_pred(key, x):
    # need to deal with multidimensional case
    arr = np.sum(x, axis=0)
    vega_arr = []
    for i in range(arr.shape[0]):
        vega_arr.append({"bin": i, "value": arr[i]})
    wandb.summary[key] = vega_arr
    
def log_data_dists(key, x1, x2):
    arr1 = calc_data_dist("x1", x1)
    arr2 = calc_data_dist("x2", x1)
    
    wandb.summary[key] = np.concatenate((arr1, arr2))
    #wandb.summary[name] = wandb.Histogram(x)
    
def log_feat_mean_stds(key, x):
    arr1 = calc_feat_mean_stds(key, x)
    wandb.summary[key] = arr


def log_feats_mean_stds(key, x1, x2):
    arr1 = calc_feat_mean_stds("x1", x1)
    arr2 = calc_feat_mean_stds("x2", x1)
    
    wandb.summary[key] = np.concatenate((arr1, arr2))

def log_2d_data_dist(name, X_train):
    hists = []
    arr = []

    for i in range(28):
        for j in range(28):
            hist, bin_edges = np.histogram(X_train[:,i,j])
            for bin, value in zip(bin_edges[:-1], hist):
                arr.append({"x": i, "y": j, "bin": bin, "value": value})
            hists.append(arr)

    wandb.summary[name] = arr
    
def log_mnist_data(X_train, y_train, X_test, y_test):
    log_data_dist("x_train", X_train)
    log_data_dist("x_test", X_test)
    log_data_dist("y_train", y_train)
    log_data_dist("y_test", y_test)
    
    log_all_data_dist_2d("x_train_all", X_train)
    log_all_data_dist_2d("x_test_all", X_test)
    
    X_train_flat = X_train.reshape((X_train.shape[0], X_train.shape[1]*X_train.shape[2]))
    X_test_flat = X_test.reshape((X_test.shape[0], X_test.shape[1]*X_test.shape[2]))

    log_feat_mean_stds("x_train_mean_std", X_train_flat)
    log_feat_mean_stds("x_test_mean_std", X_test_flat)
    
    
class LogPredictionCallback(keras.callbacks.Callback):
    def __init__(self, sample_data, name_prefix=""):
        self.sample_data = sample_data
        self.name_prefix = name_prefix
        
    def on_epoch_end(self, epoch, logs=None):
        predictions = self.model.predict(self.sample_data[0])
        log_data_dist(self.name_prefix + "_" + "x", self.sample_data[0])
        log_data_pred(self.name_prefix + "_" + "y", self.sample_data[1])
        log_data_pred(self.name_prefix + "_" + "y_pred", model.predict(self.sample_data[0]))
        log_data_dist(self.name_prefix + "_" + "y_pred_dist", model.predict(self.sample_data[0]))


        


In [32]:
# Original MNIST -busted!

from keras.utils import to_categorical
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten

import wandb
from wandb.keras import WandbCallback

# logging code
run = wandb.init(project="debug-mnist", entity="l2k2")
config = run.config
wandb.summary["dataset"] = "orig-mnist"

# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
img_width = X_train.shape[1]
img_height = X_train.shape[2]

X_train = X_train.astype('float32')
X_train /= 255.
X_test = X_test.astype('float32')
X_test /= 255.

# reshape data
X_train = X_train.reshape(
    X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)
X_test = X_test.reshape(
    X_test.shape[0], X_test.shape[1], X_test.shape[2], 1)

# one hot encode outputs
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
labels = range(10)

num_classes = y_train.shape[1]

# create model
model = Sequential()
model.add(Conv2D(32,
                 (3,3),
                 input_shape=(28, 28, 1),
                 activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(num_classes, activation="softmax"))
model.compile(loss='categorical_crossentropy', optimizer='adam',
              metrics=['accuracy'])
model.fit(X_train, y_train_cat, epochs=10, validation_data=(X_test, y_test), 
          callbacks=[LogPredictionCallback((X_train, y_train), "train"), 
                     LogPredictionCallback((X_test, y_test), "test"),
                     WandbCallback()])

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
 6592/60000 [==>...........................] - ETA: 24s - loss: 0.0126 - acc: 0.9968

KeyboardInterrupt: 