# Measuring Context Usage with CXMI

This notebook contains the code to measure CXMI for contextual models trained in this libray

Start by setting the path for your checkpoint of interest. This should ideally be a model trained with *dynamic* context size. 
We also need set context size for which we are measuring CXMI. We also need to set the languages in order to load the sentencepiece models.

In [1]:
model_ckpt="/projects/tir5/users/patrick/checkpoints/iwslt2017/en-fr/one_to_five_1/"
source_context_size=0
target_context_size=1
source_lang="en"
target_lang="fr"

And then load the models and associated files such as the vocabularies into memory

In [2]:
import os
import sentencepiece as sp
import contextual_mt
from fairseq import utils, hub_utils

package = hub_utils.from_pretrained(
    model_ckpt, checkpoint_file="checkpoint_best.pt"
)
models = package["models"]
for model in models:
    model.cuda()
    model.eval()

# load dict, params and generator from task
src_dict = package["task"].src_dict
tgt_dict = package["task"].tgt_dict

# load sentencepiece models (assumes they are in the checkpoint dirs)
# FIXME: is there someway to have it in `package`
if os.path.exists(os.path.join(model_ckpt, "spm.model")):
    spm = sp.SentencePieceProcessor()
    spm.Load(os.path.join(model_ckpt, "spm.model"))
    src_spm = spm
    tgt_spm = spm
else:
    src_spm = sp.SentencePieceProcessor()
    src_spm.Load(os.path.join(model_ckpt, f"spm.{source_lang}.model"))
    tgt_spm = sp.SentencePieceProcessor()
    tgt_spm.Load(os.path.join(model_ckpt, f"spm.{target_lang}.model"))

## Measuring CXMI

To measure CXMI, we need an held-out dataset. Currently, two types of dataset are supported

* A standard dataset
* A contrastive dataset

### Standard Dataset 

To measure the CXMI for standart dataset, define the source, target and docids files

In [3]:
source_file="/projects/tir1/corpora/dialogue_mt/iwslt2017/en-fr/test.en-fr.en"
target_file="/projects/tir1/corpora/dialogue_mt/iwslt2017/en-fr/test.en-fr.fr"
docids_file="/projects/tir1/corpora/dialogue_mt/iwslt2017/en-fr/test.en-fr.docids"
batch_size=8

And run the following cell to compute the corpus-level cxmi

In [4]:
from contextual_mt.utils import parse_documents
from contextual_mt.docmt_cxmi import compute_cxmi
import numpy as np

# load files needed
with open(source_file, "r") as src_f:
    srcs = [line.strip() for line in src_f]
with open(docids_file, "r") as docids_f:
    docids = [int(idx) for idx in docids_f]
with open(target_file, "r") as tgt_f:
    refs = [line.strip() for line in tgt_f]

documents = parse_documents(srcs, refs, docids)
sample_cxmis, ids = compute_cxmi(
        documents,
        models,
        src_spm,
        src_dict,
        tgt_spm,
        tgt_dict,
        source_context_size,
        target_context_size,
        batch_size=batch_size
)
print(np.mean(sample_cxmis))

0.009629978


### Contrastive dataset

To compute CXMI for either ContraPro or Bawden's contrastive dataset, start by defining the dataset files

In [24]:
import torch
from contextual_mt.docmt_contrastive_eval import load_contrastive
from contextual_mt.contextual_dataset import collate
from contextual_mt.utils import encode, decode, create_context
from fairseq.sequence_scorer import SequenceScorer

bawden=True
source_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.current.en"
target_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.current.fr"
src_context_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.prev.en"
tgt_context_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.prev.fr"
#source_file="/home/pfernand/repos/ContraPro/contrapro.text.en"
#target_file="/home/pfernand/repos/ContraPro/contrapro.text.de"
#src_context_file="/home/pfernand/repos/ContraPro/contrapro.context.en"
#tgt_context_file="/home/pfernand/repos/ContraPro/contrapro.context.de"


# load files
srcs, all_tgts, tgt_labels, srcs_contexts, tgts_contexts = load_contrastive(
    source_file, target_file, src_context_file, tgt_context_file, dataset="bawden" if bawden else "contrapro"
)

