In [None]:
import os
os.environ["THEANO_FLAGS"] = "device=gpu"
import numpy as np
from lasagnekit.easy import BatchOptimizer, LightweightModel
from lasagnekit.datasets.mnist import MNIST
from sklearn.utils import shuffle
from sklearn.cross_validation import train_test_split
from lasagnekit.easy import iterate_minibatches
from lasagne import layers, updates, init, nonlinearities
import theano.tensor as T
from theano.sandbox import rng_mrg
import theano
import matplotlib.pyplot as plt
import numpy as np
from lasagne.layers import get_all_layers

from lasagnekit.generative.capsule import Capsule
from lasagnekit.easy import BatchIterator, iterate_minibatches
import lasagne
from collections import OrderedDict

In [None]:
data = MNIST()
data.load()

In [None]:
X, y = data.X, data.y
X, y = shuffle(X, y)
input_dim = X.shape[1]
output_dim = data.output_dim
train, test = train_test_split(range(X.shape[0]), test_size=0.25)
w, h = data.img_dim

In [None]:
class MyBatchOptimizer(BatchOptimizer):
    
    def iter_update(self, epoch, nb_batches, iter_update_batch):
        status = super(MyBatchOptimizer, self).iter_update(epoch, nb_batches, iter_update_batch)                        
        
        for indices, name in zip( (train, test), ("train", "test") ):
            Xs = X[indices]
            ys = y[indices]
            
            m_mean = 0
            m_var = 0
            k = 0
            for ind in iterate_minibatches(len(indices), 128):
                acc = (self.model.predict(Xs[ind]) != ys[ind])
                m_mean += acc.mean()
                m_var += acc.var()
                k += 1
            status["t_" +  name + "_mean"] = m_mean / k
            status["t_" +  name + "_std"] = np.sqrt(m_var) / k
            #status["swt_" + name] = (self.model.student_predict_with_teacher(Xs) != ys).mean()
            #status["s_" + name] = (self.model.student_predict(Xs) != ys).mean()
        #status["W"] = np.abs(l_S_hint.W.get_value()).sum()
        return status
    
class Model:
    def get_all_params(self, **tags):
        return list( set(self.x_to_y.get_all_params(**tags) 
                         +self.S_x_to_y.get_all_params(**tags)
                    ))

In [None]:
l_in = layers.InputLayer((None, X.shape[1]))

# teacher

model = "cnn"

if model == "mlp":
    l_in_drop = lasagne.layers.DropoutLayer(l_in, p=0.2)
    l_hid1 = lasagne.layers.DenseLayer(
            l_in_drop, num_units=800,
            nonlinearity=lasagne.nonlinearities.rectify,
            W=lasagne.init.GlorotUniform())
    l_hid1_drop = lasagne.layers.DropoutLayer(l_hid1, p=0.5)
    l_hid2 = lasagne.layers.DenseLayer(
            l_hid1_drop, num_units=800,
            nonlinearity=lasagne.nonlinearities.rectify)
    l_hid2_drop = lasagne.layers.DropoutLayer(l_hid2, p=0.5)
    l_course = [l_hid2]
    l_out = lasagne.layers.DenseLayer(
            l_hid2_drop, num_units=output_dim,
            nonlinearity=lasagne.nonlinearities.softmax)
elif model == "cnn":
    
    network = lasagne.layers.ReshapeLayer( l_in, ([0], 1, w, h) )
    network = lasagne.layers.Conv2DLayer(
            network, num_filters=64, filter_size=(5, 5),
            nonlinearity=lasagne.nonlinearities.rectify,
            W=lasagne.init.GlorotUniform())
    network = lasagne.layers.MaxPool2DLayer(network, pool_size=(2, 2))
    network = lasagne.layers.Conv2DLayer(
            network, num_filters=128, filter_size=(5, 5),
            nonlinearity=lasagne.nonlinearities.rectify)
    network = lasagne.layers.MaxPool2DLayer(network, pool_size=(2, 2))
    network = lasagne.layers.DenseLayer(
            network,
            #lasagne.layers.dropout(network, p=.5),
            num_units=500,
            nonlinearity=lasagne.nonlinearities.rectify)
    l_course = [network]
    l_pre_out = lasagne.layers.DenseLayer(
                    network,
                    #lasagne.layers.dropout(network, p=.5),
                    num_units=output_dim,
                    nonlinearity=lasagne.nonlinearities.linear)
    l_out = lasagne.layers.NonlinearityLayer(l_pre_out, lasagne.nonlinearities.softmax)

