In [1]:
import glob
import os
import fasttext
from itertools import combinations, permutations
from copy import copy
import sentencepiece as spm
from copy import copy
from itertools import islice

In [2]:
def test_model(test_name:str,
               lang_pair:tuple=None,
               test_data:str=None,
               tc:str=None,
               tc_m:str=None,
               sp:str=None,
               sp_m:str=None,
               tr_m:str=None,
               out:str=None,
               checkpoint:int=None,
               include_tags=False):
    """
    Args:
        test_data: Directory contains
        
    """
    dir_name = os.path.basename(os.path.dirname(tr_m))
    model = os.path.basename(tr_m)
    model_name = f'{dir_name}-{model}'
    cd = os.getcwd()
    data_path = os.path.abspath(test_data)
    if not os.path.exists(out):
        os.makedirs(out)
    (src_lang, tgt_lang) = lang_pair
    pair = r"{}-{}".format(src_lang, tgt_lang) 
    tmp = os.path.join(data_path, f"tmp-{test_name}-{pair}-{model}")
    output = os.path.join(cd, out, f'{test_name}-{pair}-{model_name}-{str(checkpoint)}.res')
    if not os.path.exists(tmp):
        os.makedirs(tmp)
    else:
        files = glob.glob(f'{tmp}/*')
        for f in files:
            try:
                os.unlink(f)
            except OSError as e:
                print("Error: %s : %s" % (f, e.strerror))
    if test_name == 'WMT18':
        inp_path = os.path.join(tmp,  f'{pair}.src')
        r = !sacrebleu -t wmt18 --language {pair} --echo src > {inp_path}
    elif test_name == 'WMT20':
        inp_path = f'data/dev/wmt20/{src_lang}-{tgt_lang}.{src_lang}'
        tgt_path = f'data/dev/wmt20/{src_lang}-{tgt_lang}.{tgt_lang}'     
    elif test_name == 'ACCURAT':
        inp_path = f'data/test/test.{src_lang}'
        tgt_path = f'data/test/test.{tgt_lang}'
    else:
        print('Unknown test set')
    tc_path = os.path.join(cd, tc)
    tcm_path = os.path.join(cd, tc_m)
    tc_out = os.path.join(tmp, "tc.out")
    !python {tc_path} {tcm_path} {inp_path} > {tc_out}
    !python {sp} --action split  --model {sp_m} --corpora {tc_out}
    tags = os.path.join(tmp, pair + '.tag')
    sp_name = os.path.basename(sp_m)
    sp_out = os.path.join(tmp, sp_name + '-tc.out')
    with open(sp_out, 'r') as _input, open(tags, 'w') as tag_output:
        for line in _input:
            l = len(line.split())
            tag_output.write(f"{' '.join([tgt_lang] * l)}" + '\n')
    tr_out = os.path.join(tmp, f"{pair}.out")
    if include_tags:
        if checkpoint == 'last':
            last_checkpoint = max([int(param.split('.')[-1]) for param in glob.glob(os.path.join(tr_m, 'params.[0-9][0-9][0-9][0-9][0-9]'))])
            !python -m sockeye.translate --input-factors {tags} --checkpoints {last_checkpoint} --beam-size 6 --batch-size 100 --disable-device-locking --models {tr_m} --input {sp_out} --output {tr_out}
        elif type(checkpoint) == int:
            !python -m sockeye.translate --input-factors {tags} --checkpoints {checkpoint} --beam-size 6 --batch-size 100 --disable-device-locking --models {tr_m} --input {sp_out} --output {tr_out}
        elif checkpoint == None:
            !python -m sockeye.translate --input-factors {tags} --beam-size 6 --batch-size 100 --disable-device-locking --models {tr_m} --input {sp_out} --output {tr_out}
    else:
        if checkpoint == 'last':
            last_checkpoint = max([int(param.split('.')[-1]) for param in glob.glob(os.path.join(tr_m, 'params.[0-9][0-9][0-9][0-9][0-9]'))])
            !python -m sockeye.translate --checkpoints {last_checkpoint} --beam-size 6 --batch-size 100 --disable-device-locking --models {tr_m} --input {sp_out} --output {tr_out}
        elif type(checkpoint) == int:
            !python -m sockeye.translate --checkpoints {checkpoint} --beam-size 6 --batch-size 100 --disable-device-locking --models {tr_m} --input {sp_out} --output {tr_out}
        elif checkpoint == None:
            !python -m sockeye.translate --beam-size 6 --batch-size 100 --disable-device-locking --models {tr_m} --input {sp_out} --output {tr_out}                                                  
    !python {sp} --action restore --corpora {tr_out} --model {sp_m} 
    res_path = os.path.join(tmp, f'de-{sp_name}-{os.path.basename(tr_out)}')
    dir_name = os.path.basename(os.path.dirname(tr_m))
    if test_name == 'WMT18':
        !cat {res_path} | sacrebleu -t wmt18 -l {pair} > {output}
    elif test_name == 'WMT20':
        !cat {res_path} | sacrebleu {tgt_path} > {output}                               
    elif test_name == 'ACCURAT':
        !cat {res_path} | sacrebleu {tgt_path} > {output}
    else:
        print('Unknown test set')
    with open(output) as o:
        print(o.read())

