# WMT16

* Segment level (pearson) correlation to human scores

In [1]:
# install dependencies

!pip -q install datasets
!pip -q install bert_score
!pip -q install git+https://github.com/google-research/bleurt.git
!pip -q install unbabel-comet
!pip -q install transformers
!pip -q install POT

[K     |████████████████████████████████| 365 kB 37.0 MB/s 
[K     |████████████████████████████████| 141 kB 72.2 MB/s 
[K     |████████████████████████████████| 212 kB 68.7 MB/s 
[K     |████████████████████████████████| 115 kB 73.5 MB/s 
[K     |████████████████████████████████| 101 kB 12.1 MB/s 
[K     |████████████████████████████████| 596 kB 68.6 MB/s 
[K     |████████████████████████████████| 127 kB 62.7 MB/s 
[K     |████████████████████████████████| 60 kB 1.5 MB/s 
[K     |████████████████████████████████| 4.7 MB 37.9 MB/s 
[K     |████████████████████████████████| 6.6 MB 16.0 MB/s 
[K     |████████████████████████████████| 352 kB 11.8 MB/s 
[K     |████████████████████████████████| 1.3 MB 60.4 MB/s 
[?25h  Building wheel for BLEURT (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 64 kB 2.7 MB/s 
[K     |████████████████████████████████| 116 kB 72.9 MB/s 
[K     |████████████████████████████████| 409 kB 62.0 MB/s 
[K     |██████████████

In [2]:
!git clone https://github.com/drehero/geneval

Cloning into 'geneval'...
remote: Enumerating objects: 582, done.[K
remote: Counting objects: 100% (232/232), done.[K
remote: Compressing objects: 100% (159/159), done.[K
remote: Total 582 (delta 102), reused 187 (delta 63), pack-reused 350[K
Receiving objects: 100% (582/582), 53.29 MiB | 11.59 MiB/s, done.
Resolving deltas: 100% (246/246), done.
Checking out files: 100% (192/192), done.


In [3]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [4]:
import pathlib

import datasets
import numpy as np
import pandas as pd
from scipy.stats import pearsonr

from geneval.geneval.data.wmt import WMT16

In [5]:
# import metric config
from geneval.reproduction.configs import baryscore_config as config

In [6]:
config.compute_args["batch_size"] = 64

In [17]:
out_path = pathlib.Path(f"/content/drive/MyDrive/results/wmt16/")
lang_pairs = ["cs-en", "de-en", "fi-en", "ru-en"]

In [18]:
scorer = datasets.load_metric(config.metric_path, **config.load_args)

In [19]:
for lang_pair in lang_pairs:
    # load data
    wmt = WMT16(lang_pair)

    # compute score
    args = config.compute_args.copy()
    if config.uses_reference:
        args["references"] = wmt.references
    if config.uses_source:
        args["sources"] = wmt.sources
    
    scores = scorer.compute(
        predictions=wmt.translations,
        **args
    )

    # save
    df = pd.DataFrame({
        "translation": wmt.translations,
        "reference": wmt.references,
        "source": wmt.sources,
        "human_score": wmt.scores,
        "metric_score": scores[config.score_name] if config.score_name is not None else scores
    })
    if "model_type" in args.keys():
        fn = f"{lang_pair}-{args['model_type'].split('/')[-1]}.csv"
    elif "config_name" in config.load_args.keys():
        fn = f"{lang_pair}-{config.load_args['config_name'].split('/')[-1]}.csv"
    else:
        fn = f"{lang_pair}.csv"
    df.to_csv(out_path / config.metric_name / fn, index=False)

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

KeyboardInterrupt: ignored

In [20]:
# load scores and compute pearson correlation
results = {}
for lang_pair in lang_pairs:
    if "model_type" in config.compute_args.keys():
        fn = f"{lang_pair}-{config.compute_args['model_type'].split('/')[-1]}.csv"
    elif "config_name" in config.load_args.keys():
        fn = f"{lang_pair}-{config.load_args['config_name'].split('/')[-1]}.csv"
    else:
        fn = f"{lang_pair}.csv"
    df = pd.read_csv(out_path / config.metric_name / fn)
    corr = pearsonr(df["metric_score"], df["human_score"])[0]
    results[lang_pair] = corr

In [21]:
results

{'cs-en': -0.7513220224201244,
 'de-en': -0.7299728767472832,
 'fi-en': -0.7680262081558806,
 'ru-en': -0.7302746298921875}