#student
l_S_pre_hint = layers.ConcatLayer(l_course, axis=1)
l_S_hid = lasagne.layers.DenseLayer(l_in, 100, nonlinearity=lasagne.nonlinearities.rectify)
l_S_hint = lasagne.layers.DenseLayer(l_S_pre_hint, 100, nonlinearity=lasagne.nonlinearities.rectify)
l_S_repr = lasagne.layers.ConcatLayer([l_S_hid, l_S_hint], axis=1)
#l_S_repr = l_S_hid
l_S_out = lasagne.layers.DenseLayer(l_S_repr, num_units=output_dim, nonlinearity=lasagne.nonlinearities.softmax)
print(l_S_hid.output_shape)
#Model
x_to_y = LightweightModel([l_in], [l_out])
S_x_to_y = LightweightModel([l_in], [l_S_out])
model = Model()
model.x_to_y = x_to_y
model.S_x_to_y = S_x_to_y

In [None]:
input_variables = OrderedDict()
input_variables["X"] = dict(tensor_type=T.matrix)
input_variables["y"] = dict(tensor_type=T.ivector)
    

functions = dict(
    predict=dict(
        get_output=lambda model, X:(model.x_to_y.get_output(X, deterministic=True)[0]).argmax(axis=1),
        params=["X"]
    ),
    student_predict_with_teacher=dict(
        get_output=lambda model, X:(model.S_x_to_y.get_output(X, deterministic=True)[0]).argmax(axis=1),
        params=["X"]
    ),
    #student_predict=dict(
    #    get_output=lambda model, X: (layers.get_output(l_S_out, 
    #                                                   {l_in: X, l_S_hint: T.ones( (X.shape[0], l_S_hint.output_shape[1]) )  * l_S_hint.b  },
    #                                                    deterministic=True)).argmax(axis=1),
    #    params=["X"]
    #)
)

batch_optimizer = MyBatchOptimizer(
    verbose=1,
    max_nb_epochs=300,
    batch_size=100,
    optimization_procedure=(updates.momentum, 
                            {"learning_rate": 0.001})
)

def loss_function(model, tensors):
    x_to_y, S_x_to_y = model.x_to_y, model.S_x_to_y
    X_batch, y_batch = tensors["X"], tensors["y"]
    
    y, = x_to_y.get_output(X_batch)
    S_y, = S_x_to_y.get_output(X_batch)
    
    a = T.nnet.categorical_crossentropy(y, y_batch)
    
    #pre_out = lasagne.layers.get_output(l_pre_out, X_batch)
    
    b = T.nnet.categorical_crossentropy(S_y, y_batch)
    #b = ((S_y - pre_out) ** 2).sum(axis=1)

    lbda = 0.01
    #c = lbda * T.abs_(l_S_hint.W).sum()
    
    #d = ((lasagne.layers.get_output(l_S_hid, X) - lasagne.layers.get_output(l_S_hint, X))**2).sum(axis=1)
    return (a + b).mean() #+ lbda * c
    
        
capsule = Capsule(
    input_variables, 
    model,
    loss_function,
    functions=functions,
    batch_optimizer=batch_optimizer
)


In [None]:
capsule.fit(X=X[train], y=y[train])

In [None]:
%matplotlib inline
from lasagnekit.easy import get_stat
stats = list(get_stat("t_test_mean", batch_optimizer.stats))
plt.plot(get_stat("t_test_mean", batch_optimizer.stats), label="teacher")
#plt.plot(get_stat("swt_test", batch_optimizer.stats), label="student with teacher")
#plt.plot(get_stat("s_test", batch_optimizer.stats), label="student")
plt.legend()
plt.show()

plt.plot(get_stat("W", batch_optimizer.stats), label="W norm")
plt.show()

In [None]:
l_S_hint.W.set_value(np.zeros(l_S_hint.W.get_value().shape).astype(np.float32))

In [None]:
capsule.student_predict_with_teacher(X[0:1000])==capsule.student_predict(X[0:1000])

In [None]:
plt.imshow(l_S_hint.W.get_value(), cmap="gray")
print(l_S_hint.W.get_value().max())