In [None]:
import opennmt
import os
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.backend as K
import numpy as np
import sacrebleu
import pyonmttok
from opennmt.utils import checkpoint as checkpoint_util
from pyonmttok import SentencePieceTokenizer

In [None]:
def count_weights(model):
  trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
  non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

  print('Total params: {:,}'.format(trainable_count + non_trainable_count))
  print('Trainable params: {:,}'.format(trainable_count))
  print('Non-trainable params: {:,}'.format(non_trainable_count))

def compute_scores(runner, features_filename, labels_filename, pred_filename, include_ppl=False, include_ter=False):
  runner.infer(features_filename, pred_filename)

  dot_idx = pred_filename.index('.')
  base_pred_name = pred_filename[0:dot_idx]
  dot_idx = labels_filename.index('.')
  base_model_name = labels_filename[0:dot_idx]
  pred_filename = detokenize_data(base_pred_name, base_model_name)
  detokenized_labels_filename = detokenize_data(base_model_name, base_model_name)
  preds = []
  truth = []
  with open(pred_filename) as f:
    preds = f.readlines()

  with open(detokenized_labels_filename) as f:
    truth = f.readlines()

  scores = dict()
  if include_ppl:
    scores = runner.evaluate(
        features_file=features_filename,
        labels_file=labels_filename)
  
  bleu = sacrebleu.corpus_bleu(preds, [truth])
  scores.update({'bleu': bleu.score})
  if include_ter:
    ter = sacrebleu.corpus_ter(preds, [truth])
    scores.update({'ter': ter.score})
  
  return scores

def detokenize(model_basename, tokenized_basename):
  model_path = os.path.join("sentencepiece_models", model_basename + ".model")
  vocabulary_path = os.path.join("sentencepiece_models", f"{model_basename}.vocab")
  detokenizer = SentencePieceTokenizer(model_path=model_path,
                                     vocabulary_path=vocabulary_path,)
  
  with open(f"{tokenized_basename}.tok") as f:
    with open(f"{tokenized_basename}.txt", mode="w") as fout:
      for line in f.readlines():
        fout.write(detokenizer.detokenize(line.strip().split(" ")) + "\n")

  return f"{tokenized_basename}.txt"

def tokenize(input_file, basename):
  model_path = os.path.join("sentencepiece_models", f"{basename}.model")
  vocabulary_path = os.path.join("sentencepiece_models", f"{basename}.vocab")
  tokenizer = SentencePieceTokenizer(model_path=model_path,
                                     vocabulary_path=vocabulary_path,)
  
  with open(os.path.join(f"{input_file}")) as f:
    with open(os.path.join(f"{basename}.tok"), mode="w") as fout:
      for line in f.readlines():
        if line.strip():
          fout.write(" ".join(tokenizer.tokenize(line)[0]) + "\n")

In [None]:
# Build vocab (uses SentencePiece)
# source = catalan   (ca)
# pivot  = spanish   (es)
# target = italian  (it)

!onmt-build-vocab --sentencepiece model_type=bpe --size 32000 --save_vocab sentencepiece_models/src src_pvt_data/src_train.txt
!onmt-build-vocab --sentencepiece model_type=bpe --size 32000 --save_vocab sentencepiece_models/pvt_src src_pvt_data/pvt_src_train.txt

!onmt-build-vocab --sentencepiece model_type=bpe --size 32000 --save_vocab sentencepiece_models/tgt pvt_tgt_data/tgt_train.txt
!onmt-build-vocab --sentencepiece model_type=bpe --size 32000 --save_vocab sentencepiece_models/pvt_tgt pvt_tgt_data/pvt_tgt_train.txt

!onmt-build-vocab --sentencepiece model_type=bpe --size 32000 --save_vocab sentencepiece_models/src_tgt src_tgt_data/src_tgt_train.txt
!onmt-build-vocab --sentencepiece model_type=bpe --size 32000 --save_vocab sentencepiece_models/tgt_src src_tgt_data/tgt_src_train.txt

In [None]:
tokenize("src_tgt_test.txt","src_tgt_test")
tokenize("src_tgt_train.txt","src_tgt_train")
tokenize("src_tgt_val.txt","src_tgt_val")
tokenize("tgt_src_test.txt","tgt_src_test")
tokenize("tgt_src_train.txt","tgt_src_train")
tokenize("tgt_src_val.txt","tgt_src_val")

