In [7]:
!source test_speed_cpu.sh

2021-02-04 09:56:18 | INFO | root | [1/2] Loaded 1 checkpoints in 1.4s
2021-02-04 09:56:20 | INFO | root | [2/2] Decoder created in 2.3s
== Running the same translation 5 times (after warm up) ==
-- Warm Up --
2021-02-04 09:56:22 | INFO | Transformer | reset_time = 0.062, tune_time = 0.000, decode_time = 1.136
------
- Original text: Companies and LSPs can translate their content with the ModernMT service in many languages directly on their production environment thanks to our simple RESTful API .
------
- Translated text: Le aziende e i LSP possono tradurre il loro contenuto con il servizio di modernità in molte lingue direttamente nel loro ambiente di produzione grazie alla nostra semplice API .
------
- Alignment: [(0, 0), (0, 1), (1, 2), (2, 4), (3, 5), (4, 6), (5, 7), (5, 8), (6, 9), (7, 10), (8, 11), (9, 13), (9, 14), (10, 12), (11, 15), (12, 16), (13, 17), (14, 18), (15, 19), (16, 20), (17, 22), (17, 23), (18, 21), (19, 24), (20, 24), (21, 25), (22, 26), (23, 27), (24, 28)]
-- T

In [3]:
import importlib
import sys
sys.path.append('src')

In [4]:
%load_ext autoreload
%autoreload 2

In [19]:
import torch
import torch.neuron

from mmt import utils
from mmt.checkpoint import CheckpointRegistry
from mmt.decoder import Suggestion, ModelConfig, MMTDecoder
from fairseq.sequence_generator import EnsembleModel

In [28]:
TEST_TEXT = 'Companies and LSPs can translate their content with the ModernMT service in many languages ' \
            'directly on their production environment thanks to our simple RESTful API .'
MODEL_DIR = 'model'
device=None
test_text = TEST_TEXT

In [7]:
config = ModelConfig.load('model')
builder = CheckpointRegistry.Builder()
for name, checkpoint_path in config.checkpoints:
    builder.register(name, checkpoint_path)
checkpoints = builder.build(device)
decoder = MMTDecoder(checkpoints, device=device)

In [29]:
# A simple translation without using a tuner () 
trans_1 = decoder.translate('en', 'it', [test_text])[0]
print(f'- Using [decoder.translate]: {trans_1.text}')
trans_2 = decoder._decode('en', 'it', [test_text])[0]
print(f'- Using [decoder._decode]: {trans_2.text}')
print('------')
print(f'Output of [decoder.translate] {"==" if trans_1.text == trans_2.text else "!="} [decoder._decode]')

- Using [decoder.translate]: Le aziende e i LSP possono tradurre il loro contenuto con il servizio di modernità in molte lingue direttamente nel loro ambiente di produzione grazie alla nostra semplice API .
- Using [decoder._decode]: Le aziende e i LSP possono tradurre il loro contenuto con il servizio di modernità in molte lingue direttamente nel loro ambiente di produzione grazie alla nostra semplice API .
------
Output of [decoder.translate] == [decoder._decode]


## Attempt to trace the `decoder` object

#### Observations

In [9]:
print(f'Type of [decoder] is {type(decoder)}')
print(f'Type of [decoder._translator] is {type(decoder._translator)}')
print(f'Type of [decoder._model] is {type(decoder._model)}')

# Tuner function is not used in these experiments
# print(f'Type of `decoder._tuner`: {type(decoder._tuner)}')

Type of [decoder] is <class 'mmt.decoder.MMTDecoder'>
Type of [decoder._translator] is <class 'fairseq.sequence_generator.SequenceGenerator'>
Type of [decoder._model] is <class 'fairseq.models.transformer.TransformerModel'>


- Type of `decoder` is `MMTDecoder`, it is using `fairseq` classes internally
- Internally, `decoder.translate('en', 'it', [TEST_TEXT])` ultimately calls `decoder._translator.generate([decoder._model], sample)` (`fairseq.sequence_generator.SequenceGenerator.generate(models, sample)` where `sample` is a tokenised text) but there is also `SequenceGenerator._generate(sample)` which uses `models` set within a `SequenceGenerator` constructor (which is also set correctly in the `MMTDecoder` constructor)

#### Try to use `decoder._translator._generate(sample)` directly

In [30]:
test_text_encode, input_indexes, sentence_len = decoder._make_decode_batch([TEST_TEXT])

In [31]:
print(f'Test text:\n{test_text}')
print(f'Encoded:\n{test_text_encode}')
# print(input_indexes)
# print(sentence_len)

