In [None]:
!pip3 install shap
!pip3 install mosestokenizer
!pip3 install truecase

In [1]:
import os
import sys
sys.path.append('./xmover')
import numpy as np
import json
from tqdm.notebook import tqdm
from scipy.stats import pearsonr
import shap
from mosestokenizer import MosesDetokenizer, MosesTokenizer
from scorer import XMOVERScorer
import torch
import truecase
from xmover_explainer import ExplainableXMover

from IPython.core.display import display, HTML

In [2]:
RESULTS_FNAME = 'results.json'
SRC_LANG = 'ro'
TGT_LANG = 'en'
SPLIT = 'dev'

## Load Dataset

In [3]:
data_dir = f'../../data/{SPLIT}/{SRC_LANG}-{TGT_LANG}-{SPLIT}'
src = [s.strip() for s in open(f'{data_dir}/{SPLIT}.src').readlines()]
tgt = [s.strip() for s in open(f'{data_dir}/{SPLIT}.mt').readlines()]
wor = [list(map(int, s.strip().split())) for s in open(f'{data_dir}/{SPLIT}.tgt-tags').readlines()]
sen = [float(s.strip()) for s in open(f'{data_dir}/{SPLIT}.da').readlines()]
assert len(src) == len(tgt) == len(wor) == len(sen)
dataset = {'src': src, 'tgt': tgt, 'word_labels': wor, 'sent_labels': sen}

## Get XMover Explainer to Rate and Explain
This step can cost quite some time: on a 6-core workstation with a single RTX 2080 GPU card, explaining each translation costs around 3 seconds on average. Hence, explaining all 1000 cases in the dev set takes around 1 hour to finish.

In [4]:
model = ExplainableXMover(SRC_LANG, TGT_LANG)

exps = []
for i in tqdm(range(len(dataset['src']))):
    # score = model(src, trans) # uncomment this line if you also want the xmover-score
    exp = model.explain(dataset['src'][i], dataset['tgt'][i])
    exps.append(exp)

  0%|          | 0/1000 [00:00<?, ?it/s]


Permutation explainer: 2it [00:10, 10.36s/it]               [A

Permutation explainer: 2it [00:16, 16.97s/it]               [A

Permutation explainer: 2it [00:16, 16.65s/it]               [A

Permutation explainer: 2it [00:11, 11.49s/it]               [A

Permutation explainer: 2it [00:10, 10.36s/it]               [A

Permutation explainer: 2it [00:10, 10.55s/it]               [A


In [5]:
# optional: save the explanations
import pickle
with open('{}-{}_exps.pkl'.format(SRC_LANG, TGT_LANG),'wb') as ff:
    pickle.dump(exps, ff)

## Evaluate the Quality of the Explanations

In [8]:
# if you have saved some explanations, you can load them
import pickle
exps = pickle.load(open('{}-{}_exps.pkl'.format(SRC_LANG, TGT_LANG),'rb'))

In [9]:
exp_scores = []
for exp in exps:
    scores = [-entry[1] for entry in exp] # use negative SHAP values to find the incorrect tokens
    exp_scores.append(scores)

In [10]:
sys.path.append('../..')
from scripts.evaluate import evaluate_word_level

evaluate_word_level(dataset['word_labels'], exp_scores)

AUC score: 0.638
AP score: 0.464
Recall at top-K: 0.339
