Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
added MT translation script for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Lample committed Mar 26, 2019
1 parent 14fe2d4 commit 97aeb8d
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Translate sentences from the input stream.
# The model will be faster is sentences are sorted by length.
# Input sentences must have the same tokenization and BPE codes than the ones used in the model.
#

import os
import io
import sys
import argparse
import torch

from src.utils import AttrDict
from src.utils import bool_flag, initialize_exp
from src.data.dictionary import Dictionary
from src.model.transformer import TransformerModel

from src.fp16 import network_to_half


def get_parser():
"""
Generate a parameters parser.
"""
# parse parameters
parser = argparse.ArgumentParser(description="Translate sentences")

# main parameters
parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--fp16", type=bool_flag, default=False, help="Run model with float16")
parser.add_argument("--batch_size", type=int, default=8, help="Number of sentences per batch")

# model / output paths
parser.add_argument("--model_path", type=str, default="", help="Model path")
parser.add_argument("--output_path", type=str, default="", help="Output path")

# parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
# parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")

# source language / target language
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")

return parser


def main(params):

# initialize the experiment
logger = initialize_exp(params)

# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
reloaded = torch.load(params.model_path)
model_params = AttrDict(reloaded['params'])
logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

# update dictionary parameters
for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
setattr(params, name, getattr(model_params, name))

# build dictionary / build encoder / build decoder / reload weights
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
encoder.load_state_dict(reloaded['encoder'])
decoder.load_state_dict(reloaded['decoder'])
params.src_id = model_params.lang2id[params.src_lang]
params.tgt_id = model_params.lang2id[params.tgt_lang]

# float16
if params.fp16:
assert torch.backends.cudnn.enabled
encoder = network_to_half(encoder)
decoder = network_to_half(decoder)

# read sentences from stdin
src_sent = []
for line in sys.stdin.readlines():
assert len(line.strip().split()) > 0
src_sent.append(line)
logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))

f = io.open(params.output_path, 'w', encoding='utf-8')

for i in range(0, len(src_sent), params.batch_size):

# prepare batch
word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
for s in src_sent[i:i + params.batch_size]]
lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
batch[0] = params.eos_index
for j, s in enumerate(word_ids):
if lengths[j] > 2: # if sentence not empty
batch[1:lengths[j] - 1, j].copy_(s)
batch[lengths[j] - 1, j] = params.eos_index
langs = batch.clone().fill_(params.src_id)

# encode source batch and translate it
encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
encoded = encoded.transpose(0, 1)
decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))

# convert sentences to words
for j in range(decoded.size(1)):

# remove delimiters
sent = decoded[:, j]
delimiters = (sent == params.eos_index).nonzero().view(-1)
assert len(delimiters) >= 1 and delimiters[0].item() == 0
sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

# output translation
source = src_sent[i + j].strip()
target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
f.write(target + "\n")

f.close()


if __name__ == '__main__':

# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()

# check parameters
assert os.path.isfile(params.model_path)
assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
assert params.output_path and not os.path.isfile(params.output_path)

# translate
with torch.no_grad():
main(params)

0 comments on commit 97aeb8d

Please sign in to comment.