In [2]:
import numpy as np
import theano
import theano.tensor as T
import lasagne
import time
from random import shuffle
import math
from tqdm import tqdm

In [8]:
""" 
Возвращает полносвязную сеть

Параметры
    input_X : theano переменная
    input_shape : размер
    num_layers : количество внутренних слоёв
    num_units : массив, количество нейронов на каждом внутреннем слое
    set_params : если не равно None, то там должны быть веса, которые нужно установить в сеть 
    output: bool, показывает, нужно ли делать для нашей сети последний слой
"""
def fully_connected(input_X, input_shape, num_layers, num_units, set_params=None, output=True):
    assert (num_layers == len(num_units))
    
    net = lasagne.layers.InputLayer(shape = input_shape, input_var=input_X)
    
    for i in range(num_layers): 
        net = lasagne.layers.DenseLayer(net, num_units=num_units[i], nonlinearity=lasagne.nonlinearities.rectify, 
                                        b=None, W=lasagne.init.HeNormal())
        
    if (output):
        net = lasagne.layers.DenseLayer(net, num_units = 10, nonlinearity = lasagne.nonlinearities.softmax, b=None)
    
    if (set_params != None):
        lasagne.layers.set_all_param_values(net, set_params)
    
    return net

In [4]:
"""
Возвращает train_function и accuracy_function для сети

Параметры:
    net: сеть (последний слой сети)
    input_X: данные, которые подаются на вход сети
    tаrget_y: правильные ответы
"""
def set_train_fun(net, input_X, target_y):
    y_predicted = lasagne.layers.get_output(net)
    all_weights = lasagne.layers.get_all_params(net, trainable=True)
    
    loss = lasagne.objectives.categorical_crossentropy(y_predicted,target_y).mean()
    accuracy = lasagne.objectives.categorical_accuracy(y_predicted,target_y).mean()

    updates = lasagne.updates.adam(loss, all_weights)
    
    train_fun = theano.function([input_X,target_y],[loss,accuracy],updates= updates)
    accuracy_fun = theano.function([input_X,target_y],accuracy)
    
    return train_fun, accuracy_fun

In [5]:
def iterate_minibatches(X, y, batchsize):   
    indices = np.arange(len(X))
    np.random.shuffle(indices)
    for start_idx in range(0, len(X) - batchsize + 1, batchsize):
        excerpt = indices[start_idx:start_idx + batchsize]
        yield X[excerpt], y[excerpt]

In [6]:
def train(net, train_fun, accuracy_fun, X_train, y_train, X_val, y_val, num_epoch=50, batch_size=50):
    num_epochs = num_epoch #количество проходов по данным

    batch_size = batch_size #размер мини-батча

    for epoch in tqdm(range(num_epochs)):
        # In each epoch, we do a full pass over the training data:
        train_err = 0
        train_acc = 0
        train_batches = 0
        start_time = time.time()
        for batch in iterate_minibatches(X_train, y_train, batch_size):
            inputs, targets = batch
            train_err_batch, train_acc_batch= train_fun(inputs, targets)
            train_err += train_err_batch
            train_acc += train_acc_batch
            train_batches += 1

        # And a full pass over the validation data:
        val_acc = 0
        val_batches = 0
        for batch in iterate_minibatches(X_val, y_val, batch_size):
            inputs, targets = batch
            val_acc += accuracy_fun(inputs, targets)
            val_batches += 1


        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}s".format(
            epoch + 1, num_epochs, time.time() - start_time))

        print("  training loss (in-iteration):\t\t{:.6f}".format(train_err / train_batches))
        print("  train accuracy:\t\t{:.2f} %".format(
            train_acc / train_batches * 100))
        print("  validation accuracy:\t\t{:.2f} %".format(
            val_acc / val_batches * 100))

In [7]:
def test(X_test, y_test, accuracy_fun):
    test_acc = 0
    test_batches = 0
    for batch in iterate_minibatches(X_test, y_test, 500):
        inputs, targets = batch
        acc = accuracy_fun(inputs, targets)
        test_acc += acc
        test_batches += 1
    print("Final results:")
    print("  test accuracy:\t\t{:.2f} %".format(
        test_acc / test_batches * 100))