Test text:
Companies and LSPs can translate their content with the ModernMT service in many languages directly on their production environment thanks to our simple RESTful API .
Encoded:
{'net_input': {'src_tokens': tensor([[ 9055,  9632,   518,    22,  4764, 17506,   126,   127, 15470,   144,
          2242,    54,    16, 21163,  6719, 29625,  1519,    18,   278,  3580,
          2352,    34,   144,  1027,   887,  1933,    20,    99,  3127,  9896,
         24193,  6082, 13779,    33,    15,     2]]), 'src_lengths': tensor([36])}}


In [12]:
trans_3 = decoder._decode_without_explicit_model('en', 'it', [TEST_TEXT])[0]
is_equal = "==" if trans_1.text == trans_3.text else "!="
print(f'- Using [decoder._translator._generate]: {trans_3.text}')
print('------')
print(f'Output of [decoder.translate] {is_equal} [decoder._translator._generate]')

- Using [decoder._translator._generate]: Le aziende e i LSP possono tradurre il loro contenuto con il servizio di modernità in molte lingue direttamente nel loro ambiente di produzione grazie alla nostra semplice API .
------
Output of [decoder.translate] == [decoder._translator._generate]


In [13]:
class GeneratorWrapper(torch.nn.Module):
    def __init__(self, generator):
        super(GeneratorWrapper, self).__init__()
        self.generator = generator

    def forward(self, x):
        # `_decode_from_sample` is a simple method I added which calls SequenceGenerator._generate
        return self.generator._decode_from_sample(x)

***Attemp to JIT or Neuron trace here kills the browser due to a lot of data generated by the trace, to see those run a dedicated Python script in Terminal `python mmt_trace.py &> log.txt`. This generates a large log file!***


In [None]:
# Attempt to trace
gen_wrapper = GeneratorWrapper(decoder)
#jit_gen = torch.jit.trace(gen_wrapper, sample)
#neuron_gen torch.neuron.trace(gen_wrapper, sample)

  int(self.max_len_a * src_len + self.max_len_b),


### Trace the model only

In [22]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        # `_decode_from_sample` is a simple method I added which calls SequenceGenerator._generate
        res = self.model.forward_encoder(x)
        print(res)
        return torch.Tensor([[0]])

In [23]:
model_wrapper = ModelWrapper(EnsembleModel([decoder._model]))
jit_model = torch.jit.trace(model_wrapper, sample['net_input'])

  # Remove the CWD from sys.path while we load stuff.


[EncoderOut(encoder_out=tensor([[[ 1.9738e-02,  5.7294e-02, -1.1316e-01,  ..., -1.5601e-01,
           1.8267e-01, -2.2096e-02]],

        [[ 6.7960e-04, -8.9271e-03, -1.1074e-01,  ..., -1.3071e-01,
           2.8182e-02,  4.4077e-02]],

        [[-4.3508e-02,  4.1582e-02, -1.0471e-01,  ..., -2.0630e-01,
           6.9779e-02, -7.7151e-02]],

        ...,

        [[-2.2303e-01,  3.3949e-01, -3.5549e-01,  ..., -1.4366e-01,
          -2.6553e-01, -1.7634e-01]],

        [[ 1.7761e-02,  1.6541e-02,  2.6702e-02,  ..., -1.6961e-02,
           3.9059e-02, -2.5640e-04]],

        [[ 1.7764e-02,  1.6547e-02,  2.6700e-02,  ..., -1.6976e-02,
           3.9067e-02, -2.5229e-04]]], grad_fn=<NativeLayerNormBackward>), encoder_padding_mask=tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, Fal

In [24]:
jit_model

ModelWrapper(
  original_name=ModelWrapper
  (model): RecursiveScriptModule(
    original_name=EnsembleModel
    (single_model): RecursiveScriptModule(
      original_name=TransformerModel
      (encoder): RecursiveScriptModule(
        original_name=TransformerEncoder
        (dropout_module): RecursiveScriptModule(original_name=FairseqDropout)
        (embed_tokens): RecursiveScriptModule(original_name=Embedding)
        (embed_positions): RecursiveScriptModule(original_name=SinusoidalPositionalEmbedding)
        (layers): RecursiveScriptModule(
          original_name=ModuleList
          (0): RecursiveScriptModule(
            original_name=TransformerEncoderLayer
            (self_attn): RecursiveScriptModule(
              original_name=MultiheadAttention
              (dropout_module): RecursiveScriptModule(original_name=FairseqDropout)
              (k_proj): RecursiveScriptModule(original_name=Linear)
              (v_proj): RecursiveScriptModule(original_name=Linear)
        

In [25]:
dir(decoder._model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__dataclass',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_name',
 '_is_generation_fast',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_version',
 'add_args',
 'add_module',
 'apply',
 'args',
 'bfloat16',
 'buffers',
 'b