In [1]:
require 'torch'
torch.setdefaulttensortype('torch.FloatTensor')

require 'nn'

require 'LanguageModel'
require 'util.DataLoader'

local utils = require 'util.utils'
local cmd = torch.CmdLine()

cmd:option('-checkpoint', '')
cmd:option('-split', 'val')
cmd:option('-gpu', -1)
cmd:option('-gpu_backend', 'cuda')

local myarg = {}
myarg[1]='-checkpoint'
myarg[2]='cv/checkpoint_16000.t7'
myarg[5]='-gpu'
myarg[6]=-1

opt = cmd:parse(myarg)

In [2]:
-- Set up GPU stuff
dtype = 'torch.FloatTensor'
if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then
  require 'cutorch'
  require 'cunn'
  cutorch.setDevice(opt.gpu + 1)
  dtype = 'torch.CudaTensor'
  print(string.format('Running with CUDA on GPU %d', opt.gpu))
elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then
  require 'cltorch'
  require 'clnn'
  cltorch.setDevice(opt.gpu + 1)
  dtype = torch.Tensor():cl():type()
  print(string.format('Running with OpenCL on GPU %d', opt.gpu))
else
  -- Memory benchmarking is only supported in CUDA mode
  print 'Running in CPU mode'
end

Running in CPU mode	


In [14]:
-- Load the checkpoint and model
checkpoint = torch.load(opt.checkpoint)
model = checkpoint.model
model:type(dtype)
model.dropout = 0.0 -- this actually does NOT disable dropout in the model; as the dropout layer is already added

In [19]:
--local utf8 = require 'lua-utf8'

local sen = '__ent_numeric 种 设备 , 其 包括 至少 __ent_numeric 个 片状 部件 ( __ent_numeric ) , 在 运动 方向 ( F ) 上 看 时 , 该 片状 部件 具有 在 板面 内 延伸 的 后端 片状 部分 和 前部 部分 。'

local x = model:encode_string(sen):type(dtype)
local senLen = x:size(1)
x = x:reshape(1, senLen, -1)

model:resetStates()
local scores = model:forward(x):view(senLen, -1):type(dtype)
scores = nn.LogSoftMax():forward(scores):type(dtype)
local sum = 0.0
for i=1,senLen-1 do
  print(i, x[1][i], sum)
  sum = sum + scores[i][x[1][i+1]]
end
--sum = sum + scores[senLen][x1][]]
print(sum)
local avgNLL = -sum / senLen
print(avgNLL)
local perp = torch.exp(avgNLL)
print(string.format('Perplexity: %f', perp))

1	1	0	
2	1	-1.7792618274689	
3	2	-1.7797123984492	
4	3	-1.7797226791754	
5	4	-1.7797402884353	
6	1	-1.7797629063643	
7	3	-1.7798042782342	
8	5	-1.7799482720548	
9	6	-1.7799482720548	
10	2	-1.7891230697387	
11	7	-1.7891230697387	
12	8	-1.7891230697387	
13	9	-1.7891230697387	
14	

10	-1.7927537697942	
15	11	-6.8372588891179	
16	

10	-6.8916490401716	
17	12	-13.066644462535	
18	13	-14.02299860377	
19	10	-14.04921529513	
20	14	-16.741634347254	
21	10	-16.748998678533	
22	15	-20.915171182958	
23	10	-21.958461082784	
24	16	-28.801147259084	
25	17	-29.585054732172	
26	10	-29.592092870513	
27	18	-36.467625497618	
28	19	-39.082425473967	
29	10	-39.120321596751	
30	1	-39.324067721853	
31	1	-39.326044105703	
32	2	-39.326044105703	
33	3	-39.326044105703	
34	4	-39.326044105703	
35	1	-39.326085028191	
36	3	-39.326970173057	


37	5	-39.326970173057	
38	6	-39.32927769866	
39	2	-39.332306900305	
40	7	-39.332306900305	
41	8	-39.332315062224	
42	9	-39.332335155294	
43	10	-39.333493513008	
44	20	-40.735694092651	
45	10	-40.741002364502	
46	21	-47.083348079072	
47	22	-56.724081797944	
48	10	-56.792514948593	
49	23	-63.215706972824	
50	24	-67.310357718216	
51	10	-67.317708744155	
52	25	-72.145963443862	
53	10	-72.148052075598	
54	1	-73.557580569479	
55	1	-73.561130330152	
56	2	-73.561135253034	
57	3	-73.561135253034	
58	4	-73.561135253034	
59	1	-73.56113909822	
60	3	-73.561791244222	
61	5	-73.561794467405	


62	6	-73.561794467405	
63	2	-73.59222963652	
64	7	-73.59222963652	
65	8	-73.59222963652	
66	9	-73.59222963652	
67	10	-73.595828637848	
68	26	-75.032799110184	
69	10	-75.292957857618	
70	14	-77.612684324751	
71	10	-77.623481334554	
72	27	-81.810157359945	
73	10	-81.829287213641	
74	28	-87.826969785052	
75	29	-89.138147754031	
76	10	-89.349090931373	
77	30	-94.465973255592	
78	31	-97.062296268897	
79	10	-97.12317874785	
80	25	-102.62969377394	
81	10	-102.63022228722	
82	32	-108.84081217293	
83	10	-110.52736433987	
84	26	-114.00110205655	
85	10	

-114.09496999924	
86	33	-119.36266351883	
87	10	-119.97852177565	
88	34	-127.76025338118	
89	10	-128.19866723959	
90	35	-135.19497318213	
91	10	-135.93309832756	
92	14	-137.22005380337	
93	10	-137.23178883847	
94	36	-142.17613812742	
95	10	-142.17843084895	
96	21	-148.18224888407	
97	22	-156.87557964884	
98	10	-157.01251549386	
99	23	-163.54126162194	
100	24	-167.90416998529	
101	10	-167.91866488681	
102	37	-173.53395505176	
103	38	-173.57477183194	
104	10	-173.58339260165	
105	27	-180.42788217609	
106	10	-180.58968458716	
107	39	-187.92866907661	


108	40	-192.54480276649	
109	10	-192.54821866002	
110	41	-196.26929061856	
111	10	-196.82015924897	
112	42	-204.35829096283	
113	43	-206.65846591439	
114	10	-206.66629121449	
115	44	-209.41805122044	
116	10	-209.42835013844	
117	45	-215.9638344524	
118	46	-221.48058431126	
119	10	-221.49459704287	
120	21	-228.37754162676	
121	22	-235.20477923281	
122	10	-235.35129124947	
123	23	-242.7493801815	
124	47	-243.07185150213	
125	10	-243.08095600079	
126	48	-247.60324813794	
127	10	-247.60797535019	
128	49	-254.26450620727	
129	23	-260.11169897156	
130	10	-260.13888148913	
131	23	-266.72852973589	
132	47	-268.05826799521	
133	10	-268.07091259533	


-271.61745500141	
2.0269959328464	
Perplexity: 7.591247	
