# Measuring Context Usage with CXMI

This notebook contains the code to measure CXMI for contextual models trained in this libray

Start by setting the path for your checkpoint of interest. This should ideally be a model trained with *dynamic* context size. 
We also need set context size for which we are measuring CXMI. We also need to set the languages in order to load the sentencepiece models.

In [1]:
model_ckpt="/projects/tir5/users/patrick/checkpoints/iwslt2017/en-fr/one_to_five_sampled_1/"
source_context_size=0
target_context_size=1
source_lang="en"
target_lang="fr"

And then load the models and associated files such as the vocabularies into memory

In [52]:
from fairseq.data import data_utils
from fairseq.sequence_scorer import SequenceScorer

from fairseq.data import data_utils
import tqdm
import numpy as np
from statistics import mean, stdev

import sentencepiece as sp

import contextual_mt
from contextual_mt.contextual_dataset import collate as contextual_collate
from contextual_mt.utils import encode, decode, create_context

from contextual_mt.contextual_dataset import collate

import os
import torch

#baseline_ckpt="/projects/tir5/users/patrick/checkpoints/iwslt2017/en-de/baseline_pretrained_3/"
contextual_ckpt="/projects/tir5/users/patrick/checkpoints/iwslt2017/en-fr/one_to_five_sampled_1/"
use_contrastive=False

In [2]:
from fairseq import utils, hub_utils

package = hub_utils.from_pretrained(
    model_ckpt, checkpoint_file="checkpoint_best.pt"
)
models = package["models"]
for model in models:
    model.cuda()
    model.eval()

# load dict, params and generator from task
src_dict = package["task"].src_dict
tgt_dict = package["task"].tgt_dict

scorer = SequenceScorer(tgt_dict)

# load sentencepiece models (assumes they are in the checkpoint dirs)
# FIXME: is there someway to have it in `package`
if os.path.exists(os.path.join(model_ckpt, "spm.model")):
    spm = sp.SentencePieceProcessor()
    spm.Load(os.path.join(model_ckpt, "spm.model"))
    src_spm = spm
    tgt_spm = spm
else:
    src_spm = sp.SentencePieceProcessor()
    src_spm.Load(os.path.join(model_ckpt, f"spm.{source_lang}.model"))
    tgt_spm = sp.SentencePieceProcessor()
    tgt_spm.Load(os.path.join(model_ckpt, f"spm.{target_lang}.model"))

NameError: name 'contextual_ckpt' is not defined

In [54]:
source_file="/projects/tir1/corpora/dialogue_mt/iwslt2017/en-de/test.en-de.en"
target_file="/projects/tir1/corpora/dialogue_mt/iwslt2017/en-de/test.en-de.de"
docids_file="/projects/tir1/corpora/dialogue_mt/iwslt2017/en-de/test.en-de.docids"
batch_size=4

