<a href="https://colab.research.google.com/github/cw118/domain-adapted-nmt/blob/main/3_two_step_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A two-step approach to NMT

## Part 3: the two-step classification and translation approach

### Zero-shot text classifier (step 1)

In [1]:
# access key for hugging face (optional but recommended for public models and datasets)
from google.colab import userdata, drive
userdata.get('HF_TOKEN')

'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'

In [2]:
drive.mount('/content/drive')
%cd drive/MyDrive/domain-adapted-nmt/nmt-tools

Mounted at /content/drive
/content/drive/MyDrive/domain-adapted-nmt/nmt-tools


In [3]:
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli")

config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [4]:
candidate_labels = ['technology', 'medicine', 'none']

In [5]:
# test zero-shot classifier
tech_seq = "The width of the main window in pixels."
med_seq = "Concomitant use of sildenafil with ritonavir dosed as an antiretroviral agent or as a pharmacokinetic enhancer is not recommended and in no instance should sildenafil doses exceed 25 mg in 48 hours (see also section 4.4)."
gen_seq = "The agreed conclusions have also facilitated the mainstreaming of a gender perspective in the work of the Council's functional commissions, and of other intergovernmental bodies."
rand_seq = "Hopefully this will not cause too much trouble, but perfection is impossible and that's alright."

classifier(tech_seq, candidate_labels), classifier(med_seq, candidate_labels), classifier(gen_seq, candidate_labels), classifier(rand_seq, candidate_labels)

({'sequence': 'The width of the main window in pixels.',
  'labels': ['technology', 'none', 'medicine'],
  'scores': [0.8571749329566956, 0.09038852900266647, 0.05243648216128349]},
 {'sequence': 'Concomitant use of sildenafil with ritonavir dosed as an antiretroviral agent or as a pharmacokinetic enhancer is not recommended and in no instance should sildenafil doses exceed 25 mg in 48 hours (see also section 4.4).',
  'labels': ['medicine', 'technology', 'none'],
  'scores': [0.9056766033172607, 0.06045923009514809, 0.03386416658759117]},
 {'sequence': "The agreed conclusions have also facilitated the mainstreaming of a gender perspective in the work of the Council's functional commissions, and of other intergovernmental bodies.",
  'labels': ['none', 'technology', 'medicine'],
  'scores': [0.54505854845047, 0.2401156723499298, 0.21482577919960022]},
 {'sequence': "Hopefully this will not cause too much trouble, but perfection is impossible and that's alright.",
  'labels': ['none', '

### Assign text to corresponding NMT model (step 2)

In [6]:
!pip3 install OpenNMT-py sentencepiece ctranslate2
!pip3 install tensorrt sacrebleu

Collecting OpenNMT-py
  Downloading OpenNMT_py-3.5.0-py3-none-any.whl (262 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m262.8/262.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting ctranslate2
  Downloading ctranslate2-4.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.9/35.9 MB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Collecting configargparse (from OpenNMT-py)
  Downloading ConfigArgParse-1.7-py3-none-any.whl (25 kB)
Collecting waitress (from OpenNMT-py)
  Downloading waitress-3.0.0-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.7/56.7 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyonmttok<2,>=1.35 (from OpenNMT-py)
  Downloading pyonmttok-1.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.0/17.0 MB[0m [31m

In [7]:
# convert to use ctranslate2 inference engine (only needs to be done once!)
!ct2-opennmt-py-converter --model_path models/model-base.enfr_step_10000.pt --output_dir gen-enfr_ct2
!ct2-opennmt-py-converter --model_path models/model-tech.enfr_step_3000.pt --output_dir tech-enfr_ct2
!ct2-opennmt-py-converter --model_path models/model-med.enfr_step_6000.pt --output_dir med-enfr_ct2

In [8]:
import sentencepiece as spm
import ctranslate2

# prepare ctranslate2 converted models and previous sentencepiece models for final translation task

gen_translator = ctranslate2.Translator("gen-enfr_ct2/", device="cpu")
tech_translator = ctranslate2.Translator("tech-enfr_ct2/", device="cpu")
med_translator = ctranslate2.Translator("med-enfr_ct2/", device="cpu")

gen_inp_sp = spm.SentencePieceProcessor("source-general.model")
tech_inp_sp = spm.SentencePieceProcessor("source-technology.model")
med_inp_sp = spm.SentencePieceProcessor("source-medicine.model")

gen_out_sp = spm.SentencePieceProcessor("target-general.model")
tech_out_sp = spm.SentencePieceProcessor("target-technology.model")
med_out_sp = spm.SentencePieceProcessor("target-medicine.model")

translators = {
    'none': gen_translator,
    'technology': tech_translator,
    'medicine': med_translator
}

inp_sps = {
    'none': gen_inp_sp,
    'technology': tech_inp_sp,
    'medicine': med_inp_sp
}

out_sps = {
    'none': gen_out_sp,
    'technology': tech_out_sp,
    'medicine': med_out_sp
}

In [9]:
corpora_path = 'corpora/enfr'

def translate(seq):
  category = classifier(seq, candidate_labels)['labels'][0] # get the best matching label
  translator, inp_sp, out_sp = translators[category], inp_sps[category], out_sps[category]

  inp_tokens = inp_sp.encode(seq, out_type=str)
  res = translator.translate_batch([inp_tokens])

  out_tokens = res[0].hypotheses[0]
  out = out_sp.decode(out_tokens)

  return category, out

In [10]:
# translate files using two-step approach
def two_step_translate(domain):
  score, total = 0, 0

  with open(f'{corpora_path}/en-fr-{domain}.en-filtered.en.subword.test') as source, open(f'{corpora_path}/en-fr-{domain}.2s-translated', 'w') as res:
    for l in source:
      category, translated = translate(l)
      res.write(translated + '\n')

      if domain == 'general':
        score += 1 if category == 'none' else 0
      else:
        score += 1 if category == domain else 0
      total += 1

  return score / total

In [None]:
gen_class_score = two_step_translate('general')

In [None]:
!python3 compute-bleu.py corpora/enfr/en-fr-general.fr-filtered.fr.subword.test.desubword corpora/enfr/en-fr-general.2s-translated

In [None]:
tech_class_score = two_step_translate('technology')
!python3 compute-bleu.py corpora/enfr/en-fr-technology.fr-filtered.fr.subword.test.desubword corpora/enfr/en-fr-technology.2s-translated

In [None]:
med_class_score = two_step_translate('medicine')
!python3 compute-bleu.py corpora/enfr/en-fr-medicine.fr-filtered.fr.subword.test.desubword corpora/enfr/en-fr-medicine.2s-translated

In [None]:
!head -n 5 corpora/enfr/en-fr-general.2s-translated

In [None]:
gen_class_score, tech_class_score, med_class_score