In [1]:
-- Importing libraries
include('init.lua')
include('WordIndexer.lua')
include('BatchIterator.lua')

In [2]:
-- Loading input data
local textfile = 'data/tiny_shakespeare.txt'
local indexfile = 'data/tiny_shakespeare.indexer.th'
local datafile = 'data/tiny_shakespeare.data.th'

local train_split = 0.8
local valid_split = 0.1

seq_len = 40

if paths.filep(indexfile) and paths.filep(datafile) then
    print('Loading indexer and data ...')
    indexer = torch.load(indexfile)
    data = torch.load(datafile)
else
    -- Building vocab
    print('Building vocab...')
    indexer = WordIndexer()
    local data_len = 0
    local f = assert(io.open(textfile, "r"))
    while true do
        local line = f:read()
        if not line then break end
        for c in line:gmatch('.') do
            indexer:add(c)
        end
        data_len = data_len + #line + 1
    end
    f:close()
    indexer:add('\n')
    print('Total chars: ' .. data_len)
    print('Total vocab: ' .. #indexer)

    -- Creating torch tensor
    print('Creating torch Tensor of data...')
    data = torch.ByteTensor(data_len)
    local cur_pos = 1
    local f = assert(io.open(textfile, "r"))
    while true do
        local line = f:read()
        if not line then break end
        for c in line:gmatch('.') do
            data[cur_pos] = indexer:index(c)
            cur_pos = cur_pos + 1
        end
        data[cur_pos] = indexer:index('\n')
        cur_pos = cur_pos + 1
    end
    f:close()
    -- Saving preprocessed data for later
    print('Saving data...')
    torch.save(indexfile, indexer)
    torch.save(datafile, data)
end
-- creating training batch
local len = data:size(1)
if len % (seq_len) ~= 0 then
    data = data:sub(1, seq_len * math.floor(len / seq_len))
end

labels = data:clone()
labels:sub(1,-2):copy(data:sub(2,-1))
labels[-1] = data[1]
data_seqs = data:split(seq_len)
label_seqs = labels:split(seq_len)
ntrain = math.floor(#data_seqs * train_split)
nvalid = math.floor(#data_seqs * valid_split)
ntest = #data_seqs - ntrain - nvalid
collectgarbage()
print('Total: ' .. #data_seqs)
print('Train: ' .. ntrain)
print('Valid: ' .. nvalid)
print('Test: ' .. ntest)

Loading indexer and data ...	


Total: 27884	
Train: 22307	
Valid: 2788	
Test: 2789	


In [3]:
-- constructing model
dim_word = 32
num_lstm_layers = 1
dim_cell = 32
dim_w = 16

emb = nn.LookupTable(#indexer, dim_word)
net = lstm.LSTM({input_dim=dim_word, hidden_dim=dim_cell, num_layers=num_lstm_layers})
classifier = nn.Sequential()
classifier:add(nn.Linear(dim_cell, dim_w))
classifier:add(nn.ReLU())
classifier:add(nn.Linear(dim_w, #indexer))
classifier:add(nn.LogSoftMax())
criterion = nn.ClassNLLCriterion()

In [12]:
local batch_size = 20
local n_epochs = 1

local mb_data = torch.zeros(batch_size, seq_len)
local mb_labels = torch.zeros(batch_size, seq_len)
local set_minibatch_data = function(mb_idx)
    for i = 1, mb_idx:size(1) do
        mb_data[{i, {}}] = data_seqs[i]
        mb_labels[{i, {}}] = label_seqs[i]
    end
end
for epoch = 1, n_epochs do
    -- setting up data for this epoch
    local shuffle = torch.randperm(ntrain)
    local batch_iter = BatchIterator(shuffle, batch_size, true)
    while batch_iter:has_next() do
        -- mini batch
        set_minibatch_data(batch_iter:next_batch())
        rep = emb:forward(mb_data)
    end
end


In [13]:
print(rep:size())

 20
 40
 32
[torch.LongStorage of size 3]

