-
Notifications
You must be signed in to change notification settings - Fork 2
/
mg2p.py
136 lines (129 loc) · 7.12 KB
/
mg2p.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
from tools.model import G2PModel
import argparse
import sys
parser = argparse.ArgumentParser()
parser.add_argument('name', help="Path to model")
'''
parser.add_argument('-t', '--tokens',
nargs='*',
default=[],
help='Artificial tokens to add to the beginning of each source-side line, in practice always the langid feature (default: none of them)')
'''
parser.add_argument('-src_features',
nargs='*',
default=[],
help='Character-level features to concatenate to the source input at each time step (default: none of them)')
parser.add_argument('-tgt_features',
nargs='*',
default=[],
help='Character-level features to concatenate to the target input at each time step (default: none of them)')
parser.add_argument('-preprocess', action='store_true',
help='Apply torch preprocessing to the training and validation data')
parser.add_argument('-train', action='store_true',
help='Train the model')
parser.add_argument('-translate', action='store_true',
help='Translate the model')
parser.add_argument('-train_config', default=None,
help='OpenNMT parameters for training')
parser.add_argument('-l', '--lang',
nargs='*',
default=None,
help='If preprocessing, languages for which to select data (default: all)')
parser.add_argument('-s', '--script',
nargs='*',
default=None,
help='If preprocessing, scripts for which to select data (default: all)')
opt = parser.parse_args()
HIGH_RESOURCE = ['ady', 'afr', 'ain', 'amh', 'ang', 'ara', 'arc', 'ast',
'aze', 'bak', 'ben', 'bre', 'bul', 'cat', 'ces', 'cym',
'dan', 'deu', 'dsb', 'ell', 'eng', 'epo', 'eus', 'fao',
'fas', 'fin', 'fra', 'gla', 'gle', 'hbs', 'heb', 'hin',
'hun', 'hye', 'ido', 'isl', 'ita', 'jbo', 'jpn', 'kat',
'kbd', 'kor', 'kur', 'lao', 'lat', 'lav', 'lit', 'ltz',
'mkd', 'mlt', 'msa', 'mya', 'nan', 'nci', 'nld', 'nno',
'nob', 'oci', 'pol', 'por', 'pus', 'ron', 'rus', 'san',
'scn', 'sco', 'sga', 'slk', 'slv', 'spa', 'sqi', 'swe',
'syc', 'tel', 'tgk', 'tgl', 'tha', 'tur', 'ukr', 'urd',
'vie', 'vol', 'yid', 'yue', 'zho']
ADAPTED = ['aar', 'abk', 'abq', 'ace', 'ach', 'ady', 'afr', 'agr',
'aka', 'akl', 'akz', 'ale', 'alt', 'ami', 'aqc', 'ara',
'arg', 'arw', 'arz', 'asm', 'ava', 'aym', 'aze', 'bak',
'bal', 'bam', 'bcl', 'bel', 'ben', 'bis', 'bod', 'bos',
'bre', 'bua', 'bug', 'bul', 'cat', 'ceb', 'ces', 'cha',
'che', 'chk', 'chm', 'cho', 'chv', 'cic', 'cjs', 'cor',
'crh', 'cym', 'dan', 'dar', 'deu', 'dsb', 'eng', 'est',
'eus', 'ewe', 'fao', 'fas', 'fij', 'fil', 'fin', 'fra',
'frr', 'fry', 'fur', 'gaa', 'gag', 'gla', 'gle', 'glg',
'grc', 'grn', 'gsw', 'guj', 'hak', 'hat', 'hau', 'haw',
'hbs', 'heb', 'hil', 'hin', 'hit', 'hrv', 'hun', 'iba',
'ilo', 'ind', 'inh', 'isl', 'ita', 'jam', 'jav', 'kaa',
'kab', 'kal', 'kan', 'kaz', 'kbd', 'kea', 'ket', 'khb',
'kin', 'kir', 'kjh', 'kom', 'kum', 'kur', 'lat', 'lav',
'lin', 'lit', 'lld', 'lug', 'luo', 'lus', 'lzz', 'mah',
'mal', 'mar', 'mkd', 'mlg', 'mlt', 'mnk', 'mns', 'moh',
'mon', 'mri', 'msa', 'mus', 'mww', 'mya', 'myv', 'mzn',
'nah', 'nap', 'nau', 'nci', 'nds', 'nep', 'new', 'nia',
'niu', 'nld', 'nob', 'non', 'nor', 'nso', 'oci', 'oss',
'osx', 'pag', 'pam', 'pan', 'pau', 'pol', 'pon', 'por',
'ppl', 'prs', 'pus', 'que', 'roh', 'rom', 'ron', 'rtm',
'rus', 'ryu', 'sac', 'sah', 'san', 'sat', 'scn', 'sei',
'slv', 'sme', 'sna', 'snd', 'som', 'sot', 'spa', 'sqi',
'srd', 'srp', 'sun', 'swa', 'swe', 'tam', 'tat', 'tay',
'tel', 'tgk', 'tgl', 'tir', 'tkl', 'tly', 'tpi', 'tsn',
'tuk', 'tur', 'tvl', 'twi', 'tyv', 'udm', 'uig', 'ukr',
'umb', 'unk', 'urd', 'uzb', 'vie', 'wbp', 'wol', 'wuu',
'xal', 'xho', 'xmf', 'yap', 'yid', 'yij', 'yor', 'yua',
'yue', 'zha', 'zho', 'zul', 'zza']
TYPO = ['aar', 'aau', 'abk', 'abq', 'abt', 'ace', 'ach', 'adj', 'ady',
'afr', 'agm', 'agr', 'aht', 'ain', 'aka', 'akl', 'akz', 'ale',
'alp', 'alr', 'alt', 'amh', 'amn', 'amp', 'apn', 'apy', 'aqc',
'are', 'arg', 'arn', 'arw', 'arz', 'asm', 'ava', 'ayl', 'bak',
'bam', 'ban', 'bar', 'bbc', 'bcl', 'bdr', 'ben', 'blc', 'bod',
'bor', 'bre', 'brg', 'bug', 'bul', 'cat', 'ccc', 'ceb', 'ces',
'cha', 'che', 'chl', 'cho', 'chr', 'chv', 'cic', 'cjs', 'ckt',
'cmn', 'com', 'cqd', 'crg', 'crh', 'cri', 'dbl', 'deu', 'dob',
'dru', 'duj', 'ell', 'eng', 'epo', 'eto', 'eus', 'evn', 'ewe',
'fij', 'fil', 'fin', 'fra', 'frr', 'fur', 'gaa', 'gag', 'gle',
'glg', 'gqn', 'gub', 'guj', 'gvf', 'hak', 'hau', 'haw', 'hdn',
'heb', 'hil', 'hin', 'hrv', 'hts', 'hun', 'hye', 'iba', 'ilo',
'ind', 'inh', 'isl', 'ita', 'itl', 'jam', 'jav', 'jpn', 'kaa',
'kab', 'kac', 'kal', 'kan', 'kat', 'kay', 'kaz', 'kbd', 'kca',
'kea', 'ket', 'kgp', 'khb', 'khm', 'kij', 'kin', 'kir', 'kjh',
'kmg', 'kmv', 'kor', 'kpv', 'krc', 'kum', 'kut', 'kxo', 'lac',
'lao', 'lav', 'led', 'lez', 'lin', 'lit', 'lkt', 'ltz', 'lug',
'luo', 'lus', 'lzz', 'mal', 'mar', 'mco', 'mkd', 'mlt', 'mlv',
'mnk', 'mns', 'mri', 'mtq', 'mww', 'mya', 'myp', 'mzn', 'nab',
'naq', 'nav', 'nch', 'nds', 'nep', 'new', 'nhg', 'nia', 'nio',
'niv', 'nld', 'nob', 'ntj', 'oss', 'pac', 'pag', 'pam', 'pan',
'pap', 'pau', 'pdt', 'pjt', 'pny', 'pol', 'pon', 'por', 'ppl',
'pre', 'rap', 'rif', 'ron', 'run', 'rus', 'sah', 'san', 'sat',
'sco', 'sei', 'shn', 'sin', 'slv', 'sna', 'snd', 'som', 'spa',
'squ', 'srp', 'stp', 'str', 'sun', 'swe', 'swh', 'tam', 'tat',
'tay', 'tel', 'tgl', 'tha', 'tir', 'tli', 'tpi', 'tqw', 'tsi',
'tuk', 'tur', 'twf', 'tyv', 'tzm', 'ude', 'uig', 'ukr', 'umb',
'unk', 'unm', 'urb', 'urd', 'vie', 'vma', 'wiy', 'wol', 'wuu',
'xal', 'xho', 'xmf', 'yap', 'yii', 'yor', 'yrk', 'yua', 'yue',
'zai', 'zpq', 'zul']
def main():
if not any([opt.preprocess, opt.train, opt.translate]):
print('Specify at least one action (preprocess, train, test)')
sys.exit()
if opt.lang == ['high']:
lang = HIGH_RESOURCE
elif opt.lang == ['adapted']:
lang = ADAPTED
elif opt.lang == ['typo']:
lang = TYPO
else:
lang = opt.lang
model = G2PModel(opt.name, train_langs=lang, train_scripts=opt.script,
src_features=opt.src_features, tgt_features=opt.tgt_features)
if opt.preprocess:
model.preprocess()
if opt.train:
model.train(opt.train_config)
if opt.translate:
model.translate()
if __name__ == '__main__':
main()