In [10]:
%%time
test_model('WMT18',
           ('en','et'),
           'data/test/baseline-3/',
           'scripts/truecaser/applytc.py',
           'models/preproc-models/tc-en',
           'scripts/word-pieces.py',
           'models/preproc-models/sp',
           'models/baselines/enet_saveoften/',
           'data/test/baseline-3/out',
           182,
          False)

2021-02-01 17:19:39.586821: processed 2000 lines
02/01/2021 05:19:40 PM INFO: Loading model
02/01/2021 05:19:40 PM INFO: Splitting file /gpfs/space/home/kolesnyk/nmt/data/test/baseline-3/tmp-WMT18-en-et-/tc.out
[INFO:sockeye.utils] Sockeye version 2.3.2, commit 26c02b1016b0937714ecd4ab367a6a67761ef2df, path /gpfs/space/home/kolesnyk/.conda/envs/gpu_sockeye/lib/python3.7/site-packages/sockeye/__init__.py
[INFO:sockeye.utils] MXNet version 1.7.0, path /gpfs/space/home/kolesnyk/.conda/envs/gpu_sockeye/lib/python3.7/site-packages/mxnet/__init__.py
[INFO:sockeye.utils] Command: /gpfs/space/home/kolesnyk/.conda/envs/gpu_sockeye/lib/python3.7/site-packages/sockeye/translate.py --checkpoints 182 --beam-size 6 --batch-size 100 --disable-device-locking --models models/baselines/enet_saveoften/ --input /gpfs/space/home/kolesnyk/nmt/data/test/baseline-3/tmp-WMT18-en-et-/sp-tc.out --output /gpfs/space/home/kolesnyk/nmt/data/test/baseline-3/tmp-WMT18-en-et-/en-et.out
[INFO:sockeye.utils] Arguments: 

#### for test in ['WMT20', 'WMT18', 'ACCURAT']:
    if test == 'WMT18':
        pairs = [ ('en', 'et'), ('et', 'en')]
    elif test == 'WMT20'
        pairs = [('en', 'ru'), ('ru', 'en')]
    elif test == 'ACCURAT':
        pairs = [('ru', 'et'), ('et', 'ru')]
    print(pairs)
    for (src, tgt) in pairs:
        for model in glob.glob('models/backtranslate/*[2-8]m_[0-1]'):
            name = os.path.basename(model)
            temp_dir = os.path.join('data/test/bt-1', name)
            out_dir = os.path.join('data/test/bt-1/out', name)
            tc_model = "models/preproc-models/tc" f'-{src}'
            test_model(test,
                       (src, tgt),
                       temp_dir,
                       'scripts/truecaser/applytc.py',
                       tc_model,
                       'scripts/word-pieces.py',
                       'models/preproc-models/sp',
                       model,
                       out_dir,
                      None)


In [None]:
for test in ['WMT20', 'WMT18']:
    if test == 'WMT18':
        pairs = [ ('en', 'et'), ('et', 'en')]
    elif test == 'WMT20':
        pairs = [('en', 'ru'), ('ru', 'en')]
    elif test == 'ACCURAT':
        pairs = [('ru', 'et'), ('et', 'ru')]
    print(pairs)
    for (src, tgt) in pairs:
        for model in glob.glob('models/baselines/*'):
            name = os.path.basename(model)
            temp_dir = os.path.join('data/test/baseline-3', name)
            out_dir = os.path.join('data/test/baseline-3/out', name)
            tc_model = "models/preproc-models/tc" f'-{src}'
            test_model(test,
                       (src, tgt),
                       temp_dir,
                       'scripts/truecaser/applytc.py',
                       tc_model,
                       'scripts/word-pieces.py',
                       'models/preproc-models/sp',
                       model,
                       out_dir,
                      None)
