diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..e98ba75 --- /dev/null +++ b/.gitignore @@ -0,0 +1,44 @@ +data/* +log/* +*.log + +# Compiled Lua sources +luac.out + +# luarocks build files +*.src.rock +*.zip +*.tar.gz + +# Object files +*.o +*.os +*.ko +*.obj +*.elf + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo +*.def +*.exp + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex diff --git a/README.md b/README.md new file mode 100755 index 0000000..42ba7af --- /dev/null +++ b/README.md @@ -0,0 +1,140 @@ +# Neural Conversational Model in Torch + +This is an attempt at implementing [Sequence to Sequence Learning with Neural Networks (seq2seq)](http://arxiv.org/abs/1409.3215) and reproducing the results in [A Neural Conversational Model](http://arxiv.org/abs/1506.05869) (aka the Google chatbot). + +The Google chatbot paper [became famous](http://www.sciencealert.com/google-s-ai-bot-thinks-the-purpose-of-life-is-to-live-forever) after cleverly answering a few philosophical questions, such as: + +> **Human:** What is the purpose of living? +> **Machine:** To live forever. + +## How it works + +The model is based on two [LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory) layers. One for encoding the input sentence into a "thought vector", and another for decoding that vector into a response. This model is called Sequence-to-sequence or seq2seq. + +![seq2seq](https://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png) +_Source: http://googleresearch.blogspot.ca/2015/11/computer-respond-to-this-email.html_ + +In this experiment, we train the seq2seq model with movie dialogs from the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html). The lines are shortened to the first sentence. + +## Sample conversation + +Here's a sample conversation after training for 20 epoch with 50000 examples, using the following command: + +```sh +th train.lua --cuda --dataset 50000 --hiddenSize 1000 +``` + +(Took 3 days to train on my GeForce GTX 780M.) + +For OpenCL, use `--opencl` instead of `--cuda`. To train on CPU, don't provide any of those two. + +> **me:** Hello? +> **bot:** Hi. +> +> **me:** How are you? +> **bot:** I'm fine. +> +> **me:** What's your name? +> **bot:** It's hard to describe. +> +> **me:** How so? +> **bot:** I'm not sure. +> +> **me:** What color is the sky? +> **bot:** It's blue. +> +> **me:** What is your job? +> **bot:** It's not that i'm a fucking werewolf! +> +> **me:** What is the purpose of life? +> **bot:** A gift. +> +> **me:** Are you intelligent? +> **bot:** Yes, well... +> +> **me:** Are you a machine? +> **bot:** That's a lie. +> +> **me:** Are you human? +> **bot:** No, i'm not. +> +> **me:** What are you? +> **bot:** I'm not sure. +> +> **me:** Do you plan on taking over the world? +> **bot:** No, i don't. + +Phew! That was close. Good thing I didn't train it on the full dataset. Please experiment responsibly. + +_(Disclaimer: nonsensical responses have been removed.)_ + +## Installing + +1. [Install Torch](http://torch.ch/docs/getting-started.html). +2. Install the following additional Lua libs: + + ```sh + luarocks install nn + luarocks install rnn + luarocks install penlight + ``` + + To train with CUDA install the latest CUDA drivers, toolkit and run: + + ```sh + luarocks install cutorch + luarocks install cunn + ``` + + To train with opencl install the lastest Opencl torch lib: + + ```sh + luarocks install cltorch + luarocks install clnn + ``` + +3. Download the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html) and extract all the files into data/cornell_movie_dialogs. + +## Training + +```sh +th train.lua [-h / options] +``` + +Use the `--dataset NUMBER` option to control the size of the dataset. Training on the full dataset takes about 5h for a single epoch. + +The model will be saved to `data/model.t7` after each epoch if it has improved (error decreased). + +## Testing + +To load the model and have a conversation: + +```sh +th -i eval.lua --cuda # Skip --cuda if you didn't train with it +# ... +th> say "Hello." +``` + +## License + +MIT License + +Copyright (c) 2016 Marc-Andre Cournoyer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/cornell_movie_dialogs.lua b/cornell_movie_dialogs.lua new file mode 100755 index 0000000..cf298d5 --- /dev/null +++ b/cornell_movie_dialogs.lua @@ -0,0 +1,72 @@ +local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") +local stringx = require "pl.stringx" +local xlua = require "xlua" + +local function parsedLines(file, fields) + local f = assert(io.open(file, 'r')) + + return function() + local line = f:read("*line") + + if line == nil then + f:close() + return + end + + local values = stringx.split(line, " +++$+++ ") + local t = {} + + for i,field in ipairs(fields) do + t[field] = values[i] + end + + return t + end +end + +function CornellMovieDialogs:__init(dir) + self.dir = dir +end + +local MOVIE_LINES_FIELDS = {"lineID","characterID","movieID","character","text"} +local MOVIE_CONVERSATIONS_FIELDS = {"character1ID","character2ID","movieID","utteranceIDs"} +local TOTAL_LINES = 387810 + +local function progress(c) + if c % 10000 == 0 then + xlua.progress(c, TOTAL_LINES) + end +end + +function CornellMovieDialogs:load() + local lines = {} + local conversations = {} + local count = 0 + + print("-- Parsing Cornell movie dialogs data set ...") + + for line in parsedLines(self.dir .. "/movie_lines.txt", MOVIE_LINES_FIELDS) do + lines[line.lineID] = line + line.lineID = nil + -- Remove unused fields + line.characterID = nil + line.movieID = nil + count = count + 1 + progress(count) + end + + for conv in parsedLines(self.dir .. "/movie_conversations.txt", MOVIE_CONVERSATIONS_FIELDS) do + local conversation = {} + local lineIDs = stringx.split(conv.utteranceIDs:sub(3, -3), "', '") + for i,lineID in ipairs(lineIDs) do + table.insert(conversation, lines[lineID]) + end + table.insert(conversations, conversation) + count = count + 1 + progress(count) + end + + xlua.progress(TOTAL_LINES, TOTAL_LINES) + + return conversations +end diff --git a/dataset.lua b/dataset.lua new file mode 100755 index 0000000..32e1c15 --- /dev/null +++ b/dataset.lua @@ -0,0 +1,233 @@ +--[[ +Format movie dialog data as a table of line 1: + + { {word_ids of character1}, {word_ids of character2} } + +Then flips it around and get the dialog from the other character's perspective: + + { {word_ids of character2}, {word_ids of character1} } + +Also builds the vocabulary. +]]-- + +local DataSet = torch.class("neuralconvo.DataSet") +local xlua = require "xlua" +local tokenizer = require "tokenizer" +local list = require "pl.List" + +function DataSet:__init(loader, options) + options = options or {} + + self.examplesFilename = "data/examples.t7" + + -- Discard words with lower frequency then this + self.minWordFreq = options.minWordFreq or 1 + + -- Maximum number of words in an example sentence + self.maxExampleLen = options.maxExampleLen or 25 + + -- Load only first fews examples (approximately) + self.loadFirst = options.loadFirst + + self.examples = {} + self.word2id = {} + self.id2word = {} + self.wordsCount = 0 + + self:load(loader) +end + +function DataSet:load(loader) + local filename = "data/vocab.t7" + + if path.exists(filename) then + print("Loading vocabulary from " .. filename .. " ...") + local data = torch.load(filename) + self.word2id = data.word2id + self.id2word = data.id2word + self.wordsCount = data.wordsCount + self.goToken = data.goToken + self.eosToken = data.eosToken + self.unknownToken = data.unknownToken + self.examplesCount = data.examplesCount + else + print("" .. filename .. " not found") + self:visit(loader:load()) + print("Writing " .. filename .. " ...") + torch.save(filename, { + word2id = self.word2id, + id2word = self.id2word, + wordsCount = self.wordsCount, + goToken = self.goToken, + eosToken = self.eosToken, + unknownToken = self.unknownToken, + examplesCount = self.examplesCount + }) + end +end + +function DataSet:visit(conversations) + -- Table for keeping track of word frequency + self.wordFreq = {} + self.examples = {} + + -- Add magic tokens + self.goToken = self:makeWordId("") -- Start of sequence + self.eosToken = self:makeWordId("") -- End of sequence + self.unknownToken = self:makeWordId("") -- Word dropped from vocabulary + + print("-- Pre-processing data") + + local total = self.loadFirst or #conversations * 2 + + for i, conversation in ipairs(conversations) do + if i > total then break end + self:visitConversation(conversation) + xlua.progress(i, total) + end + + -- Revisit from the perspective of 2nd character + for i, conversation in ipairs(conversations) do + if #conversations + i > total then break end + self:visitConversation(conversation, 2) + xlua.progress(#conversations + i, total) + end + + print("-- Removing low frequency words") + + for i, datum in ipairs(self.examples) do + self:removeLowFreqWords(datum[1]) + self:removeLowFreqWords(datum[2]) + xlua.progress(i, #self.examples) + end + + self.wordFreq = nil + + self.examplesCount = #self.examples + self:writeExamplesToFile() + self.examples = nil + + collectgarbage() +end + +function DataSet:writeExamplesToFile() + print("Writing " .. self.examplesFilename .. " ...") + local file = torch.DiskFile(self.examplesFilename, "w") + + for i, example in ipairs(self.examples) do + file:writeObject(example) + xlua.progress(i, #self.examples) + end + + file:close() +end + +function DataSet:batches(size) + local file = torch.DiskFile(self.examplesFilename, "r") + file:quiet() + local done = false + + return function() + if done then + return + end + + local examples = {} + + for i = 1, size do + local example = file:readObject() + if example == nil then + done = true + file:close() + return examples + end + table.insert(examples, example) + end + + return examples + end +end + +function DataSet:removeLowFreqWords(input) + for i = 1, input:size(1) do + local id = input[i] + local word = self.id2word[id] + + if word == nil then + -- Already removed + input[i] = self.unknownToken + + elseif self.wordFreq[word] < self.minWordFreq then + input[i] = self.unknownToken + + self.word2id[word] = nil + self.id2word[id] = nil + self.wordsCount = self.wordsCount - 1 + end + end +end + +function DataSet:visitConversation(lines, start) + start = start or 1 + + for i = start, #lines, 2 do + local input = lines[i] + local target = lines[i+1] + + if target then + local inputIds = self:visitText(input.text) + local targetIds = self:visitText(target.text, 2) + + if inputIds and targetIds then + -- Revert inputs + inputIds = list.reverse(inputIds) + + table.insert(targetIds, 1, self.goToken) + table.insert(targetIds, self.eosToken) + + table.insert(self.examples, { torch.IntTensor(inputIds), torch.IntTensor(targetIds) }) + end + end + end +end + +function DataSet:visitText(text, additionalTokens) + local words = {} + additionalTokens = additionalTokens or 0 + + if text == "" then + return + end + + for t, word in tokenizer.tokenize(text) do + table.insert(words, self:makeWordId(word)) + -- Only keep the first sentence + if t == "endpunct" or #words >= self.maxExampleLen - additionalTokens then + break + end + end + + if #words == 0 then + return + end + + return words +end + +function DataSet:makeWordId(word) + word = word:lower() + + local id = self.word2id[word] + + if id then + self.wordFreq[word] = self.wordFreq[word] + 1 + else + self.wordsCount = self.wordsCount + 1 + id = self.wordsCount + self.id2word[id] = word + self.word2id[word] = id + self.wordFreq[word] = 1 + end + + return id +end diff --git a/eval.lua b/eval.lua new file mode 100755 index 0000000..86d1772 --- /dev/null +++ b/eval.lua @@ -0,0 +1,77 @@ +require 'neuralconvo' +local tokenizer = require "tokenizer" +local list = require "pl.List" +local options = {} + +if dataset == nil then + cmd = torch.CmdLine() + cmd:text('Options:') + cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA') + cmd:option('--opencl', false, 'use OpenCL. Training must be done on OpenCL') + cmd:option('--debug', false, 'show debug info') + cmd:text() + options = cmd:parse(arg) + + -- Data + dataset = neuralconvo.DataSet() + + -- Enabled CUDA + if options.cuda then + require 'cutorch' + require 'cunn' + elseif options.opencl then + require 'cltorch' + require 'clnn' + end +end + +if model == nil then + print("-- Loading model") + model = torch.load("data/model.t7") +end + +-- Word IDs to sentence +function pred2sent(wordIds, i) + local words = {} + i = i or 1 + + for _, wordId in ipairs(wordIds) do + local word = dataset.id2word[wordId[i]] + table.insert(words, word) + end + + return tokenizer.join(words) +end + +function printProbabilityTable(wordIds, probabilities, num) + print(string.rep("-", num * 22)) + + for p, wordId in ipairs(wordIds) do + local line = "| " + for i = 1, num do + local word = dataset.id2word[wordId[i]] + line = line .. string.format("%-10s(%4d%%)", word, probabilities[p][i] * 100) .. " | " + end + print(line) + end + + print(string.rep("-", num * 22)) +end + +function say(text) + local wordIds = {} + + for t, word in tokenizer.tokenize(text) do + local id = dataset.word2id[word:lower()] or dataset.unknownToken + table.insert(wordIds, id) + end + + local input = torch.Tensor(list.reverse(wordIds)) + local wordIds, probabilities = model:eval(input) + + print(">> " .. pred2sent(wordIds)) + + if options.debug then + printProbabilityTable(wordIds, probabilities, 4) + end +end diff --git a/movie_script_parser.lua b/movie_script_parser.lua new file mode 100755 index 0000000..795d57d --- /dev/null +++ b/movie_script_parser.lua @@ -0,0 +1,116 @@ +local Parser = torch.class("neuralconvo.MovieScriptParser") + +function Parser:parse(file) + local f = assert(io.open(file, 'r')) + self.input = f:read("*all") + f:close() + + self.pos = 0 + self.match = nil + + -- Find start of script + repeat self:acceptLine() until self:accept("
")
+
+  local dialogs = {}
+
+  -- Apply rules until end of script
+  while not self:accept("
") and self:acceptLine() do + local dialog = self:parseDialog() + if dialog then + table.insert(dialogs, dialog) + end + end + + return dialogs +end + +-- Returns true if regexp matches and advance position +function Parser:accept(regexp) + local match = string.match(self.input, "^" .. regexp, self.pos) + if match then + self.pos = self.pos + #match + self.match = match + return true + end +end + +-- Accept anything up to the end of line +function Parser:acceptLine() + return self:accept(".-\n") +end + +function Parser:acceptSep() + while self:accept("") or self:accept(" +") do end + return self:accept("\n") +end + +function Parser:parseDialog() + local dialogs = {} + + repeat + local dialog = self:parseSpeech() + if dialog then + table.insert(dialogs, dialog) + end + until not self:acceptSep() + + if #dialogs > 0 then + return dialogs + end +end + +-- Matches: +-- +-- NAME +-- some nice text +-- more text. +-- +-- or +-- +-- NAME; text +function Parser:parseSpeech() + local name + + self:accept("") + self:accept("") + + -- Get the character name (all caps) + -- TODO remove parenthesis from name + if self:accept(" +") and self:accept("[A-Z][A-Z%- %.%(%)]+") then + name = self.match + else + return + end + + -- Handle inline dialog: `NAME; text` + if self:accept(";") and self:accept("[^\n]+") then + return { + character = name, + text = self.match + } + end + + self:accept("\n") + + if not self:accept("") then + return + end + + -- Get the dialog lines + -- TODO remove parenthesis from text + local lines = {} + while self:accept(" +") do + -- The actual line of dialog + if self:accept("[^\n]+") then + table.insert(lines, self.match) + end + self:accept("\n") + end + + if #lines > 0 then + return { + character = name, + text = table.concat(lines) + } + end +end diff --git a/neuralconvo.lua b/neuralconvo.lua new file mode 100755 index 0000000..3891f16 --- /dev/null +++ b/neuralconvo.lua @@ -0,0 +1,12 @@ +require 'torch' +require 'nn' +require 'rnn' + +neuralconvo = {} + +torch.include('neuralconvo', 'cornell_movie_dialogs.lua') +torch.include('neuralconvo', 'dataset.lua') +torch.include('neuralconvo', 'movie_script_parser.lua') +torch.include('neuralconvo', 'seq2seq.lua') + +return neuralconvo diff --git a/seq2seq.lua b/seq2seq.lua new file mode 100755 index 0000000..3f330e0 --- /dev/null +++ b/seq2seq.lua @@ -0,0 +1,136 @@ +-- Based on https://github.com/Element-Research/rnn/blob/master/examples/encoder-decoder-coupling.lua +local Seq2Seq = torch.class("neuralconvo.Seq2Seq") + +function Seq2Seq:__init(vocabSize, hiddenSize) + self.vocabSize = assert(vocabSize, "vocabSize required at arg #1") + self.hiddenSize = assert(hiddenSize, "hiddenSize required at arg #2") + + self:buildModel() +end + +function Seq2Seq:buildModel() + self.encoder = nn.Sequential() + self.encoder:add(nn.LookupTable(self.vocabSize, self.hiddenSize)) + self.encoder:add(nn.SplitTable(1, 2)) + self.encoderLSTM = nn.LSTM(self.hiddenSize, self.hiddenSize) + self.encoder:add(nn.Sequencer(self.encoderLSTM)) + self.encoder:add(nn.SelectTable(-1)) + + self.decoder = nn.Sequential() + self.decoder:add(nn.LookupTable(self.vocabSize, self.hiddenSize)) + self.decoder:add(nn.SplitTable(1, 2)) + self.decoderLSTM = nn.LSTM(self.hiddenSize, self.hiddenSize) + self.decoder:add(nn.Sequencer(self.decoderLSTM)) + self.decoder:add(nn.Sequencer(nn.Linear(self.hiddenSize, self.vocabSize))) + self.decoder:add(nn.Sequencer(nn.LogSoftMax())) + + self.encoder:zeroGradParameters() + self.decoder:zeroGradParameters() +end + +function Seq2Seq:cuda() + self.encoder:cuda() + self.decoder:cuda() + + if self.criterion then + self.criterion:cuda() + end +end + +function Seq2Seq:cl() + self.encoder:cl() + self.decoder:cl() + + if self.criterion then + self.criterion:cl() + end +end + +--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]-- +function Seq2Seq:forwardConnect(inputSeqLen) + self.decoderLSTM.userPrevOutput = + nn.rnn.recursiveCopy(self.decoderLSTM.userPrevOutput, self.encoderLSTM.outputs[inputSeqLen]) + self.decoderLSTM.userPrevCell = + nn.rnn.recursiveCopy(self.decoderLSTM.userPrevCell, self.encoderLSTM.cells[inputSeqLen]) +end + +--[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]-- +function Seq2Seq:backwardConnect() + self.encoderLSTM.userNextGradCell = + nn.rnn.recursiveCopy(self.encoderLSTM.userNextGradCell, self.decoderLSTM.userGradPrevCell) + self.encoderLSTM.gradPrevOutput = + nn.rnn.recursiveCopy(self.encoderLSTM.gradPrevOutput, self.decoderLSTM.userGradPrevOutput) +end + +function Seq2Seq:train(input, target) + local encoderInput = input + local decoderInput = target:sub(1, -2) + local decoderTarget = target:sub(2, -1) + + -- Forward pass + local encoderOutput = self.encoder:forward(encoderInput) + self:forwardConnect(encoderInput:size(1)) + local decoderOutput = self.decoder:forward(decoderInput) + local Edecoder = self.criterion:forward(decoderOutput, decoderTarget) + + if Edecoder ~= Edecoder then -- Exist early on bad error + return Edecoder + end + + -- Backward pass + local gEdec = self.criterion:backward(decoderOutput, decoderTarget) + self.decoder:backward(decoderInput, gEdec) + self:backwardConnect() + self.encoder:backward(encoderInput, encoderOutput:zero()) + + self.encoder:updateGradParameters(self.momentum) + self.decoder:updateGradParameters(self.momentum) + self.decoder:updateParameters(self.learningRate) + self.encoder:updateParameters(self.learningRate) + self.encoder:zeroGradParameters() + self.decoder:zeroGradParameters() + + self.decoder:forget() + self.encoder:forget() + + return Edecoder +end + +local MAX_OUTPUT_SIZE = 20 + +function Seq2Seq:eval(input) + assert(self.goToken, "No goToken specified") + assert(self.eosToken, "No eosToken specified") + + self.encoder:forward(input) + self:forwardConnect(input:size(1)) + + local predictions = {} + local probabilities = {} + + -- Forward and all of it's output recursively back to the decoder + local output = {self.goToken} + for i = 1, MAX_OUTPUT_SIZE do + local prediction = self.decoder:forward(torch.Tensor(output))[#output] + -- prediction contains the probabilities for each word IDs. + -- The index of the probability is the word ID. + local prob, wordIds = prediction:topk(5, 1, true, true) + + -- First one is the most likely. + next_output = wordIds[1] + table.insert(output, next_output) + + -- Terminate on EOS token + if next_output == self.eosToken then + break + end + + table.insert(predictions, wordIds) + table.insert(probabilities, prob) + end + + self.decoder:forget() + self.encoder:forget() + + return predictions, probabilities +end diff --git a/tokenizer.lua b/tokenizer.lua new file mode 100755 index 0000000..071f0b0 --- /dev/null +++ b/tokenizer.lua @@ -0,0 +1,55 @@ +local lexer = require "pl.lexer" +local yield = coroutine.yield +local M = {} + +local function word(token) + return yield("word", token) +end + +local function quote(token) + return yield("quote", token) +end + +local function space(token) + return yield("space", token) +end + +local function tag(token) + return yield("tag", token) +end + +local function punct(token) + return yield("punct", token) +end + +local function endpunct(token) + return yield("endpunct", token) +end + +local function unknown(token) + return yield("unknown", token) +end + +function M.tokenize(text) + return lexer.scan(text, { + { "^%s+", space }, + { "^['\"]", quote }, + { "^%w+", word }, + { "^%-+", space }, + { "^[,:;%-]", punct }, + { "^%.+", endpunct }, + { "^[%.%?!]", endpunct }, + { "^", tag }, + { "^.", unknown }, + }, { [space]=true, [tag]=true }) +end + +function M.join(words) + local s = table.concat(words, " ") + s = s:gsub("^%l", string.upper) + s = s:gsub(" (') ", "%1") + s = s:gsub(" ([,:;%-%.%?!])", "%1") + return s +end + +return M \ No newline at end of file diff --git a/train.lua b/train.lua new file mode 100755 index 0000000..5726f3c --- /dev/null +++ b/train.lua @@ -0,0 +1,121 @@ +require 'neuralconvo' +require 'xlua' + +cmd = torch.CmdLine() +cmd:text('Options:') +cmd:option('--dataset', 0, 'approximate size of dataset to use (0 = all)') +cmd:option('--minWordFreq', 1, 'minimum frequency of words kept in vocab') +cmd:option('--cuda', false, 'use CUDA') +cmd:option('--opencl', false, 'use opencl') +cmd:option('--hiddenSize', 300, 'number of hidden units in LSTM') +cmd:option('--learningRate', 0.05, 'learning rate at t=0') +cmd:option('--momentum', 0.9, 'momentum') +cmd:option('--minLR', 0.00001, 'minimum learning rate') +cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR') +cmd:option('--maxEpoch', 50, 'maximum number of epochs to run') +cmd:option('--batchSize', 1000, 'number of examples to load at once') + +cmd:text() +options = cmd:parse(arg) + +if options.dataset == 0 then + options.dataset = nil +end + +-- Data +print("-- Loading dataset") +dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"), + { + loadFirst = options.dataset, + minWordFreq = options.minWordFreq + }) + +print("\nDataset stats:") +print(" Vocabulary size: " .. dataset.wordsCount) +print(" Examples: " .. dataset.examplesCount) + +-- Model +model = neuralconvo.Seq2Seq(dataset.wordsCount, options.hiddenSize) +model.goToken = dataset.goToken +model.eosToken = dataset.eosToken + +-- Training parameters +model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion()) +model.learningRate = options.learningRate +model.momentum = options.momentum +local decayFactor = (options.minLR - options.learningRate) / options.saturateEpoch +local minMeanError = nil + +-- Enabled CUDA +if options.cuda then + require 'cutorch' + require 'cunn' + model:cuda() +elseif options.opencl then + require 'cltorch' + require 'clnn' + model:cl() +end + + +-- Run the experiment + +for epoch = 1, options.maxEpoch do + print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch) + print("") + + local errors = torch.Tensor(dataset.examplesCount):fill(0) + local timer = torch.Timer() + + local i = 1 + for examples in dataset:batches(options.batchSize) do + collectgarbage() + + for _, example in ipairs(examples) do + local input, target = unpack(example) + + if options.cuda then + input = input:cuda() + target = target:cuda() + elseif options.opencl then + input = input:cl() + target = target:cl() + end + + local err = model:train(input, target) + + -- Check if error is NaN. If so, it's probably a bug. + if err ~= err then + error("Invalid error! Exiting.") + end + + errors[i] = err + xlua.progress(i, dataset.examplesCount) + i = i + 1 + end + end + + timer:stop() + + print("\nFinished in " .. xlua.formatTime(timer:time().real) .. " " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.') + print("\nEpoch stats:") + print(" LR= " .. model.learningRate) + print(" Errors: min= " .. errors:min()) + print(" max= " .. errors:max()) + print(" median= " .. errors:median()[1]) + print(" mean= " .. errors:mean()) + print(" std= " .. errors:std()) + + -- Save the model if it improved. + if minMeanError == nil or errors:mean() < minMeanError then + print("\n(Saving model ...)") + torch.save("data/model.t7", model) + minMeanError = errors:mean() + end + + model.learningRate = model.learningRate + decayFactor + model.learningRate = math.max(options.minLR, model.learningRate) +end + +-- Load testing script +require "eval"