<a href="https://colab.research.google.com/github/joeynmt/joeynmt/blob/main/notebooks/torchhub.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import logging
# supress the logging
logging.disable(logging.CRITICAL)

# Torch hub interface

Author: May Ohta  
Last update: 15. January 2024

---

:warning: Run this notebook on single GPU or CPU.

Check if you have a compatible torch version.

In [2]:
import torch

torch.__version__

'2.1.0+cu121'

Install Joeynmt v2.3:

In [3]:
!pip install -q git+https://github.com/joeynmt/joeynmt.git
!pip show joeynmt

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.3/106.3 kB[0m [31m961.4 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.6/510.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.7/254.7 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.6/57.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.7 MB

Load En-De / De-En Transformer models trained on WMT'14 data.

Note: This takes a few minutes.

In [None]:
en2de = torch.hub.load('joeynmt/joeynmt', 'wmt14_ende')
de2en = torch.hub.load('joeynmt/joeynmt', 'wmt14_deen')

## Translate

Provide a list of src sentences in `translate()` func.  
The return value is a (flat) list of string.

In [5]:
# en->de model, single sentence
en2de.translate(["Hello world!"])

['Hallo Welt!']

In [6]:
# de->en model, multiple sentences
de2en.translate(["Hallo Welt!", "Wie geht's?"])

['Hello world!', 'What is going on?']

## N-best list

Set `beam_size` and `n_best` options.  
(has to be `beam_size` > `n_best`)

In [7]:
de2en.translate(["Wie geht's?"], beam_size=5, n_best=5)

['What is going on?',
 'How is it?',
 'How is it going?',
 'How is this going?',
 'How can I do it?']

In [8]:
en2de.translate(["Hello world!", "How are you?"], beam_size=5, n_best=5)

['Hallo Welt!',
 'Hello Welt!',
 'Hello world!',
 'Helle Welt!',
 'Hello World!',
 'Wie sind Sie?',
 'Wie sind Sie da?',
 'Wie sind Sie denn?',
 'Wie stehen Sie?',
 'Wie sind Sie denn da?']

## Scoring

#### 1. Hypothesis scoring
- call `score()` func without target.
- if Greedy Decoding (`beam_size == 1`), token-level log-likelihood is returned.
- if Beam Search (`beam_size > 1`), sequence-level log-likelihood is returned.

#### 2. Reference scoring
- call `score()` func with target.
    
The return value is a list of `PredictionOutput` objects, which have
- translation: List[str]
- tokens: List[List[str]]
- token_probs: Optional[List[List[float]]]
- sequence_probs: Optional[List[float]]
- attention_probs: Optional[List[List[float]]]

Each item holds a list of length = n_best.

In [9]:
beam_size = 1 # Greedy Decoding -> token-level log-likelihood
n_best = 1
sentences = ["Hello world!", "How are you?"]
greedy_out = en2de.score(sentences, beam_size=beam_size, n_best=n_best)

for i, out in enumerate(greedy_out):
    print(f'## sent {i+1} ##')
    for n in range(n_best):
        for score, token in zip(out.token_probs[n], out.tokens[n]):
            print(f'{score:.4f} {token}')
        print()

## sent 1 ##
-1.5346 ▁Hallo
-0.1637 ▁Welt
-0.1131 ▁
-0.0694 !
-0.1075 </s>

## sent 2 ##
-0.2651 ▁Wie
-0.7192 ▁sind
-0.1603 ▁Sie
-0.6431 ▁
-0.1069 ?
-0.0967 </s>



In [10]:
beam_size = 5 # Beam Search -> sequence-level log-likelihood
n_best = 3
sentences = ['Hello world!', "How are you?"]
beam_out = en2de.score(sentences, beam_size=beam_size, n_best=n_best)

for i, out in enumerate(beam_out):
    print(f'## sent {i+1} ##')
    for n in range(n_best):
        print(f'{n+1} best: {out.sequence_probs[n]:.4f} {out.translation[n]}')
    print()

## sent 1 ##
1 best: -1.4635 Hallo Welt!
2 best: -2.4815 Hello Welt!
3 best: -2.9123 Hello world!

## sent 2 ##
1 best: -1.3846 Wie sind Sie?
2 best: -2.2482 Wie sind Sie da?
3 best: -2.6511 Wie sind Sie denn?



In [11]:
# Reference scoring -> always n_best = 1
ref_scores = en2de.score(
    src=['I like cookies.', 'I like cookies.'],
    trg=['Ich mag Kekse.', 'Ich liebe Kekse.'],
)

n_best = 1
for i, ref_score in enumerate(ref_scores):
    print(f'## sent {i+1} ##')
    for n in range(n_best):
        for s, t in zip(ref_score.token_probs[n], ref_score.tokens[n]):
            print(f'{s:.4f} {t}')
        print()

## sent 1 ##
-0.2305 ▁Ich
-0.3924 ▁mag
-6.8749 ▁Ke
-1.1366 k
-0.0226 se
-0.1327 ▁
-0.0989 .
-0.1135 </s>

## sent 2 ##
-0.2305 ▁Ich
-3.5479 ▁liebe
-6.6423 ▁Ke
-1.3268 k
-0.0343 se
-0.1388 ▁
-0.1114 .
-0.1103 </s>



## Plot attention

In [12]:
# get attention probs
ref_scores = de2en.score(
    src=['Hello world!'],
    trg=['Hallo Welt!'],
)

# plot
de2en.plot_attention(
    src='Hello world!',
    trg='Hallo Welt!',
    attention_scores=ref_scores[0].attention_probs[0],
)

## Prompting

JoeyNMT v2.3 supports multilingual translation with language tags.  
The following model is trained on iwslt14 en-de and en-fr sentence pairs with three language tags: `<de>`, `<en>` and `<fr>`.

In [None]:
iwslt14 = torch.hub.load('joeynmt/joeynmt', 'iwslt14_prompt')

#### 1. language tags (multi-task learning)

In the src prompt, we tell the model from which language we are translating, and  
in the trg prompt, we tell the model to which language we are translating.

In [5]:
# de -> en
iwslt14.translate(
    src=["Hallo Welt!"],
    src_prompt=["<de>"],
    trg_prompt=["<en>"],
    beam_size=5,
    n_best=5,
)

['Hello world!', 'Hello world.', 'Hi World!', 'Hi world!', 'Hello, world!']

In [6]:
# en -> fr
iwslt14.translate(
    src=["How are you?"],
    src_prompt=["<en>"],
    trg_prompt=["<fr>"],
    beam_size=5,
    n_best=5,
)

['Comment es-tu ?',
 'Comment êtes-vous ?',
 'Comment êtes-vous?',
 'Comment allez-vous ?',
 'Comment vous êtes-vous ?']

In [7]:
# en -> de
iwslt14.translate(
    src=["How are you?"],
    src_prompt=["<en>"],
    trg_prompt=["<de>"],
    beam_size=5,
    n_best=5,
)

['Wie geht es Ihnen?',
 "Wie geht's dir?",
 "Wie geht's Ihnen?",
 'Wie sind Sie?',
 'Wie geht es dir?']

#### 2. trg context

In addition to the language tags, we can specify the context in the prompt.  
The model translated with "Sie", in a formal way, for the example above.  
We can tell the model that we'd like to translate it with "Du", in a more casual form.

In [8]:
iwslt14.translate(
    src=["How are you?"],
    src_prompt=["<en>"],
    trg_prompt=["<de> du"],
    beam_size=1,
)

['Wie bist du?']

#### 3. src context

In the following example, the model translated the English pronoun "it" into the neutral form "es" in all top-5 candidates.  
We can give src-side context, so that the model can get more hint about the gender of the pronoun.

In [9]:
# w/o src context
out = iwslt14.score(
    src=["I'll bring it."],
    src_prompt=["<en>"],
    trg_prompt=["<de>"],
    beam_size=5,
    n_best=5,
)
for n in range(5):
    print(f'{n+1} best: {out[0].sequence_probs[n]:.4f} | {out[0].translation[n]}')

1 best: -1.4792 | Ich werde es mitbringen.
2 best: -1.6045 | Ich bringe es dazu.
3 best: -1.6643 | Ich bringe es mit.
4 best: -1.6840 | Ich bringe es an.
5 best: -1.7278 | Ich bringe es.


In [10]:
# w/ src context
out = iwslt14.score(
    src=["I'll bring it."],
    src_prompt=["<en> There is a camera."],
    trg_prompt=["<de>"],
    beam_size=5,
    n_best=5,
)
for n in range(5):
    print(f'{n+1} best: {out[0].sequence_probs[n]:.4f} | {out[0].translation[n]}')

1 best: -1.6253 | Ich bringe es mit.
2 best: -1.6836 | Ich bringe es an.
3 best: -1.7318 | Ich werde es mitbringen.
4 best: -1.7923 | Ich bringe es.
5 best: -1.8206 | Ich bringe sie mit.


The feminine form "sie" appears at the 5th position of the nbest list,  
when we provide "camera" in the src prompt. (In German, "Kamera" is a feminine noun.)  
By designing the prompt, we can gain more control on the trg translation!

### Training with prompt

You can provide a tsv file containing all `src`, `src_prompt`, `trg`, and `trg_prompt`.
See `test/data/toy/dev.tsv` for example. In this case, we used a preceeding sentence as a prompt.

In [13]:
import pandas as pd
from pathlib import Path

hub_dir = Path(torch.hub.get_dir()) / "joeynmt_joeynmt_main"
df = pd.read_csv(hub_dir / "test/data/toy/dev.tsv", sep="\t")
df[["src_prompt", "src", "trg_prompt", "trg"]].head(5)

Unnamed: 0,src_prompt,src,trg_prompt,trg
0,<de>,"Ich freue mich , dass ich da bin .",<en>,I’m happy to be here .
1,"<de> Ich freue mich , dass ich da bin .","Ja , guten Tag .",<en> I’m happy to be here .,"Yes , hello ."
2,"<de> Ja , guten Tag .","Ja , also , was soll Biohacking sein ?","<en> Yes , hello .","Yes , so , what is biohacking ?"
3,"<de> Ja , also , was soll Biohacking sein ?","Ich muss dazu erst mal ein bisschen ausholen ,...","<en> Yes , so , what is biohacking ?",I’ll have to provide some background informati...
4,<de> Ich muss dazu erst mal ein bisschen ausho...,Ich studiere Molekularbiologie und beschäftige...,<en> I’ll have to provide some background info...,I study molecular biology and have been doing ...


You also should pay attention to the `lang` section in the config.  
(We have multiple languages in both sides!)

```yaml
# config.yaml
...
data:
    train: "test/data/iwslt14/train"
    dev: "test/data/iwslt14/validation"
    test: "test/data/iwslt14/test"
    dataset_type: "tsv"
    src:
        lang: "src" # <- instead of "en", "de", etc.
        ...
        
    trg:
        lang: "trg" # <- instead of "en", "de", etc.
        ...
        
    special_symbols:
        unk_token: "<unk>"
        pad_token: "<pad>"
        bos_token: "<s>"
        eos_token: "</s>"
        sep_token: "<sep>" # <- sepatator token
        unk_id: 0
        pad_id: 1
        bos_id: 2
        eos_id: 3
        sep_id: 4 # <- separator token index
        lang_tags: ["<de>", "<en>"] # <- language tags
    ...
```