scorer = SequenceScorer(tgt_dict)
sample_cxmis = []
corrects = []
b_corrects = []
for src, src_ctx, contr_tgts, tgt_ctx in zip(srcs, srcs_contexts, all_tgts, tgts_contexts):
    src = encode(src, src_spm, src_dict)
    src_ctx = [encode(ctx, src_spm, src_dict) for ctx in src_ctx]
    contr_tgts = [encode(tgt, tgt_spm, tgt_dict) for tgt in contr_tgts]
    tgt_ctx = [encode(ctx, tgt_spm, tgt_dict) for ctx in tgt_ctx]
    baseline_samples = []
    contextual_samples = []

    for tgt in contr_tgts:
        baseline_src_context = create_context(
            src_ctx,
            0,
            break_id=src_dict.index("<brk>"),
            eos_id=src_dict.eos(),
        )
        baseline_tgt_context = create_context(
            tgt_ctx,
            0,
            break_id=tgt_dict.index("<brk>"),
            eos_id=tgt_dict.eos(),
        )
        contextual_src_context = create_context(
            src_ctx,
            source_context_size,
            break_id=src_dict.index("<brk>"),
            eos_id=src_dict.eos(),
        )
        contextual_tgt_context = create_context(
            tgt_ctx,
            target_context_size,
            break_id=tgt_dict.index("<brk>"),
            eos_id=tgt_dict.eos())

        full_src = torch.cat([src, torch.tensor([src_dict.eos()])])
        full_tgt = torch.cat([tgt, torch.tensor([tgt_dict.eos()])])
        baseline_sample = {
            "id": 0,
            "source": full_src,
            "src_context": baseline_src_context,
            "target": full_tgt,
            "tgt_context": baseline_tgt_context,
        }
        contextual_sample = {
            "id": 0,
            "source": full_src,
            "src_context": contextual_src_context,
            "target": full_tgt,
            "tgt_context": contextual_tgt_context,
        }
        baseline_samples.append(baseline_sample)
        contextual_samples.append(contextual_sample)

    baseline_sample = collate(
        baseline_samples,
        pad_id=src_dict.pad(),
        eos_id=src_dict.eos(),
    )
    contextual_sample = collate(
        contextual_samples,
        pad_id=src_dict.pad(),
        eos_id=src_dict.eos()
    )

    baseline_sample = utils.move_to_cuda(baseline_sample)
    contextual_sample = utils.move_to_cuda(contextual_sample)

    baseline_out = scorer.generate(models, baseline_sample)
    contextual_out = scorer.generate(models, contextual_sample)

    scores = [h[0]["score"] for h in contextual_out]

    most_likely = torch.argmax(torch.stack(scores))
    correct = most_likely == 0
    baseline_correct = torch.argmax(torch.stack([h[0]["score"] for h in baseline_out])) == 0

    b_corrects.append(baseline_correct)
    corrects.append(correct)
    sample_cxmis.append(contextual_out[0][0]["score"].cpu() - baseline_out[0][0]["score"].cpu())

corrects = np.stack([correct.cpu().numpy() for correct in corrects])
b_corrects = np.stack([b_correct.cpu().numpy() for b_correct in b_corrects])
print(np.mean(sample_cxmis))


0.022161597


#### Measuring Correlations

To measure the correlation of the *per-sample* CXMI with the performance on samples that requires context, run

In [27]:
from scipy import stats

binary_vars = np.stack([not b_c and c for b_c, c in zip(b_corrects, corrects)])
print(scipy.stats.pointbiserialr(binary_vars, sample_cxmis))

PointbiserialrResult(correlation=0.025460697287605066, pvalue=0.7204408014491894)


### Analysing Samples

**TODO**

In [5]:
for _, i in sorted(zip(sample_cxmis, ids), reverse=True):
    print(f"current: {documents[i[0]][i[1]]}")
    print(f"context: {documents[i[0]][i[1]-1]}")
    input()

current: ('Genius.', 'Génial.')
context: ("I am learning that it's a genius idea to use a pair of barbecue tongs  to pick up things that you dropped.   I'm learning that nifty trick where you can charge  your mobile phone battery from your chair battery.", "J'apprends que c'est une idée géniale d'utiliser une pince à barbecue pour ramasser les choses qu'on a laissé tomber.  J'apprends ce truc génial pour charger la batterie de son téléphone portable grâce à celle de son fauteuil.")

current: ('This is the thing about symbols.', 'avec les symboles :')
context: ('This is the thing about postmodernism.', "C'est le problème avec le post-modernisme,")


KeyboardInterrupt: Interrupted by user


