-
Notifications
You must be signed in to change notification settings - Fork 71
/
0-data.jl
151 lines (126 loc) · 2.99 KB
/
0-data.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
using ArgParse
using Random
Random.seed!(0)
using Transformers.Datasets
using Transformers.Datasets: WMT, IWSLT
function parse_commandline()
s = ArgParseSettings()
@add_arg_table s begin
"--gpu", "-g"
help = "use gpu"
action = :store_true
"task"
help = "task name"
required = true
range_tester = x-> x ∈ ["wmt14", "iwslt2016", "copy"]
end
return parse_args(ARGS, s)
end
const args = parse_commandline()
enable_gpu(args["gpu"])
if args["task"] == "copy"
const N = 2
const V = 10
const Smooth = 1e-6
const Batch = 32
const lr = 1e-4
startsym = 11
endsym = 12
unksym = 0
labels = [unksym, startsym, endsym, collect(1:V)...]
function gen_data()
global V
d = rand(1:V, 10)
(d,d)
end
function preprocess(data)
x, t = data
x = mkline.(x)
t = mkline.(t)
x_mask = getmask(x)
t_mask = getmask(t)
x, t = vocab(x, t)
todevice(x,t,x_mask,t_mask)
end
function train!()
global Batch
println("start training")
model = (embed=embed, encoder=encoder, decoder=decoder)
i = 1
for i = 1:320*7
data = batched([gen_data() for i = 1:Batch])
x, t, x_mask, t_mask = preprocess(data)
grad = gradient(ps) do
l = loss(model, x, t, x_mask, t_mask)
l
end
i%8 == 0 && @show loss(model, x, t, x_mask, t_mask)
update!(opt, ps, grad)
end
end
mkline(x) = [startsym, x..., endsym]
elseif args["task"] == "wmt14" || args["task"] == "iwslt2016"
const N = 6
const Smooth = 0.4
const Batch = 8
const lr = 1e-6
const MaxLen = 100
const task = args["task"]
if task == "wmt14"
wmt14 = WMT.GoogleWMT()
datas = dataset(Train, wmt14)
vocab = get_vocab(wmt14)
else
iwslt2016 = IWSLT.IWSLT2016(:en, :de)
datas = dataset(Train, iwslt2016)
vocab = get_vocab(iwslt2016)
end
startsym = "<s>"
endsym = "</s>"
unksym = "</unk>"
labels = [unksym, startsym, endsym, collect(keys(vocab))...]
function preprocess(batch)
x = mkline.(batch[1])
t = mkline.(batch[2])
x_mask = getmask(x)
t_mask = getmask(t)
x, t = vocab(x, t)
todevice(x,t,x_mask,t_mask)
end
function train!()
global Batch
println("start training")
i = 1
model = (embed=embed, encoder=encoder, decoder=decoder)
while (batch = get_batch(datas, Batch)) != []
x, t, x_mask, t_mask = preprocess(batch)
grad = gradient(ps) do
loss(model, x, t, x_mask, t_mask)
end
i+=1
i%8 == 0 && @show loss(model, x, t, x_mask, t_mask)
@time update!(opt, ps, grad)
end
end
if task == "wmt14"
function mkline(x)
global MaxLen
xi = split(x)
if length(xi) > MaxLen
xi = xi[1:100]
end
[startsym, xi..., endsym]
end
else
function mkline(x)
global MaxLen
xi = tokenize(x)
if length(xi) > MaxLen
xi = xi[1:100]
end
[startsym, xi..., endsym]
end
end
else
error("task not define")
end