forked from Graydyn/char-rnn
-
Notifications
You must be signed in to change notification settings - Fork 4
/
sample.lua
118 lines (95 loc) · 3.61 KB
/
sample.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
--[[
This file samples characters from a trained model
Code is based on implementation in
https://github.com/oxford-cs-ml-2015/practical6
]]--
require 'torch'
require 'nn'
require 'nngraph'
require 'optim'
require 'lfs'
require 'util.OneHot'
require 'util.misc'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Sample from a character-level language model')
cmd:text()
cmd:text('Options')
-- required:
cmd:argument('-model','model checkpoint to use for sampling')
-- optional parameters
cmd:option('-seed',123,'random number generator\'s seed')
cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
cmd:option('-primetext',"the",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
cmd:option('-length',7,'number of words to sample')
cmd:option('-temperature',1,'temperature of sampling')
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:text()
-- parse input params
opt = cmd:parse(arg)
if opt.gpuid >= 0 then
print('using CUDA on GPU ' .. opt.gpuid .. '...')
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
end
torch.manualSeed(opt.seed)
-- load the model checkpoint
if not lfs.attributes(opt.model, 'mode') then
print('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?')
end
checkpoint = torch.load(opt.model)
local vocab = checkpoint.vocab
local ivocab = {}
for c,i in pairs(vocab) do ivocab[i] = c end
protos = checkpoint.protos
local rnn_idx = #protos.softmax.modules - 1
opt.rnn_size = protos.softmax.modules[rnn_idx].weight:size(2)
-- initialize the rnn state
local current_state, state_predict_index
local model = checkpoint.opt.model
print('creating an LSTM...')
local num_layers = checkpoint.opt.num_layers or 1 -- or 1 is for backward compatibility
current_state = {}
for L=1,checkpoint.opt.num_layers do
-- c and h for all layers
local h_init = torch.zeros(1, opt.rnn_size)
if opt.gpuid >= 0 then h_init = h_init:cuda() end
table.insert(current_state, h_init:clone())
table.insert(current_state, h_init:clone())
end
state_predict_index = #current_state -- last one is the top h
local seed_text = opt.primetext
local prev_char
protos.rnn:evaluate() -- put in eval mode so that dropout works properly
-- do a few seeded timesteps - if the words here aren't in training data you're going to have a bad time
print('seeding with ' .. seed_text)
for c in seed_text:gmatch'%w+' do
prev_char = torch.Tensor{vocab[c]}
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
local embedding = protos.embed:forward(prev_char)
current_state = protos.rnn:forward{embedding, unpack(current_state)}
if type(current_state) ~= 'table' then current_state = {current_state} end
end
-- start sampling/argmaxing
for i=1, opt.length do
-- softmax from previous timestep
local next_h = current_state[state_predict_index]
next_h = next_h / opt.temperature
local log_probs = protos.softmax:forward(next_h)
if opt.sample == 0 then
-- use argmax
local _, prev_char_ = log_probs:max(2)
prev_char = prev_char_:resize(1)
else
-- use sampling
local probs = torch.exp(log_probs):squeeze()
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
end
-- forward the rnn for next word
local embedding = protos.embed:forward(prev_char)
current_state = protos.rnn:forward{embedding, unpack(current_state)}
if type(current_state) ~= 'table' then current_state = {current_state} end
io.write(ivocab[prev_char[1]] .. ' ')
end
io.write('\n') io.flush()