tokenize("src_test.txt","src_test")
tokenize("src_train.txt","src_train")
tokenize("src_val.txt","src_val")
tokenize("pvt_src_test.txt","pvt_src_test")
tokenize("pvt_src_train.txt","pvt_src_train")
tokenize("pvt_src_val.txt","pvt_src_val")

tokenize("pvt_tgt_test.txt","pvt_tgt_test")
tokenize("pvt_tgt_train.txt","pvt_tgt_train")
tokenize("pvt_tgt_val.txt","pvt_tgt_val")
tokenize("tgt_test.txt","tgt_test")
tokenize("tgt_train.txt","tgt_train")
tokenize("tgt_val.txt","tgt_val")

In [None]:
config_src_pvt = {
    "model_dir": "src_pvt_model/",
    "data": {
        "train_features_file": "src_pvt_data/src_train.tok/",
        "train_labels_file": "src_pvt_data/pvt_src_train.tok/",
        "eval_features_file": "src_pvt_data/src_val.tok/",
        "eval_labels_file": "src_pvt_data/pvt_src_val.tok/",
        "source_vocabulary": "sentencepiece_models/src.vocab/",
        "target_vocabulary": "sentencepiece_models/pvt_src.vocab/",
    },
    "train": {
        "max_step": 25000,
        "save_checkpoints_steps": 500,
        "keep_checkpoint_max": 2,
    },
    "eval": {
        "save_eval_predictions": True,
        "steps": 50000,
        "max_exports_to_keep": 2,
        "early_stopping": {
            "metric": "loss",
            "min_improvement": 0.1,
            "steps": 100,
        },
    }
}

config_pvt_tgt = {
    "model_dir": "/content/pvt_tgt_model/",
    "data": {
        "train_features_file": "/content/pvt_tgt_data/pvt_tgt_train.tok/",
        "train_labels_file": "/content/pvt_tgt_data/tgt_train.tok/",
        "eval_features_file": "/content/pvt_tgt_data/pvt_tgt_val.tok/",
        "eval_labels_file": "/content/pvt_tgt_data/tgt_val.tok/",
        "source_vocabulary": "/content/sentencepiece_models/pvt_tgt.vocab/",
        "target_vocabulary": "/content/sentencepiece_models/tgt.vocab/",
    },
    "train": {
        "max_step": 25000,
        "save_checkpoints_steps": 500,
        "keep_checkpoint_max": 2,
    },
    "eval": {
        "save_eval_predictions": True,
        "steps": 50000,
        "max_exports_to_keep": 2,
        "early_stopping": {
            "metric": "loss",
            "min_improvement": 0.1,
            "steps": 100,
        },
    }
}

config_src_tgt = {
    "model_dir": "/content/src_tgt_model/",
    "data": {
        "train_features_file": "/content/src_tgt_data/src_tgt_train.tok",
        "train_labels_file": "/content/src_tgt_data/tgt_src_train.tok",
        "eval_features_file": "/content/src_tgt_data/src_tgt_val.tok",
        "eval_labels_file": "/content/src_tgt_data/tgt_src_val.tok",
        "source_vocabulary": "/content/sentencepiece_models/src_tgt.vocab",
        "target_vocabulary": "/content/sentencepiece_models/tgt_src.vocab",
    },
    "train": {
        "max_step": 25000,
        "save_checkpoints_steps": 500,
        "keep_checkpoint_max": 2,
    },
    "eval": {
        "save_eval_predictions": True,
        "steps": 50000,
        "max_exports_to_keep": 2,
        "early_stopping": {
            "metric": "loss",
            "min_improvement": 0.1,
            "steps": 100,
        },
    }
}

config_baseline = {
    "model_dir": "/content/baseline_model/",
    "data": {
        "train_features_file": "/content/src_tgt_data/src_tgt_train.tok/",
        "train_labels_file": "/content/src_tgt_data/tgt_src_train.tok/",
        "eval_features_file": "/content/src_tgt_data/src_tgt_val.tok/",
        "eval_labels_file": "/content/src_tgt_data/tgt_src_val.tok/",
        "source_vocabulary": "/content/sentencepiece_models/src_tgt.vocab/",
        "target_vocabulary": "/content/sentencepiece_models/tgt_src.vocab/",
    },
    "train": {
        "max_step": 25000,
        "save_checkpoints_steps": 500,
        "keep_checkpoint_max": 2,
    },
    "eval": {
        "save_eval_predictions": True,
        "steps": 50000,
        "max_exports_to_keep": 2,
        "early_stopping": {
            "metric": "loss",
            "min_improvement": 0.1,
            "steps": 100,
        },
    }
}