if not use_contrastive:
    # load files needed
    with open(source_file, "r") as src_f:
        srcs = [line.strip() for line in src_f]
    with open(docids_file, "r") as docids_f:
        docids = [int(idx) for idx in docids_f]
    with open(target_file, "r") as tgt_f:
        refs = [line.strip() for line in tgt_f]

    srcs_encoded = [encode(src_l, src_spm, src_dict) for src_l in srcs]
    refs_encoded = [encode(tgt_l, tgt_spm, tgt_dict) for tgt_l in refs]
    # parse lines into list of documents
    documents = []
    prev_docid = None
    for src_l, tgt_l, idx in zip(srcs, refs, docids):
        if prev_docid != idx:
            documents.append([])
        prev_docid = idx
        documents[-1].append((src_l, tgt_l))

    preds = []
    ids = []
    scores = []
    src_context_lines = [[] for _ in range(batch_size)]
    tgt_context_lines = [[] for _ in range(batch_size)]

    # info necessary to create batches and recreate docs
    doc_idx = 0
    current_docs = [None for _ in range(batch_size)]
    current_docs_ids = [-1 for _ in range(batch_size)]
    current_docs_pos = [0 for _ in range(batch_size)]
    baseline_xes = []
    contextual_xes = []
    total_xmi = 0 
    num_samples = 0
    while True:
        batch_map = []
        batch_targets = []
        baseline_samples = []
        contextual_samples = []
        random_samples = []
        for idx in range(batch_size):
            # if any of the docs in the batch has finished replace by a new one
            if current_docs[idx] is None or current_docs_pos[idx] >= len(
                current_docs[idx]
            ):
                if doc_idx < len(documents):
                    current_docs[idx] = documents[doc_idx]
                    current_docs_ids[idx] = doc_idx
                    current_docs_pos[idx] = 0
                    src_context_lines[idx] = []
                    tgt_context_lines[idx] = []
                    doc_idx += 1
                else:
                    current_docs[idx] = None
                    continue

            src_l, tgt_l = current_docs[idx][current_docs_pos[idx]]

            ids.append((current_docs_ids[idx], current_docs_pos[idx]))

            # binarize source and create input with context and target
            source_noeos = encode(src_l, src_spm, src_dict)
            source = torch.stack([*source_noeos, torch.tensor(src_dict.eos())])
            target_noeos = encode(tgt_l, tgt_spm, tgt_dict)
            target = torch.stack([*target_noeos, torch.tensor(tgt_dict.eos())])

            random_src_pool = [srcs_encoded[idx] for idx in np.random.randint(0, len(srcs), size=source_context_size)]
            random_tgt_pool = [refs_encoded[idx] for idx in np.random.randint(0, len(refs), size=target_context_size)]

            baseline_src_context = create_context(
                src_context_lines[idx],
                0,
                break_id=src_dict.index("<brk>"),
                eos_id=src_dict.eos(),
            )
            baseline_tgt_context = create_context(
                tgt_context_lines[idx],
                0,
                break_id=tgt_dict.index("<brk>"),
                eos_id=tgt_dict.eos(),
            )
            random_src_context = create_context(
                random_src_pool,
                source_context_size,
                break_id=src_dict.index("<brk>"),
                eos_id=src_dict.eos()
            )
            random_tgt_context = create_context(
                random_tgt_pool,
                target_context_size,
                break_id=tgt_dict.index("<brk>"),
                eos_id=tgt_dict.eos()
            )

            random_samples.append(
                {
                    "id": 0,
                    "src_context": random_src_context,
                    "source": source,
                    "tgt_context": random_tgt_context,
                    "target": target
                }
            )
            baseline_samples.append(
                {
                    "id": 0,
                    "src_context": baseline_src_context,
                    "source": source,
                    "tgt_context": baseline_tgt_context,
                    "target": target
                }
            )

            contextual_src_context = create_context(
                src_context_lines[idx],
                source_context_size,
                break_id=src_dict.index("<brk>"),
                eos_id=src_dict.eos(),
            )
            contextual_tgt_context = create_context(
                tgt_context_lines[idx],
                target_context_size,
                break_id=tgt_dict.index("<brk>"),
                eos_id=tgt_dict.eos(),
            )
            contextual_samples.append(
                {
                    "id": 0,
                    "src_context": contextual_src_context,
                    "source": source,
                    "tgt_context": contextual_tgt_context,
                    "target": target
                }
            )

            src_context_lines[idx].append(source_noeos)
            tgt_context_lines[idx].append(target_noeos)

            current_docs_pos[idx] += 1

        # while exit condition
        if all(chat is None for chat in current_docs):
            break

        # create batch
        baseline_sample = collate(
            baseline_samples, src_dict.pad(), src_dict.eos()
        )
        baseline_sample = utils.move_to_cuda(baseline_sample)
        contextual_sample = collate(
            contextual_samples, src_dict.pad(), src_dict.eos()
        )
        contextual_sample = utils.move_to_cuda(contextual_sample)

        baseline_output = scorer.generate(contextual_models, baseline_sample)
        contextual_output = scorer.generate(contextual_models, contextual_sample)
        for batch_idx in range(len(baseline_samples)):
            # decode hypothesis
            baseline_xes.append(baseline_output[batch_idx][0]["score"].cpu())
            contextual_xes.append(contextual_output[batch_idx][0]["score"].cpu())
            total_xmi+=baseline_xes[-1] - contextual_xes[-1]
            num_samples+=1
            
    all_baseline_xes.append(baseline_xes)
    all_contextual_xes.append(contextual_xes)

In [55]:
from contextual_mt.docmt_contrastive_eval import load_contrastive

bawden=True
if bawden:
    source_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.current.en"
    target_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.current.fr"
    src_context_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.prev.en"
    tgt_context_file="/home/pfernand/repos/discourse-mt-test-sets/test-sets/lexical_choice.prev.fr"
else:
    source_file="/home/pfernand/repos/ContraPro/contrapro.text.en"
    target_file="/home/pfernand/repos/ContraPro/contrapro.text.de"
    src_context_file="/home/pfernand/repos/ContraPro/contrapro.context.en"
    tgt_context_file="/home/pfernand/repos/ContraPro/contrapro.context.de"


