参考Theano官网教程，在鸢尾花数据集上训练逻辑回归多分类模型。

In [1]:
from sklearn.datasets import load_iris
import theano
import theano.tensor as T
import numpy

## 加载数据

In [2]:
datasets = load_iris()
train_set_x = theano.shared(datasets.data[:,(0,2)].astype(dtype=theano.config.floatX), borrow=True)
train_set_y = T.cast(theano.shared(datasets.target.astype(dtype=theano.config.floatX),borrow=True),'int32')
valid_set_x, valid_set_y = train_set_x, train_set_y

## 构建模型

In [3]:
# 输入数据矩阵
x = T.matrix('x')  # data, presented as rasterized images
# 输入数据类别
y = T.ivector('y')  # labels, presented as 1D vector of [int] labels

# 输入向量维度
n_in=2
# 输出向量维度
n_out=3

# 模型参数
W = theano.shared(value=numpy.zeros((n_in, n_out), dtype=theano.config.floatX),
                  name='W',
                  borrow=True)
# initialize the biases b as a vector of n_out 0s
b = theano.shared(value=numpy.zeros((n_out,), dtype=theano.config.floatX),
                  name='b',
                  borrow=True)

# 预测表达式
p_y_given_x = T.nnet.softmax(T.dot(x, W) + b)
y_pred = T.argmax(p_y_given_x, axis=1)


# 损失函数表达式
cost =  -T.mean(T.log(p_y_given_x)[T.arange(y.shape[0]), y])
# 计算梯度
g_W = T.grad(cost=cost, wrt=W)
g_b = T.grad(cost=cost, wrt=b)


# 误差表达式
errors = T.mean(T.neq(y_pred, y))

# 批次索引
index = T.lscalar()

# 一个批次所含实例数量
batch_size=150

# 学习速率
learning_rate=0.01

# 编译验证函数
validate_model = theano.function(
    inputs=[index],
    outputs=errors,
    givens={
        x: valid_set_x[index * batch_size: (index + 1) * batch_size],
        y: valid_set_y[index * batch_size: (index + 1) * batch_size]
    }
)

# 编译训练函数
train_model = theano.function(
    inputs=[index],
    outputs=cost,
    updates=[(W, W - learning_rate * g_W),
           (b, b - learning_rate * g_b)],
    givens={
        x: train_set_x[index * batch_size: (index + 1) * batch_size],
        y: train_set_y[index * batch_size: (index + 1) * batch_size]
    }
)

## 训练模型

In [4]:
# compute number of minibatches for training, validation and testing
n_train_batches = train_set_x.get_value(borrow=True).shape[0] // batch_size
n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] // batch_size

# 周期数
n_epochs=1000
epoch = 0
while epoch < n_epochs:
    epoch = epoch + 1
    for minibatch_index in range(n_train_batches):
        minibatch_avg_cost = train_model(minibatch_index)
        # iteration number
        iter = (epoch - 1) * n_train_batches + minibatch_index
        # compute zero-one loss on validation set
        validation_losses = [validate_model(i)
                                for i in range(n_valid_batches)]
        this_validation_loss = numpy.mean(validation_losses)

        print(
            'epoch %i, minibatch %i/%i, validation error %f %%' %
            (
                epoch,
                minibatch_index + 1,
                n_train_batches,
                this_validation_loss * 100.
            )
        )

epoch 1, minibatch 1/1, validation error 66.666667 %
epoch 2, minibatch 1/1, validation error 66.666667 %
epoch 3, minibatch 1/1, validation error 66.666667 %
epoch 4, minibatch 1/1, validation error 66.666667 %
epoch 5, minibatch 1/1, validation error 66.666667 %
epoch 6, minibatch 1/1, validation error 66.666667 %
epoch 7, minibatch 1/1, validation error 66.666667 %
epoch 8, minibatch 1/1, validation error 66.666667 %
epoch 9, minibatch 1/1, validation error 66.666667 %
epoch 10, minibatch 1/1, validation error 66.666667 %
epoch 11, minibatch 1/1, validation error 66.666667 %
epoch 12, minibatch 1/1, validation error 66.666667 %
epoch 13, minibatch 1/1, validation error 66.666667 %
epoch 14, minibatch 1/1, validation error 66.666667 %
epoch 15, minibatch 1/1, validation error 66.666667 %
epoch 16, minibatch 1/1, validation error 66.666667 %
epoch 17, minibatch 1/1, validation error 66.666667 %
epoch 18, minibatch 1/1, validation error 66.666667 %
epoch 19, minibatch 1/1, validation e

