In [1]:
require 'hdf5';
require 'nn';

In [2]:
function build_model(dwin, nchar, nclass, hid1, hid2)
    -- Model with skip layer from Bengio, standards parameters
    -- should be:
    -- dwin = 5
    -- hid1 = 30
    -- hid2 = 100

    -- To store the whole model
    local dnnlm = nn.Sequential()

    -- Layer to embedd (and put the words along the window into one vector)
    local LT = nn.Sequential()
    local LT_ = nn.LookupTable(nchar,hid1)
    LT:add(LT_)
    LT:add(nn.View(-1, hid1*dwin))

    dnnlm:add(LT)

    local concat = nn.ConcatTable()

    local lin_tanh = nn.Sequential()
    lin_tanh:add(nn.Linear(hid1*dwin,hid2))
    lin_tanh:add(nn.Tanh())

    local id = nn.Identity()

    concat:add(lin_tanh)
    concat:add(id)

    dnnlm:add(concat)
    dnnlm:add(nn.JoinTable(2))
    dnnlm:add(nn.Linear(hid1*dwin + hid2, nclass))
    dnnlm:add(nn.LogSoftMax())

    -- Loss
    local criterion = nn.ClassNLLCriterion()

    return dnnlm, criterion
end

function compute_perplexity(gram_input, nnlm, N)
    local perp = 0
    local context = torch.zeros(N-1)
    local probability = torch.zeros(2)
    -- Do not predict for the last char
    --for i=1,gram_input:size(1)-N do
    local size=gram_input:size(1) - (N-1)
    for i=1,size do
        context:copy(gram_input:narrow(1,i,N-1))
        -- Line where the model appears
        probability:copy(nnlm:forward(context))
        if gram_input[i+(N-1)] == 1 then
            right_proba = probability[1]
        else
            right_proba = probability[2]
        end
        perp = perp + right_proba
    end
    perp = math.exp(-perp/size)
    return perp
end

function train_model(train_input, train_output, dnnlm, criterion, dwin, nclass, eta, nEpochs, batchSize, valid, N)
    -- Train the model with a mini batch SGD
    -- standard parameters are
    -- nEpochs = 1
    -- batchSize = 32
    -- eta = 0.01

    -- To store the loss
    local av_L = 0

    -- Memory allocation
    local inputs_batch = torch.DoubleTensor(batchSize,dwin)
    local targets_batch = torch.DoubleTensor(batchSize)
    local outputs = torch.DoubleTensor(batchSize, nclass)
    local df_do = torch.DoubleTensor(batchSize, nclass)
    
    local train_perplexity = torch.DoubleTensor(nEpochs)
    local valid_perplexity = torch.DoubleTensor(nEpochs)
    for i = 1, nEpochs do
        -- timing the epoch
        local timer = torch.Timer()

        av_L = 0
        
        -- max renorm of the lookup table
        dnnlm:get(1):get(1).weight:renorm(2,1,1)
        
        -- mini batch loop
        for t = 1, train_input:size(1), batchSize do
            -- Mini batch data
            local current_batch_size = math.min(batchSize,train_input:size(1)-t)
            inputs_batch:narrow(1,1,current_batch_size):copy(train_input:narrow(1,t,current_batch_size))
            targets_batch:narrow(1,1,current_batch_size):copy(train_output:narrow(1,t,current_batch_size))
            
            -- reset gradients
            dnnlm:zeroGradParameters()
            --gradParameters:zero()

            -- Forward pass (selection of inputs_batch in case the batch is not full, ie last batch)
            outputs:narrow(1,1,current_batch_size):copy(dnnlm:forward(inputs_batch:narrow(1,1,current_batch_size)))

            -- Average loss computation
            local f = criterion:forward(outputs:narrow(1,1,current_batch_size), targets_batch:narrow(1,1,current_batch_size))
            av_L = av_L +f

            -- Backward pass
            df_do:narrow(1,1,current_batch_size):copy(criterion:backward(outputs:narrow(1,1,current_batch_size), targets_batch:narrow(1,1,current_batch_size)))
            dnnlm:backward(inputs_batch:narrow(1,1,current_batch_size), df_do:narrow(1,1,current_batch_size))
            dnnlm:updateParameters(eta)
            
        end
            
        print('Epoch '..i..': '..timer:time().real)
        print('Average Loss: '..av_L/math.floor(train_input:size(1)/batchSize))
        
        train_perplexity[i] = math.exp(av_L/math.floor(train_input:size(1)/batchSize))
        valid_perplexity[i] = compute_perplexity(valid, dnnlm, N)
    end

    return train_perplexity, valid_perplexity

end

In [3]:
myFile = hdf5.open('../data_preprocessed/4-grams.hdf5','r')
data = myFile:all()
myFile:close()

N = 4

train_input = data['input_matrix_train']
train_output = data['output_matrix_train']
input_data_train = data['input_data_train']

input_data_valid = data['input_data_valid_nospace']:clone()

input_data_test = data['input_data_test']:clone()

In [6]:
torch.manualSeed(1)
nnlm1, crit = build_model(N-1, 49, 2, 16, 80)
perp_train_3, perp_valid_3 = train_model(train_input, train_output, nnlm1, crit, N-1, 2, 0.01, 15, 20, input_data_valid, N)

Epoch 1: 20.040657043457	
Average Loss: 0.2999279016966	


[string "function build_model(dwin, nchar, nclass, hid..."]:119: attempt to index global 'valid_perplexity' (a nil value)
stack traceback:
	[string "function build_model(dwin, nchar, nclass, hid..."]:119: in function 'train_model'
	[string "perp_train_3, perp_valid_3 = train_model(trai..."]:1: in main chunk
	[C]: in function 'xpcall'
	.../virgileaudi/torch/install/share/lua/5.1/itorch/main.lua:179: in function <.../virgileaudi/torch/install/share/lua/5.1/itorch/main.lua:143>
	.../virgileaudi/torch/install/share/lua/5.1/lzmq/poller.lua:75: in function 'poll'
	...rgileaudi/torch/install/share/lua/5.1/lzmq/impl/loop.lua:307: in function 'poll'
	...rgileaudi/torch/install/share/lua/5.1/lzmq/impl/loop.lua:325: in function 'sleep_ex'
	...rgileaudi/torch/install/share/lua/5.1/lzmq/impl/loop.lua:370: in function 'start'
	.../virgileaudi/torch/install/share/lua/5.1/itorch/main.lua:350: in main chunk
	[C]: in function 'require'
	(command line):1: in main chunk
	[C]: at 0x010289eb50: 

In [None]:
myFile = hdf5.open('../data_preprocessed/4-grams.hdf5','r')
data = myFile:all()
myFile:close()

N = 4

train_input = data['input_matrix_train']
train_output = data['output_matrix_train']
input_data_train = data['input_data_train']

input_data_valid = data['input_data_valid_nospace']:clone()

input_data_test = data['input_data_test']:clone()