In [None]:
learning_rate = opennmt.schedules.NoamDecay(scale=2.0, model_dim=512, warmup_steps=8000)
optimizer = tfa.optimizers.LazyAdam(learning_rate)

In [None]:
# Training source-pivot model
src_pvt_model = opennmt.models.TransformerBase()
src_pvt_runner = opennmt.Runner(src_pvt_model, config_src_pvt, auto_config=True)
sp_config = src_pvt_runner._finalize_config(training=True)

src_pvt_runner.train(num_devices=1, with_eval=True)

In [None]:
# Training pivot-target model
pvt_tgt_model = opennmt.models.TransformerBase()
pvt_tgt_runner = opennmt.Runner(pvt_tgt_model, config_pvt_tgt, auto_config=True)
pt_config = pvt_tgt_runner._finalize_config(training=True)

pvt_tgt_runner.train(num_devices=1, with_eval=True)

In [None]:
# Restore both models weights
src_pvt_model.initialize(data_config=sp_config['data'], params=sp_config['params'])
src_pvt_model.create_variables(optimizer=optimizer)

pvt_tgt_model.initialize(data_config=pt_config['data'], params=pt_config['params'])
pvt_tgt_model.create_variables(optimizer=optimizer)

checkpoint_path = sp_config['model_dir']
checkpoint = checkpoint_util.Checkpoint.from_config(sp_config, src_pvt_model, optimizer=optimizer)
checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True)

checkpoint_path = pt_config['model_dir']
checkpoint = checkpoint_util.Checkpoint.from_config(pt_config, pvt_tgt_model, optimizer=optimizer)
checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True)

count_weights(src_pvt_model)
count_weights(pvt_tgt_model)

In [None]:
# Transfer weights to src_tgt_model
src_tgt_model = opennmt.models.TransformerBase()
src_tgt_runner = opennmt.Runner(src_tgt_model, config_src_tgt, auto_config=True)
st_config = src_tgt_runner._finalize_config(training=True)

src_tgt_model.initialize(data_config=st_config['data'], params=st_config['params'])
src_tgt_model.create_variables(optimizer=optimizer)

src_tgt_model.encoder = src_pvt_model.encoder
src_tgt_model.decoder = pvt_tgt_model.decoder

new_checkpoint = checkpoint_util.Checkpoint.from_config(st_config, src_tgt_model, optimizer=optimizer)
new_checkpoint.save()

In [None]:
# Training source-target model (using pretrained models)
src_tgt_model = opennmt.models.TransformerBase()
src_tgt_runner = opennmt.Runner(src_tgt_model, config_src_tgt, auto_config=True)
st_config = src_tgt_runner._finalize_config(training=True)
src_tgt_runner.train(num_devices=1, with_eval=True)

In [None]:
# Training source-target model (using no models)
baseline_model = opennmt.models.TransformerBase()
baseline_runner = opennmt.Runner(baseline_model, config_baseline, auto_config=True)

baseline_runner.train(num_devices=1, with_eval=True)

In [None]:
# Compute scores
baseline_scores = compute_scores(
    runner=baseline_runner,
    features_filename="/content/src_tgt_data/src_tgt_test.txt",
    labels_filename="/content/src_tgt_data/tgt_src_test.txt",
    pred_filename="/content/baseline_pred.txt")

pivot_based_tl_scores = compute_scores(
    runner=src_tgt_runner,
    features_filename="/content/src_tgt_data/src_tgt_test.txt",
    labels_filename="/content/src_tgt_data/tgt_src_test.txt",
    pred_filename="/content/src_to_tgt_pred.txt")

print(f"============ Baseline Source-Target NMT Evaluation ============\n {baseline_scores}")
print(f"============ Pretrain Source-Target NMT Evaluation ============\n {pivot_based_tl_scores}")