epoch 269, minibatch 1/1, validation error 24.666667 %
epoch 270, minibatch 1/1, validation error 24.666667 %
epoch 271, minibatch 1/1, validation error 24.666667 %
epoch 272, minibatch 1/1, validation error 24.666667 %
epoch 273, minibatch 1/1, validation error 24.666667 %
epoch 274, minibatch 1/1, validation error 24.666667 %
epoch 275, minibatch 1/1, validation error 24.666667 %
epoch 276, minibatch 1/1, validation error 24.666667 %
epoch 277, minibatch 1/1, validation error 24.666667 %
epoch 278, minibatch 1/1, validation error 24.666667 %
epoch 279, minibatch 1/1, validation error 24.666667 %
epoch 280, minibatch 1/1, validation error 24.666667 %
epoch 281, minibatch 1/1, validation error 24.666667 %
epoch 282, minibatch 1/1, validation error 24.666667 %
epoch 283, minibatch 1/1, validation error 24.666667 %
epoch 284, minibatch 1/1, validation error 24.666667 %
epoch 285, minibatch 1/1, validation error 24.666667 %
epoch 286, minibatch 1/1, validation error 24.666667 %
epoch 287,

epoch 536, minibatch 1/1, validation error 13.333333 %
epoch 537, minibatch 1/1, validation error 13.333333 %
epoch 538, minibatch 1/1, validation error 13.333333 %
epoch 539, minibatch 1/1, validation error 13.333333 %
epoch 540, minibatch 1/1, validation error 13.333333 %
epoch 541, minibatch 1/1, validation error 13.333333 %
epoch 542, minibatch 1/1, validation error 13.333333 %
epoch 543, minibatch 1/1, validation error 13.333333 %
epoch 544, minibatch 1/1, validation error 13.333333 %
epoch 545, minibatch 1/1, validation error 13.333333 %
epoch 546, minibatch 1/1, validation error 13.333333 %
epoch 547, minibatch 1/1, validation error 13.333333 %
epoch 548, minibatch 1/1, validation error 12.666667 %
epoch 549, minibatch 1/1, validation error 12.666667 %
epoch 550, minibatch 1/1, validation error 12.666667 %
epoch 551, minibatch 1/1, validation error 12.000000 %
epoch 552, minibatch 1/1, validation error 12.000000 %
epoch 553, minibatch 1/1, validation error 12.000000 %
epoch 554,

epoch 780, minibatch 1/1, validation error 10.000000 %
epoch 781, minibatch 1/1, validation error 10.000000 %
epoch 782, minibatch 1/1, validation error 10.000000 %
epoch 783, minibatch 1/1, validation error 10.000000 %
epoch 784, minibatch 1/1, validation error 10.000000 %
epoch 785, minibatch 1/1, validation error 10.000000 %
epoch 786, minibatch 1/1, validation error 10.000000 %
epoch 787, minibatch 1/1, validation error 10.000000 %
epoch 788, minibatch 1/1, validation error 10.000000 %
epoch 789, minibatch 1/1, validation error 10.000000 %
epoch 790, minibatch 1/1, validation error 10.000000 %
epoch 791, minibatch 1/1, validation error 10.000000 %
epoch 792, minibatch 1/1, validation error 10.000000 %
epoch 793, minibatch 1/1, validation error 10.000000 %
epoch 794, minibatch 1/1, validation error 10.000000 %
epoch 795, minibatch 1/1, validation error 10.000000 %
epoch 796, minibatch 1/1, validation error 10.000000 %
epoch 797, minibatch 1/1, validation error 10.000000 %
epoch 798,