if use_contrastive:
    # load files
    srcs, all_tgts, tgt_labels, srcs_contexts, tgts_contexts = load_contrastive(
        source_file, target_file, src_context_file, tgt_context_file, dataset="bawden" if bawden else "contrapro"
    )
    baseline_xes = []
    contextual_xes = []
    corrects = []
    b_corrects = []
    for src, src_ctx, contr_tgts, tgt_ctx in zip(srcs, srcs_contexts, all_tgts, tgts_contexts):
        src = encode(src, src_spm, src_dict)
        src_ctx = [encode(ctx, src_spm, src_dict) for ctx in src_ctx]
        contr_tgts = [encode(tgt, tgt_spm, tgt_dict) for tgt in contr_tgts]
        tgt_ctx = [encode(ctx, tgt_spm, tgt_dict) for ctx in tgt_ctx]
        baseline_samples = []
        contextual_samples = []

        for tgt in contr_tgts:
            baseline_src_context = create_context(
                src_ctx,
                0,
                break_id=src_dict.index("<brk>"),
                eos_id=src_dict.eos(),
            )
            baseline_tgt_context = create_context(
                tgt_ctx,
                0,
                break_id=tgt_dict.index("<brk>"),
                eos_id=tgt_dict.eos(),
            )
            contextual_src_context = create_context(
                src_ctx,
                source_context_size,
                break_id=src_dict.index("<brk>"),
                eos_id=src_dict.eos(),
            )
            contextual_tgt_context = create_context(
                tgt_ctx,
                target_context_size,
                break_id=tgt_dict.index("<brk>"),
                eos_id=tgt_dict.eos())

            full_src = torch.cat([src, torch.tensor([src_dict.eos()])])
            full_tgt = torch.cat([tgt, torch.tensor([tgt_dict.eos()])])
            baseline_sample = {
                "id": 0,
                "source": full_src,
                "src_context": baseline_src_context,
                "target": full_tgt,
                "tgt_context": baseline_tgt_context,
            }
            contextual_sample = {
                "id": 0,
                "source": full_src,
                "src_context": contextual_src_context,
                "target": full_tgt,
                "tgt_context": contextual_tgt_context,
            }
            baseline_samples.append(baseline_sample)
            contextual_samples.append(contextual_sample)

        baseline_sample = contextual_collate(
            baseline_samples,
            pad_id=src_dict.pad(),
            eos_id=src_dict.eos(),
        )
        contextual_sample = contextual_collate(
            contextual_samples,
            pad_id=src_dict.pad(),
            eos_id=src_dict.eos()
        )

        baseline_sample = utils.move_to_cuda(baseline_sample)
        contextual_sample = utils.move_to_cuda(contextual_sample)

        baseline_out = scorer.generate(contextual_models, baseline_sample)
        contextual_out = scorer.generate(contextual_models, contextual_sample)

        scores = [h[0]["score"] for h in contextual_out]

        most_likely = torch.argmax(torch.stack(scores))
        correct = most_likely == 0
        baseline_correct = torch.argmax(torch.stack([h[0]["score"] for h in baseline_out])) == 0

        b_corrects.append(baseline_correct)
        corrects.append(correct)
        baseline_xes.append(baseline_out[0][0]["score"])
        contextual_xes.append(contextual_out[0][0]["score"])

    all_baseline_xes.append(baseline_xes)
    all_contextual_xes.append(contextual_xes)
    corrects = np.stack([correct.cpu().numpy() for correct in corrects])
    b_corrects = np.stack([b_correct.cpu().numpy() for b_correct in b_corrects])




In [56]:
print(len(all_baseline_xes))
def ensemble_xes_average(all_xes):
    all_xes = zip(*all_xes)
    avg_xes = []
    for samples in all_xes:
        avg_xes.append(np.log(np.mean([np.exp(xe.cpu()) for xe in samples])))
    return avg_xes

def calculate_xmi(model1_xes, model2_xes):
    total_xmi = 0
    num_samples = 0
    for m1_xe, m2_xe in zip(model1_xes, model2_xes):
        total_xmi += m1_xe - m2_xe
        num_samples += 1
    return -total_xmi/num_samples

avg_baseline_xes = ensemble_xes_average(all_baseline_xes)
avg_contextual_xes = ensemble_xes_average(all_contextual_xes)
print(calculate_xmi(avg_baseline_xes, avg_contextual_xes))

1
-0.7144931929047442


In [57]:
import scipy
print(f"Total Acc: {np.stack(corrects).mean().item()}")
print(f"Total Acc: {np.stack(b_corrects).mean().item()}")
print(scipy.stats.pointbiserialr(np.stack(corrects), [cxe - bxe for cxe, bxe in zip(avg_contextual_xes, avg_baseline_xes)]))
print(scipy.stats.pointbiserialr(np.stack([not b_c and c for b_c, c in zip(b_corrects, corrects)]), [cxe - bxe for cxe, bxe in zip(avg_contextual_xes, avg_baseline_xes)]))
print(scipy.stats.pearsonr(avg_contextual_xes, [cxe - bxe for cxe, bxe in zip(avg_contextual_xes, avg_baseline_xes)]))


Total Acc: 0.57
Total Acc: 0.5


ValueError: x and y must have the same length.

In [25]:
m1_xmes=[cxe - bxe for cxe, bxe in zip(avg_contextual_xes, avg_baseline_xes)]
#m2_xmes=[cxe - bxe for cxe, bxe in zip(avg_contextual_xes, avg_baseline_xes)]


for _, i in sorted([(x2 - x1, i) for x1, x2, i in zip(m1_xmes, m2_xmes, ids)], reverse=True):
    print(f"current: {documents[i[0]][i[1]]}")
    print(f"context: {documents[i[0]][i[1]-1]}")
    input()



NameError: name 'm1_xmes' is not defined