In [None]:
import numpy
from theano import tensor
from fuel.streams import DataStream
from fuel.schemes import ShuffledScheme
from fuel.datasets.iris import Iris
from fuel.transformers import Mapping
from blocks.bricks import Linear, Softmax
from blocks.bricks.cost import MisclassificationRate
from blocks.initialization import Uniform, Constant
from blocks.graph import ComputationGraph
from blocks.algorithms import GradientDescent, Scale
from blocks.main_loop import MainLoop
from blocks.extensions import Timing, FinishAfter, Printing
from blocks.extensions.monitoring import TrainingDataMonitoring
from blocks.extras.extensions.plot import Plot

# Params

In [None]:
learning_rate = 0.01
nclasses = 3
nfeatures = 4
batch_size = 32
nepochs = 300

# Data

In [None]:
dataset = Iris(which_sets=('all',))
scheme = ShuffledScheme(examples=dataset.num_examples, batch_size=batch_size)
stream = DataStream(dataset, iteration_scheme=scheme)

## One-hot representation

In [None]:
I = numpy.eye(nclasses, dtype=int)
def one_hot(data):
    return data[0], I[data[1].flatten()]
stream = Mapping(stream, one_hot)

# Model

In [None]:
x = tensor.matrix('features')
y = tensor.lmatrix('targets')
linear = Linear(nfeatures, nclasses,
                weights_init=Constant(0), biases_init=Constant(0))
linear.initialize()
linear_output = linear.apply(x)
softmax = Softmax()
y_hat = softmax.apply(linear_output)

## cost

In [None]:
cost = softmax.categorical_cross_entropy(y, linear_output).mean()
error = MisclassificationRate().apply(y.nonzero()[1], y_hat)
error.name = 'error'
cost.name = 'cost'

# Algorithm

In [None]:
cg = ComputationGraph(cost)
algorithm = GradientDescent(cost=cost, parameters=cg.parameters,
                            step_rule=Scale(learning_rate))

# extensions

In [None]:
monitor = TrainingDataMonitoring([cost, error], prefix='tra', after_batch=True)
extensions=[monitor, Printing(), Timing(), FinishAfter(after_n_epochs=nepochs)]

# Main loop

In [None]:
loop = MainLoop(algorithm, stream, extensions=extensions)
loop.run()

In [None]:
loop.profile.report()