In [None]:
#%pip install transformers unbabel-comet sacrebleu statsmodels nbformat plotly bitsandbytes
#%pip uninstall -y torchvision
%pip uninstall -y inseq
%pip install git+https://github.com/inseq-team/inseq.git

In [None]:
import datasets

#flores = datasets.load_dataset("gsarti/flores_101", "all", split="devtest")
#iwslt = datasets.load_dataset("gsarti/iwslt2017_context", "iwslt2017-en-fr", split="test")
scat = datasets.load_dataset("inseq/scat", split="test", verification_mode="no_checks")#, download_mode="force_redownload")

ds = {"scat": scat}#"flores": flores, "iwslt17": iwslt, "scat": scat}

In [None]:
from collections import Counter, OrderedDict
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from random import random

# Used in dataset preprocessing to sample the number of context sentences to include
class OrderedCounter(Counter, OrderedDict):
    'Counter that remembers the order elements are first encountered'

    def __repr__(self):
        return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))

    def __reduce__(self):
        return self.__class__, (OrderedDict(self),)

def get_preprocess_dataset(ctx_size, dataset="flores", src_lang="eng"):
    def preprocess_dataset_seq(examples):
        if dataset == "flores":
            inputs = examples[f"sentence_{src_lang[:3]}"]
            n_previous = [i for _, v in OrderedCounter(examples["URL"]).items() for i in range(v)]
        elif dataset == "iwslt17":
            inputs = [ex[src_lang[:2]] for ex in examples["translation"]]
            n_previous = [i for _, v in OrderedCounter(examples["doc_id"]).items() for i in range(v)]
        else:
            raise ValueError(f"Not available: {dataset}")
        n_contexts = [ctx_size for _ in range(len(inputs))]
        n_contexts = [min(n_contexts[idx], n_previous[idx]) for idx in range(len(n_contexts))]
        context_inputs = []
        for idx in range(len(inputs)):
            if n_contexts[idx] > 0:
                ctx = " ".join(inputs[idx - n_contexts[idx]:idx])
                context_inputs.append(f"{ctx}<brk> {inputs[idx]}")
            else:
                context_inputs.append(inputs[idx])
        return {"sentence": context_inputs}
    
    def preprocess_dataset_merged(examples):
        if dataset == "scat":
            inputs = examples[src_lang[:2]]
            contexts = examples[f"context_{src_lang[:2]}"]
        context_inputs = []
        for idx in range(len(inputs)):
            if ctx_size > 0:
                context_inputs.append(f"{contexts[idx]}<brk> {inputs[idx]}")
            else:
                context_inputs.append(inputs[idx])
        return {"sentence": context_inputs}

    if dataset in ["flores", "iwslt17"]:
        return preprocess_dataset_seq
    elif dataset in ["scat"]:
        return preprocess_dataset_merged

def encode(examples, tokenizer):
    return tokenizer(examples["sentence"], truncation=True, padding='max_length')

In [None]:
import torch
import os
from tqdm import tqdm

base_path = "translations/translations"

def translate(cwd, ctx, model_type, dataset="flores", src_lang="eng", use_context=True, has_lang_tag=False, model_name: str = None):
    if model_name is None:
        model_id = f"{model_type}-ctx{ctx}-cwd{cwd}"
        model_name = f"context-mt/iwslt17-{model_id}-en-fr"
    else:
        model_id = model_type
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    if not has_lang_tag:
        tok = AutoTokenizer.from_pretrained(model_name)
    else:
        tok = AutoTokenizer.from_pretrained(model_name, src_lang="en_XX", tgt_lang="fr_XX")
    preproc_fn = get_preprocess_dataset(ctx if use_context else 0, dataset=dataset, src_lang=src_lang)
    data_preproc = ds[dataset].map(preproc_fn, batched=True, batch_size = 2000, remove_columns=ds[dataset].column_names)
    data_tokenized = data_preproc.map(lambda x: encode(x, tok), batched=True, remove_columns=["sentence"])
    data_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    dataloader = torch.utils.data.DataLoader(data_tokenized, batch_size=8 if "marian-small" in model_type else 4 if "marian-big" in model_type else 1)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval().to(device)
    print(f"Translating...")
    with open(os.path.join(base_path, f"{dataset}-{model_id}{'-noctx' if not use_context else ''}.txt"), 'a') as f:
        for i, batch in enumerate(tqdm(dataloader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            if not has_lang_tag:
                out = model.generate(**batch)
            else:
                out = model.generate(**batch, forced_bos_token_id=tok.lang_code_to_id["fr_XX"])
            if use_context:
                translations = tok.batch_decode(out.to("cpu"), skip_special_tokens=False)
                translations = [t.replace("<pad>", "").replace("</s>", "").replace("fr_XX", "").strip() for t in translations]
            else:
                translations = tok.batch_decode(out.to("cpu"), skip_special_tokens=True)
            if i == 0:
              print(translations[:2])
            for trans in translations:
                f.write(trans + "\n")

default_models = {
    "marian-small": "Helsinki-NLP/opus-mt-en-fr",
    "marian-big": "Helsinki-NLP/opus-mt-tc-big-en-fr",
    "mbart50-1toM": "facebook/mbart-large-50-one-to-many-mmt"
}

In [None]:
for cwd in range(6):
    for model_type in ["marian-small", "marian-big"]:
        translate(cwd, 4, model_type)

for ctx in range(0, 11, 2):
    for model_type in ["marian-small", "marian-big"]:
        translate(0, ctx, model_type)

for cwd in range(6):
    for model_type in ["marian-small", "marian-big"]:
        translate(cwd, 4, model_type, dataset="iwslt17")

for ctx in range(6, 11, 2):
    for model_type in ["marian-small", "marian-big"]:
        translate(0, ctx, model_type, dataset="iwslt17")

for cwd in range(6):
    for model_type in ["marian-small", "marian-big"]:
        translate(cwd, 4, model_type, dataset="scat")

for ctx in range(0, 11, 2):
    for model_type in ["marian-small", "marian-big"]:
        translate(0, ctx, model_type, dataset="scat")

In [None]:
for ctx, cwd in zip([4], [1]):#[0, 4], [0, 1]):
    #translate(cwd, ctx, "mbart50-1toM", has_lang_tag=True)
    translate(cwd, ctx, "mbart50-1toM", has_lang_tag=True, dataset="iwslt17")
    #translate(cwd, ctx, "mbart50-1toM", has_lang_tag=True, dataset="scat")

In [None]:
for model_type, model_name in default_models.items():
    if model_type == "mbart50-1toM":
        for dataset in ["scat"]:#["flores"]:#ds.keys():
            translate(0, 0, model_type, has_lang_tag=model_type == "mbart50-1toM", use_context=False, dataset=dataset, model_name=model_name)

In [None]:
#translate(1, 4, "marian-big-scat", dataset="scat", model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr")
#translate(1, 4, "marian-big-scat", dataset="flores", model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr")
translate(1, 4, "marian-small-scat", dataset="scat", model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr")
translate(1, 4, "marian-small-scat", dataset="flores", model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr")

In [None]:
base_path = "translations/translations_noctx"

#for cwd in range(6):
#    for model_type in ["marian-small", "marian-big"]:
#        translate(cwd, 4, model_type, dataset="scat", use_context=False)
#
#for ctx in range(0, 11, 2):
#    for model_type in ["marian-small", "marian-big"]:
#        translate(0, ctx, model_type, dataset="scat", use_context=False)
#
#for ctx, cwd in zip([4], [1]):#[0, 4], [0, 1]):
#    translate(cwd, ctx, "mbart50-1toM", has_lang_tag=True, use_context=False, dataset="scat")

#translate(1, 4, "marian-big-scat", dataset="scat", use_context=True, model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr")
#translate(1, 4, "marian-small-scat", dataset="scat", use_context=True, model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr")
#translate(1, 4, "marian-big-scat", dataset="scat", use_context=False, model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr")
translate(1, 4, "marian-small-scat", dataset="scat", use_context=False, model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr")

In [None]:
#translate(0, 4, "marian-small-scat-target", dataset="scat", use_context=True, model_name="context-mt/scat-marian-small-target-ctx4-cwd0-en-fr")
#translate(0, 4, "marian-big-scat-target", dataset="scat", use_context=True, model_name="context-mt/scat-marian-big-target-ctx4-cwd0-en-fr")
translate(0, 4, "mbart50-1toM-scat-target", dataset="scat", use_context=True, has_lang_tag=True, model_name="context-mt/scat-mbart50-1toM-target-ctx4-cwd0-en-fr")

In [None]:
for use_context in [True, False]:
    base_path = f"translations/translations{'_noctx' if not use_context else ''}"
    #translate(1, 4, "marian-small-scat", dataset="scat", use_context=use_context, model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr")
    #translate(1, 4, "marian-big-scat", dataset="scat", use_context=use_context, model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr")
    translate(1, 4, "mbart50-1toM-scat", dataset="scat", use_context=use_context, has_lang_tag=True, model_name="context-mt/scat-mbart50-1toM-ctx4-cwd1-en-fr")

In [None]:
base_path = "translations/translations_noctx"
#translate(0, 4, "marian-small-scat-target", dataset="scat", use_context=False, model_name="context-mt/scat-marian-small-target-ctx4-cwd0-en-fr")
#translate(0, 4, "marian-big-scat-target", dataset="scat", use_context=False, model_name="context-mt/scat-marian-big-target-ctx4-cwd0-en-fr")
#translate(0, 4, "mbart50-1toM-scat-target", dataset="scat", use_context=False, has_lang_tag=True, model_name="context-mt/scat-mbart50-1toM-target-ctx4-cwd0-en-fr")

In [9]:
from typing import List, Tuple
import re

def get_aligned_gender_annotations(ref_text, contrast_ref_text, mt_text) -> List[Tuple[str, str]]:
    """ Returns a list of 0s and 1s, where 0 means that the word is not in the MT output and 1 means that it is. """
    ref_tok = re.findall(r'\w+\b', ref_text)
    contrast_ref_tok = re.findall(r'\w+\b', contrast_ref_text)
    if not isinstance(mt_text, str):
        return [0]
    mt_tok = [x.lower() for x in re.findall(r'\w+\b', mt_text)]
    keywords = [ref for ref, con in zip(ref_tok, contrast_ref_tok) if ref != con]
    out = []
    for kw in keywords:
        if kw.lower() not in mt_tok:
            out += [0]
        else:
            out += [1]
            mt_tok.remove(kw.lower())
    return out

In [None]:
from sacrebleu.metrics import BLEU
import os
from tqdm import tqdm

from comet import download_model, load_from_checkpoint

model_path = download_model("Unbabel/wmt22-comet-da")
model = load_from_checkpoint(model_path)

import logging
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
    logger.setLevel(logging.WARNING)

def evaluate(cwd = 0, ctx = 0, model_type = "", dataset="flores", src_lang="eng", tgt_lang="fra", metric="bleu", use_context=True, model_name = None, use_target_context = False):
    if model_name is None:
        model_id = f"{model_type}-ctx{ctx}-cwd{cwd}"
    else:
        model_id = model_type
    if dataset == "flores":
        src = flores[f"sentence_{src_lang[:3]}"]
        refs = flores[f"sentence_{tgt_lang[:3]}"]
    elif dataset == "iwslt17":
        src = [ex[src_lang[:2]] for ex in iwslt["translation"]]
        refs = [ex[tgt_lang[:2]] for ex in iwslt["translation"]]
    elif dataset == "scat":
        src = scat[src_lang[:2]]
        refs = scat[tgt_lang[:2]]
    base_path = "translations/translations" if use_context else "translations/translations_noctx"
    with open(os.path.join(base_path, f"{dataset}-{model_id}{'' if use_context else '-noctx'}.txt"), 'r') as f:
        sys = f.readlines()
    if use_target_context:
        # Remove target context for evaluation
        sys = [s.split("<brk>")[1].strip() if "<brk>" in s else s for s in sys]
    if metric == "comet":
        comet_out = model.predict([{"src": s, "mt": m, "ref": r} for s, m, r in zip(src, sys, refs)], batch_size=8, gpus=1)
        print(dataset, f"{model_id}{'' if use_context else '-noctx'}", "COMET", comet_out.system_score)
    if metric == "bleu":
        bleu = BLEU()
        print(dataset, f"{model_id}{'' if use_context else '-noctx'}", bleu.corpus_score(sys, [refs]))
    if metric == "accuracy":
        if dataset != "scat":
            raise ValueError("Only scat dataset supports accuracy metric")
        tot_keywords, tot_correct = 0, 0
        for curr_ref, curr_ref_contrast, curr_mt in tqdm(zip(refs, scat[f"contrast_{tgt_lang[:2]}"], sys), desc="Aligned accuracy", total=len(refs)):
            matches = get_aligned_gender_annotations(curr_ref, curr_ref_contrast, curr_mt)
            tot_keywords += len(matches)
            tot_correct += len([x for x in matches if x == 1])
        print(dataset, f"{model_id}{'' if use_context else '-noctx'}", "align_match_acc", round(tot_correct / tot_keywords, 4))

In [None]:
for use_context in [True]:#[True, False]:
    if not use_context:
        data = ["scat"]
    else:
        data = ["iwslt17", "flores", "scat"]
    for dataset in ["scat"]:#data:
        for model_type in ["marian-small", "marian-big", "mbart50-1toM"]:
            for metric in ["accuracy"]:#["bleu", "comet"]:
                for cwd in range(6):
                    evaluate(cwd, 4, model_type, dataset=dataset, metric=metric, use_context=use_context)
                print("*" * 20)
                for ctx in range(0, 11, 2):
                    evaluate(0, ctx, model_type, dataset=dataset, metric=metric, use_context=use_context)
            print("-" * 20)
        print("=" * 20)
        if dataset != "iwslt17":
            for metric in ["accuracy"]:#["bleu", "comet"]:
                for ctx, cwd in zip([4], [1]):
                    evaluate(cwd, ctx, "mbart50-1toM", dataset=dataset, metric=metric, use_context=use_context)
                print("*" * 20)
            print("=" * 20)
        if use_context:
            for model_type, model_name in default_models.items():
                if model_type == "mbart50-1toM" and dataset == "iwslt17":
                    continue
                for metric in ["accuracy"]:#["bleu", "comet"]:
                    evaluate(0,0, model_type, dataset=dataset, metric=metric, use_context=use_context, model_name=model_name)
                print("*" * 20)
    print("")

In [None]:
for metric in ["bleu", "comet", "accuracy"]:
    for use_context in [True]:#[True, False]:
        evaluate(0, 4, "marian-small-scat-target", dataset="scat", metric=metric, use_context=use_context, use_target_context=True, model_name="context-mt/scat-marian-small-target-ctx4-cwd0-en-fr")
        evaluate(0, 4, "marian-big-scat-target", dataset="scat", metric=metric, use_context=use_context, use_target_context=True, model_name="context-mt/scat-marian-big-target-ctx4-cwd0-en-fr")
        evaluate(0, 4, "mbart50-1toM-scat-target", dataset="scat", metric=metric, use_context=use_context, use_target_context=True, model_name="context-mt/scat-mbart50-1toM-target-ctx4-cwd0-en-fr")

In [None]:
for metric in ["bleu", "comet", "accuracy"]:
    for use_context in [False]:#[True, False]:
        for model_type, model_name in default_models.items():
            evaluate(0, 0, model_type, dataset="scat", metric=metric, use_context=False, use_target_context=False, model_name=model_name)

In [None]:
for metric in ["bleu", "comet", "accuracy"]:
    for use_context in [True]:#[True, False]:
        evaluate(1, 4, "marian-small-scat", dataset="scat", metric=metric, use_context=use_context, use_target_context=False, model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr")
        evaluate(1, 4, "marian-big-scat", dataset="scat", metric=metric, use_context=use_context, use_target_context=False, model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr")
        evaluate(0, 4, "mbart50-1toM-scat", dataset="scat", metric=metric, use_context=use_context, use_target_context=False, model_name="context-mt/scat-mbart50-1toM-ctx4-cwd1-en-fr")

Next steps:

- ~~Compute `comet` scores for all models on the three datasets (already implemented, only need running)~~

- Translate with mbart pre and post-finetuning and evaluate its performances on the three datasets.

- Produce `noctx` translations setting context to 0 for all available models on the `scat` test set - this will be used as reference output for each model when no additional context is provided.

- Extract context reference translation for every example. This is used as input to compute the context sensitivity metrics alongside the contrast sources (see call to `model.attribute` below)

- Compute and save context-sensitivity scores for every token in every example in SCAT test set. Settings: P-CXMI, full KL-Div, truncated KL-Div at 90% probability. Format: dict `{"words": list of tokens, "scores": list of floats}`

- Compute Accuracy/AUPRC for various threshold values on the target pronoun for available models, evaluate results.

In [1]:
import re
from typing import List, Tuple

def tokenize(text: str, is_tagged: bool = False):
    pattern_nonspace = r"(<p>|</p>|<hon>|<hoff>|\S+)" if is_tagged else r"(\S+)"
    pattern_word = r"(<p>|</p>|<hon>|<hoff>|\w+)" if is_tagged else r"(\w+)"
    return [x for nonspace in re.split(pattern_nonspace, text) for x in re.split(pattern_word, nonspace) if x.strip()]

def tokenize_model(text: str, model):
    out = model.encode(text, as_targets=True)
    return [x.replace("▁", "") for x in out.input_tokens[0] if x not in ["<pad>", "</s>", "fr_XX"]]

def get_tokens_with_cue_target_tags(txt_tag: str, txt_clean: str):
    untagged_toks = tokenize(txt_clean)
    tagged_toks = tokenize(txt_tag, is_tagged=True)
    tag_idx, untag_idx = 0, 0
    cue_tags = [0 for _ in range(len(untagged_toks))]
    target_tags = [0 for _ in range(len(untagged_toks))]
    is_cue = False
    is_target = False
    while tag_idx < len(tagged_toks) and untag_idx < len(untagged_toks):
        if tagged_toks[tag_idx] == untagged_toks[untag_idx]:
            if is_cue:
                cue_tags[untag_idx] = 1
            elif is_target:
                target_tags[untag_idx] = 1
            tag_idx += 1
            untag_idx += 1
        elif tagged_toks[tag_idx] in  ["<p>", "<hon>"]:
            if tagged_toks[tag_idx] == "<p>":
                is_target = True
            elif tagged_toks[tag_idx] == "<hon>":
                is_cue = True
            tag_idx += 1
        elif tagged_toks[tag_idx] in ["</p>", "<hoff>"]:
            if tagged_toks[tag_idx] == "</p>":
                is_target = False
            elif tagged_toks[tag_idx] == "<hoff>":
                is_cue = False
            tag_idx += 1
        else:
            print(tagged_toks[tag_idx], untagged_toks[untag_idx])
            raise ValueError(f"Something went wrong\nTagged:{tagged_toks}\nUntagged:{untagged_toks}")
    return untagged_toks, cue_tags, target_tags

def get_subword_alignments(src: str, tgt: str) -> List[Tuple[int, int]]:
    """Aligns tokens of two whitespace-tokenized strings having the same contents,
    but differing in tokenization.
    The output is a sequence in the format "0-0 1-1 2-3 3-2 ..." corresponding to indices of
    aligned tokens between src and tgt
    """
    assert "".join(src.split(" ")) == "".join(tgt.split(" ")), f"SRC: {''.join(src.split())}\nTGT: {''.join(tgt.split())}\n"
    out = []
    src_idx = 0
    tgt_idx = 0
    # Splitting on single space ensures that "_" tokens are not lost and alignments are preserved.
    src_tok = src.strip().split(" ")
    tgt_tok = tgt.strip().split(" ")
    while src_idx < len(src_tok):
        curr_src_tok = src_tok[src_idx]
        curr_tgt_tok = tgt_tok[tgt_idx]
        if curr_src_tok == curr_tgt_tok:
            out.append(f"{src_idx}-{tgt_idx}")
            src_idx += 1
            tgt_idx += 1
        elif curr_src_tok in curr_tgt_tok:
            out.append(f"{src_idx}-{tgt_idx}")
            tgt_tok[tgt_idx] = tgt_tok[tgt_idx].replace(curr_src_tok, "", 1)
            src_idx += 1
        elif curr_tgt_tok in curr_src_tok:
            out.append(f"{src_idx}-{tgt_idx}")
            src_tok[src_idx] = src_tok[src_idx].replace(curr_tgt_tok, "", 1)
            tgt_idx += 1
        else:
            raise ValueError(f"ERR: {curr_src_tok} =!= {curr_tgt_tok}")
    out = " ".join(out)
    return [tuple(int(x) for x in pair.split("-")) for pair in out.split()]

def propagate_tags(tok_tgt, tags, alignments):
    model_tok_cue_tags = [0 for _ in range(len(tok_tgt))]
    for tok_idx, word_idx in alignments:
        if tags[word_idx] == 1:
            model_tok_cue_tags[tok_idx] = 1
    return model_tok_cue_tags

def get_model_cue_target_tags(tagged, untagged, model):    
    model_tokenized = tokenize_model(untagged, model)
    untagged_toks, cue_tags, target_tags = get_tokens_with_cue_target_tags(tagged, untagged)
    try:
        alignments = get_subword_alignments(" ".join(model_tokenized), " ".join(untagged_toks))
    except AssertionError:
        raise ValueError(model_tokenized, untagged_toks)
    model_tok_cue_tags = propagate_tags(model_tokenized, cue_tags, alignments)
    model_tok_target_tags = propagate_tags(model_tokenized, target_tags, alignments)
    return model_tok_cue_tags, model_tok_target_tags

In [2]:
import datasets

scat = datasets.load_dataset("inseq/scat", split="test", verification_mode="no_checks")#, download_mode="force_redownload")
ds = {"scat": scat}#"flores": flores, "iwslt17": iwslt, "scat": scat}

No config specified, defaulting to: scat/sentences
Found cached dataset scat (/home/gsarti/.cache/huggingface/datasets/inseq___scat/sentences/0.0.0/d0361f176cca9a1b65c6bf59ea1a94ab2c131b30c80e3bffab4f103c0e9406dd)


In [3]:
from typing import Optional, List, Tuple
import torch
from inseq.attr.step_functions import StepFunctionArgs, _get_contrast_output
from inseq.data import FeatureAttributionInput
from inseq.utils import logits_kl_divergence

def kl_div_per_layer_fn(
    args: StepFunctionArgs,
    contrast_target_prefixes: Optional[FeatureAttributionInput] = None,
    contrast_sources: Optional[FeatureAttributionInput] = None,
    contrast_targets: Optional[FeatureAttributionInput] = None,
    contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
    top_k: int = 0,
    top_p: float = 1.0,
    min_tokens_to_keep: int = 1,
):
    """Compute the KL divergence between original and contrastive probabilities at every layer of the model using the
    logit lens approach to project intermediate hidden states to logits.
    """

    original_batch = args.attribution_model.formatter.convert_args_to_batch(args)
    original_output = args.attribution_model.get_forward_output(
        original_batch,
        output_hidden_states=True,
    )
    contrast_output = _get_contrast_output(
        args,
        contrast_sources=contrast_sources,
        contrast_target_prefixes=contrast_target_prefixes,
        contrast_targets=contrast_targets,
        contrast_targets_alignments=contrast_targets_alignments,
        output_hidden_states=True,
    )
    all_kl_divergences = []
    for i in range(len(original_output.decoder_hidden_states) - 1):
        original_logits = args.attribution_model.model.lm_head(original_output.decoder_hidden_states[i][:, -1, :])
        contrast_logits = args.attribution_model.model.lm_head(contrast_output.decoder_hidden_states[i][:, -1, :])
        original_logits = original_logits + args.attribution_model.model.final_logits_bias
        contrast_logits = contrast_logits + args.attribution_model.model.final_logits_bias
        kl_divergence = logits_kl_divergence(
            original_logits=original_logits,
            contrast_logits=contrast_logits,
            top_p=top_p,
            top_k=top_k,
            min_tokens_to_keep=min_tokens_to_keep,
        )
        all_kl_divergences.append(kl_divergence)
    return torch.stack(all_kl_divergences, dim=1)

import inseq

# Register the function defined above
# Since outputs are still probabilities, contiguous tokens can still be aggregated using product
inseq.register_step_function(
    fn=kl_div_per_layer_fn,
    identifier="kl_div_per_layer",
    overwrite=True
)

In [4]:
import inseq
from inseq import AttributionModel
import pandas as pd
from tqdm import tqdm
import torch

from typing import List, Dict, Optional, Callable, Tuple

from simalign import SentenceAligner
import stanza

aligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="mai")
nlp = stanza.Pipeline(lang='en', processors='tokenize', download_method=None)

default_models = {
    "marian-small": "Helsinki-NLP/opus-mt-en-fr",
    "marian-big": "Helsinki-NLP/opus-mt-tc-big-en-fr",
    "mbart50-1toM": "facebook/mbart-large-50-one-to-many-mmt"
}


def get_model_id_and_name(
    cwd: int = 1, ctx: int = 4, model_type: str = "", dataset: str = "scat", model_name: str = None
) -> Tuple[str, str]:
    if model_name is None:
        if not model_type:
            raise ValueError("Must specify model_type or model_name")
        model_id = f"{dataset}-{model_type}-ctx{ctx}-cwd{cwd}"
        model_name = f"context-mt/{model_id}-en-fr"
    else:
        model_id = model_type
    return model_id, model_name


def get_formatted_examples(
    model_name: str = None,
    force_gen: bool = False,
    has_context: bool = True,
    has_lang_tag: bool = False,
    has_target_context: bool = False,
    start_idx: int = None,
    max_idx: int = None,
) -> List[Dict[str, str]]:
    if max_idx is None:
        max_idx = len(scat)
    if start_idx is None:
        start_idx = 0
    model = inseq.load_model(model_name, "saliency")
    generate_kwargs = {}
    if has_lang_tag:
        model.tokenizer.src_lang = "en_XX"
        model.tokenizer.tgt_lang = "fr_XX"
        generate_kwargs["forced_bos_token_id"] = model.tokenizer.lang_code_to_id["fr_XX"]
    examples = []
    for idx, ex in tqdm(enumerate(scat), total=max_idx):
        if idx < start_idx:
            continue
        if max_idx is not None and idx >= max_idx:
            break
        if not force_gen:
            ctx_tgt = None
            if has_context:
                contrast_sources = ex["context_en"] + "<brk> " + ex["en"]
            else:
                contrast_sources = ex["context_en"] + " " + ex["en"]
                ctx_tgt = " ".join([
                    model.generate(s.text, max_new_tokens=128, **generate_kwargs)[0]
                    for s in nlp(ex["context_en"]).sentences
                ])
                decoder_input = model.encode(ctx_tgt, as_targets=True).to(model.device)
                generate_kwargs["decoder_input_ids"] = decoder_input.input_ids
                if has_lang_tag:
                    lang_id_tensor = torch.tensor([model.tokenizer.lang_code_to_id["fr_XX"]]).to(model.device)
                    # Prepend the ID tensor to the original tensor along the first dimension (rows)
                    generate_kwargs["decoder_input_ids"] = torch.cat((lang_id_tensor.unsqueeze(0), generate_kwargs["decoder_input_ids"]), dim=1)
            encoded_sources = model.encode(contrast_sources, as_targets=False).to(model.device)
            generation_out = model.model.generate(
                input_ids=encoded_sources.input_ids,
                attention_mask=encoded_sources.attention_mask,
                return_dict_in_generate=True,
                **generate_kwargs,
            )
            encoded_sources = encoded_sources.to("cpu")
            if not has_context:
                decoder_input = decoder_input.to("cpu")
            ctx_gen = model.tokenizer.batch_decode(
                generation_out.sequences, skip_special_tokens=False if has_target_context else True
            )[0].replace("<pad>", "").replace("</s>", "")
            del generation_out
            torch.cuda.empty_cache()
            if has_context and has_target_context:
                ctx_tgt = ctx_gen.split("<brk>")[0].strip() + "<brk>"
            start_pos = len(ctx_tgt) if ctx_tgt else 0
            ctx_gen = ctx_gen[start_pos:]
        tgt = ex["fr"] if force_gen else ctx_gen
        ctx_tgt = ex["context_fr"] if force_gen else ctx_tgt
        examples.append({
            "src_en": ex["en"],
            "tgt_fr": tgt,
            "src_en_ctx": contrast_sources,
            "tgt_fr_ctx": ctx_tgt,
            "src_en_with_tags": ex["en_with_tags"],
            "orig_fr": ex["fr"],
            "orig_fr_with_tags": ex["fr_with_tags"],
        })
        if idx < 3:
            print(f"FULL EXAMPLE: {ex}")
            print(f"SRC: {ex['en']}")
            print(f"TGT: {tgt}")
            print(f"SRC CTX: {contrast_sources}")
            print(f"TGT CTX: {ctx_tgt}")
    return examples


def base_attribute_fn(ex: Dict[str, str], model: AttributionModel, idx: int) -> pd.DataFrame:
    out = model.attribute(
        ex["src_en"],
        ex["tgt_fr"],
        attribute_target=True,
        step_scores=["probability", "contrast_prob", "pcxmi", "kl_divergence"],
        contrast_sources=ex["src_en_ctx"],
        contrast_target_prefixes=ex["tgt_fr_ctx"],
        show_progress=False,
    )
    df = pd.DataFrame(out.get_scores_dicts()[0]["step_scores"])
    df = df.transpose().reset_index().rename(columns={"level_0": "token_idx", "level_1": "token"})
    df.insert(0, "example_idx", idx)
    return df


def top_p_attribute_fn(ex: Dict[str, str], model: AttributionModel, idx: int) -> pd.DataFrame:
    overall_df = None
    for top_p in [0.1, 0.3, 0.5, 0.7, 0.9]:
        out = model.attribute(
            ex["src_en"],
            ex["tgt_fr"],
            attribute_target=True,
            step_scores=["kl_divergence", "top_p_size"],
            contrast_sources=ex["src_en_ctx"],
            contrast_target_prefixes=ex["tgt_fr_ctx"],
            show_progress=False,
            top_p=top_p,
        )
        df = pd.DataFrame(out.get_scores_dicts()[0]["step_scores"])
        df = df.transpose().reset_index().rename(columns={"level_0": "token_idx", "level_1": "token"})
        df.insert(0, "example_idx", idx)
        if overall_df is None:
            overall_df = df.rename(columns={"kl_divergence": f"kl_div_{int(top_p * 100)}", "top_p_size": f"top_p_size_{int(top_p * 100)}"})
        else:
            overall_df[f"kl_div_{int(top_p * 100)}"] = df["kl_divergence"]
            overall_df[f"top_p_size_{int(top_p * 100)}"] = df["top_p_size"]
    return overall_df


def logit_lens_attribute_fn(ex: Dict[str, str], model: AttributionModel, idx: int) -> pd.DataFrame:
    out = model.attribute(
        ex["src_en"],
        ex["tgt_fr"],
        attribute_target=True,
        step_scores=["kl_div_per_layer"],
        contrast_sources=ex["src_en_ctx"],
        contrast_target_prefixes=ex["tgt_fr_ctx"],
        show_progress=False,
    )
    for layer_idx in range(out[0].step_scores["kl_div_per_layer"].shape[0]):
        out[0].step_scores[f"kl_div_l{layer_idx}"] = out[0].step_scores["kl_div_per_layer"][layer_idx, :]
    del out[0].step_scores["kl_div_per_layer"]
    df = pd.DataFrame(out.get_scores_dicts()[0]["step_scores"])
    df = df.transpose().reset_index().rename(columns={"level_0": "token_idx", "level_1": "token"})
    df.insert(0, "example_idx", idx)
    return df


def input_contributions_attribute_fn(ex: Dict[str, str], model: AttributionModel, idx: int) -> pd.DataFrame:
    has_target_context = ex["tgt_fr_ctx"] is not None and pd.notnull(ex["tgt_fr_ctx"])
    # Handle missing source context
    full_src = ex["src_en_ctx"]
    if full_src.startswith("<brk>") or not isinstance(ex["tgt_fr"], str):
        print(f"Skipping example {idx}")
        return None
    full_tgt = ex["tgt_fr_ctx"] + " " + ex["tgt_fr"].strip() if has_target_context else ex["tgt_fr"].strip()
    tgt_fr_ctx_tokens = model.encode(ex["tgt_fr_ctx"], as_targets=True).input_tokens[0]
    offset = len(tgt_fr_ctx_tokens) - 1 if has_target_context else 0 # pad
    if tgt_fr_ctx_tokens[1] == "fr_XX" and has_target_context:
        offset -= 1
    curr_len = len(model.encode(ex["tgt_fr"], as_targets=True).input_tokens[0]) - 1 # pad
    out = model.attribute(
        full_src,
        full_tgt,
        attribute_target=True,
        show_progress=False,
        attr_pos_start=offset if has_target_context else None,
        attributed_fn="contrast_prob_diff",
        contrast_sources=ex["src_en"],
        contrast_targets=ex["tgt_fr"].strip(),
        contrast_targets_alignments=[
            (idx_full, idx_curr) 
            for idx_curr, idx_full in enumerate(range(offset, offset + curr_len), start=1 if has_target_context else 0)
        ],
    )
    has_lang_tag = out[0].source[0].token == "en_XX" and out[0].target[0].token == "fr_XX"
    aggr_args = {}
    src_brk_idx = [t.token for t in out[0].source].index("<brk>")
    lang_tag_offset = 1 if has_lang_tag else 0
    aggr_args["source_spans"] = [(lang_tag_offset,src_brk_idx), (src_brk_idx+1,len(out[0].source) - 1)]
    if has_target_context:
        special_tok = 'fr_XX → <brk>' if has_lang_tag else '<brk>' 
        tgt_brk_idx = [t.token for t in out[0].target].index(special_tok)
        aggr_args["target_spans"] = [(lang_tag_offset,tgt_brk_idx)]
    aggr_out = out.aggregate("spans", **aggr_args).aggregate()
    assert aggr_out[0].source_attributions.size(0) == 4 + lang_tag_offset, (
        f"Expected {4 + lang_tag_offset} source tokens but found {aggr_out[0].source_attributions.size(0)} "
        f"instead: {aggr_out[0].source}"
    )
    if has_lang_tag:
        aggr_out[0].step_scores["src_langtag_attr"] = aggr_out[0].source_attributions[0, :]
    aggr_out[0].step_scores["src_ctx_attr"] = aggr_out[0].source_attributions[0 + lang_tag_offset, :]
    aggr_out[0].step_scores["src_brk_attr"] = aggr_out[0].source_attributions[1 + lang_tag_offset, :]
    aggr_out[0].step_scores["src_curr_attr"] = aggr_out[0].source_attributions[2 + lang_tag_offset, :]
    aggr_out[0].step_scores["src_eos_attr"] = aggr_out[0].source_attributions[3 + lang_tag_offset, :]
    tgt_curr_start_idx = 0
    if has_lang_tag:
        aggr_out[0].step_scores["tgt_langtag_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx, :]
        tgt_curr_start_idx += 1
    if has_target_context:
        aggr_out[0].step_scores["tgt_ctx_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx, :]
        aggr_out[0].step_scores["tgt_brk_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx + 1, :]
        tgt_curr_start_idx += 2
    aggr_out[0].step_scores["tgt_curr_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx:, :].nansum(axis=0)
    assert torch.allclose(torch.stack(list(aggr_out[0].step_scores.values()), dim=1).nansum(axis=1), torch.ones_like(aggr_out[0].step_scores["src_ctx_attr"]))
    df = pd.DataFrame(aggr_out.get_scores_dicts(do_aggregation=False)[0]["step_scores"])
    df = df.transpose().reset_index().rename(columns={"level_0": "token_idx", "level_1": "token"})
    if has_target_context:
        df["token_idx"] = [i for i in range(len(aggr_out[0].step_scores["src_ctx_attr"]))]
    df.insert(0, "example_idx", idx)
    return df


def context_sensitive_span_identification_scores(
    examples_path: Optional[str] = None,
    cwd: int = 1,
    ctx: int = 4,
    model_type: str = "",
    dataset: str = "scat",
    model_name: Optional[str] = None,
    force_gen: bool = False,
    has_context: bool = True,
    has_lang_tag: bool = False,
    has_target_context: bool = False,
    start_idx: int = 0,
    max_idx: Optional[int] = None,
    add_tags: bool = True,
    attribute_fn: Callable[[Dict[str, str], int], pd.DataFrame] = base_attribute_fn,
) -> pd.DataFrame:
    if examples_path is None:
        model_id, model_name = get_model_id_and_name(cwd=cwd, ctx=ctx, model_type=model_type, dataset=dataset, model_name=model_name)
        examples = get_formatted_examples(
            model_name = model_name, force_gen = force_gen, has_context = has_context, has_lang_tag = has_lang_tag, has_target_context = has_target_context, start_idx = start_idx, max_idx = max_idx
        )
    else:
        examples = pd.read_csv(examples_path, sep="\t").to_dict("records")
        model_id = model_type
    scores_df = None
    if start_idx > 0:
        scores_df = pd.read_csv(f"translations/scores/temp/{model_id}-scores-{'gold' if force_gen else 'gen'}.tsv", sep="\t")
    if max_idx is None:
        max_idx = len(examples)
    model = inseq.load_model(model_name, "saliency")
    if has_lang_tag:
        model.tokenizer.src_lang = "en_XX"
        model.tokenizer.tgt_lang = "fr_XX"
    for idx, ex in enumerate(examples):
        if idx < start_idx:
            continue
        if idx >= max_idx:
            break
        df = attribute_fn(ex, model, idx)
        if df is None:
            continue
        if add_tags:
            try:
                if force_gen:
                    cue_tags, target_tags = get_model_cue_target_tags(ex["orig_fr_with_tags"], ex["orig_fr"], model)
                else:
                    word_tok_ctx_gen = tokenize(ex["fr"])
                    sub_tok_ctx_gen = tokenize_model(ex["fr"], model)
                    # Get cue and target tags on the gold word-tokenized text
                    tok_gold_ref, gold_word_cue_tags, gold_word_target_tags = get_tokens_with_cue_target_tags(ex["orig_fr_with_tags"], ex["tgt_fr"])
                    # Align the word-tokenized model generation to the word-tokenized gold text
                    ctx_gen_to_gold_ref_alignments = aligner.get_word_aligns(word_tok_ctx_gen, tok_gold_ref)["itermax"]
                    # Tags on the model-generated word level translation
                    ctx_gen_word_cue_tags = propagate_tags(word_tok_ctx_gen, gold_word_cue_tags, ctx_gen_to_gold_ref_alignments)
                    ctx_gen_word_target_tags = propagate_tags(word_tok_ctx_gen, gold_word_target_tags, ctx_gen_to_gold_ref_alignments)
                    # Align the subword- and word-tokenized model generations
                    try:
                        sub_to_word_ctx_gen_alignments = get_subword_alignments(" ".join(sub_tok_ctx_gen), " ".join(word_tok_ctx_gen))
                    except AssertionError:
                        raise ValueError(sub_tok_ctx_gen, word_tok_ctx_gen)
                    # Propagate word-level tags on model generation to subword level.
                    cue_tags = propagate_tags(sub_tok_ctx_gen, ctx_gen_word_cue_tags, sub_to_word_ctx_gen_alignments)
                    target_tags = propagate_tags(sub_tok_ctx_gen, ctx_gen_word_target_tags, sub_to_word_ctx_gen_alignments)
                    if idx < 3:
                        print([(x, y) for x, y in zip(sub_tok_ctx_gen, cue_tags)])
                        print([(x, y) for x, y in zip(sub_tok_ctx_gen, target_tags)])
                # Add </s> token tag
                cue_tags += [0]
                target_tags += [0]
                if has_lang_tag:
                    cue_tags = [0] + cue_tags
                    target_tags = [0] + target_tags
                df["is_supporting_context"] = cue_tags
                df["is_context_sensitive"] = target_tags
            except Exception as ex:
                print(f"Excluding example {idx} due to error {ex}")
                continue
        if scores_df is None:
            scores_df = df
        else:
            scores_df = pd.concat([scores_df, df], axis=0)
        scores_df.to_csv(f"translations/scores/temp/{model_id}-scores-{'gold' if force_gen else 'gen'}.tsv", index=False, sep="\t")

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2023-08-04 17:12:38,497 - simalign.simalign - INFO - Initialized the EmbeddingLoader with model: bert-base-multilingual-cased
2023-08-04 17:12:3

In [None]:
#idx = "context-mt/scat-mbart50-1toM-target-ctx4-cwd0-en-fr"
#examples = get_formatted_examples(
#    model_name = idx,
#    has_context=True,
#    has_target_context=True,
#    has_lang_tag=True,
#    force_gen=False,
#)

In [None]:
#df = pd.DataFrame(examples)
#df.to_csv(f"translations/processed_examples/{idx.split('/')[1]}.tsv", sep="\t", index=False)

In [None]:
context_sensitive_span_identification_scores(
    model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr",
    model_type="scat-marian-big-ctx4-cwd1",
    examples_path="translations/processed_examples/scat-marian-big-ctx4-cwd1-en-fr.tsv",
    attribute_fn=input_contributions_attribute_fn,
    has_context=True,
    has_target_context=False,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [5]:
#import pandas as pd
#
#df = pd.read_csv("translations/scores/scat-marian-big-target-ctx4-cwd0-scores-gen.tsv", sep="\t")
#cols = ["src_ctx_attr", "src_brk_attr", "src_curr_attr", "src_eos_attr", "tgt_ctx_attr", "tgt_brk_attr", "tgt_curr_attr"]
#df = df[[x for x in df.columns if x not in cols]]
#df.to_csv("translations/scores/scat-marian-big-target-ctx4-cwd0-scores-gen.tsv", sep="\t", index=False)

In [None]:
context_sensitive_span_identification_scores(
    model_name="context-mt/scat-marian-big-target-ctx4-cwd0-en-fr",
    model_type="scat-marian-big-target-ctx4-cwd0",
    examples_path="translations/processed_examples/scat-marian-big-target-ctx4-cwd0-en-fr.tsv",
    attribute_fn=input_contributions_attribute_fn,
    has_context=True,
    has_target_context=True,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr",
    model_type="scat-marian-small-ctx4-cwd1",
    examples_path="translations/processed_examples/scat-marian-small-ctx4-cwd1-en-fr.tsv",
    attribute_fn=input_contributions_attribute_fn,
    has_context=True,
    has_target_context=False,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    model_name="context-mt/scat-marian-small-target-ctx4-cwd0-en-fr",
    model_type="scat-marian-small-target-ctx4-cwd0",
    examples_path="translations/processed_examples/scat-marian-small-target-ctx4-cwd0-en-fr.tsv",
    attribute_fn=input_contributions_attribute_fn,
    has_context=True,
    has_target_context=True,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    model_name="context-mt/scat-mbart50-1toM-ctx4-cwd1-en-fr",
    model_type="scat-mbart50-1toM-ctx4-cwd1",
    examples_path="translations/processed_examples/scat-mbart50-1toM-ctx4-cwd1-en-fr.tsv",
    attribute_fn=input_contributions_attribute_fn,
    has_context=True,
    has_target_context=False,
    has_lang_tag=True,
    force_gen=False,
    add_tags=False, #! important
)

In [5]:
context_sensitive_span_identification_scores(
    model_name="context-mt/scat-mbart50-1toM-target-ctx4-cwd0-en-fr",
    model_type="scat-mbart50-1toM-target-ctx4-cwd0",
    examples_path="translations/processed_examples/scat-mbart50-1toM-target-ctx4-cwd0-en-fr.tsv",
    attribute_fn=input_contributions_attribute_fn,
    has_context=True,
    has_target_context=True,
    has_lang_tag=True,
    force_gen=False,
    add_tags=False, #! important
    start_idx=319,
)

Skipping example 319
Skipping example 369
Skipping example 389
Skipping example 423
Skipping example 444
Skipping example 461
Skipping example 499


In [None]:
context_sensitive_span_identification_scores(
    cwd=1, ctx=4,
    model_type="scat-marian-big-ctx4-cwd1",
    model_name="context-mt/scat-marian-big-ctx4-cwd1-en-fr",
    has_context=True,
    has_target_context=False,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    cwd=0, ctx=4,
    model_type="scat-marian-big-target-ctx4-cwd0",
    model_name="context-mt/scat-marian-big-target-ctx4-cwd0-en-fr",
    has_context=True,
    has_target_context=True,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    cwd=0, ctx=4,
    model_type="scat-marian-small-target-ctx4-cwd0",
    model_name="context-mt/scat-marian-small-target-ctx4-cwd0-en-fr",
    has_context=True,
    has_target_context=True,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    cwd=1, ctx=4,
    model_type="scat-marian-small-ctx4-cwd1",
    model_name="context-mt/scat-marian-small-ctx4-cwd1-en-fr",
    has_context=True,
    has_target_context=False,
    has_lang_tag=False,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    cwd=0, ctx=4,
    model_type="scat-mbart50-1toM-target-ctx4-cwd0",
    model_name="context-mt/scat-mbart50-1toM-target-ctx4-cwd0-en-fr",
    has_context=True,
    has_target_context=True,
    has_lang_tag=True,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
context_sensitive_span_identification_scores(
    cwd=1, ctx=4,
    model_type="scat-mbart50-1toM-ctx4-cwd1",
    model_name="context-mt/scat-mbart50-1toM-ctx4-cwd1-en-fr",
    has_context=True,
    has_target_context=False,
    has_lang_tag=True,
    force_gen=False,
    add_tags=False, #! important
)

In [None]:
for model_type, model_name in default_models.items():
    if model_type == "mbart50-1toM":
        context_sensitive_span_identification_scores(
            cwd=0, ctx=0,
            model_type=model_type,
            model_name=model_name,
            has_context=False,
            has_target_context=False,
            has_lang_tag=model_type == "mbart50-1toM",
            force_gen=False,
        )

In [None]:
import pandas as pd
import plotly.express as px
import numpy as np

def load_scores_df(path: str):
    df = pd.read_csv(path, sep="\t")
    df["source_context"] = [scat["context_en"][idx] for idx in df["example_idx"]]
    df["source"] = [scat["en"][idx] for idx in df["example_idx"]]
    df["target_context"] = [scat["context_fr"][idx] for idx in df["example_idx"]]
    df["target"] = df.groupby('example_idx')['token'].transform(lambda x: ''.join(x.str.replace("▁", " ")))
    df["kl_divergence_mean"] = df.groupby('example_idx')['kl_divergence'].transform(np.mean)
    df["kl_divergence_std"] = df.groupby('example_idx')['kl_divergence'].transform(np.std)
    df["kl_divergence_zscore"] = (df.kl_divergence - df.kl_divergence_mean) / df.kl_divergence_std
    df["pcxmi_mean"] = df.groupby('example_idx')['pcxmi'].transform(np.mean)
    df["pcxmi_std"] = df.groupby('example_idx')['pcxmi'].transform(np.std)
    df["pcxmi_zscore"] = (df.pcxmi - df.pcxmi_mean) / df.pcxmi_std
    df["contrast_prob_diff"]  = df.contrast_prob - df.probability
    df.is_context_sensitive = df.is_context_sensitive.astype(bool)
    return df

In [None]:
df_gold = load_scores_df("translations/scores/scat-marian-big-ctx4-cwd1-scores-gold.tsv")
print(df_gold["is_context_sensitive"].value_counts())
df_gold.head()

In [None]:
fig = px.scatter(df_gold, x=df_gold.index, y='kl_divergence', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
fig.write_html("translations/scores/scat-marian-big-ctx4-cwd1-scores-kldiv.html")
fig = px.scatter(df_gold, x=df_gold.index, y='kl_divergence_zscore', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
fig.write_html("translations/scores/scat-marian-big-ctx4-cwd1-scores-kldiv-zscore.html")

In [None]:
fig = px.scatter(df_gold, x=df_gold.index, y='pcxmi', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
fig.write_html("translations/scores/scat-marian-big-ctx4-cwd1-scores-pcxmi.html")
fig = px.scatter(df_gold, x=df_gold.index, y='pcxmi_zscore', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
fig.write_html("translations/scores/scat-marian-big-ctx4-cwd1-scores-pcxmi-zscore.html")

In [None]:
df_gen = load_scores_df("translations/scores/scat-marian-big-ctx4-cwd1-scores-gen.tsv")
#df_gen = load_scores_df("translations/scores/mbart50-1toM-scores-gen.tsv")
print(df_gen["is_context_sensitive"].value_counts())
df_gen.head()

In [None]:
df_gen_inter = df_gen[df_gen.example_idx < 250]
df_gen_intra = df_gen[df_gen.example_idx >= 250]
df_gen_inter.is_supporting_context.sum(), df_gen_intra.is_supporting_context.sum()

In [None]:
from scipy import stats
import numpy as np

df_gen_sensitive = df_gen[df_gen["is_context_sensitive"] == True]
df_gen_not_sensitive = df_gen[df_gen["is_context_sensitive"] == False]
rng = np.random.default_rng()
stats.kstest(df_gen_sensitive["kl_divergence"].values, df_gen_not_sensitive["kl_divergence"].values)

In [None]:
#fig = px.scatter(df_gold, x=df_gold.index, y='kl_divergence', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
#fig.show()

In [None]:
fig = px.scatter(df_gen_inter, x=df_gen_inter.index, y='kl_divergence', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
fig.show()

In [None]:
fig.write_html("translations/scores/mbart50-1toM-scores-gen.html")

In [None]:
fig = px.scatter(df_gen, x=df_gen.index, y='pcxmi_zscore', color='is_context_sensitive', hover_data=['token', 'source_context', "source", "target_context", "target"], trendline="ols")
fig.show()

In [None]:
split = "test"

with open(f"filtered_scat/scat/highlighted.{split}.context.en") as f:
    orig_ctx_en = f.readlines()
with open(f"filtered_scat/scat/highlighted.{split}.context.fr") as f:
    orig_ctx_fr = f.readlines()
with open(f"filtered_scat/scat/highlighted.{split}.en") as f:
    orig_tgt_en = f.readlines()
with open(f"filtered_scat/scat/highlighted.{split}.fr") as f:
    orig_tgt_fr = f.readlines()

with open(f"highlighted.{split}.full", "w") as f:
    for ctx_en, ctx_fr, tgt_en, tgt_fr in zip(orig_ctx_en, orig_ctx_fr, orig_tgt_en, orig_tgt_fr):
        f.write(f"<SRC_CTX> {ctx_en.strip()} <SRC> {tgt_en.strip()} <TGT_CTX> {ctx_fr.strip()} <TGT> {tgt_fr.strip()}\n")

In [None]:
texts = []
with open(f"highlighted.test.full") as f:
    for line in f:
        texts.append(line.strip())
len(texts)

In [None]:
texts = [
    t for t in texts 
    if t.split("<TGT>")[1].count("<hon>") == 1 and t.split("<TGT>")[1].count("<hoff>") == 1 and t.split("<TGT>")[1].count("<p>") == 1 and t.split("<TGT>")[1].count("</p>") == 1
    and t.split("<SRC>")[1].split("<TGT_CTX>")[0].count("<p>") == 1 and t.split("<SRC>")[1].split("<TGT_CTX>")[0].count("</p>") == 1
    and "<hon>" not in t.split("<TGT>")[0] and t.count("<hon>") == 1 and t.count("<hoff>") == 1
]
len(texts)

In [None]:
def get_label(txt: str):
    return "P" if "<p>" in txt else "H" if "<hon>" in txt else "O"

def find_replace_tags(text: str):
    matches = re.finditer(r'(?:<p>|<hon>)([^<]+)(?:</p>|<hoff>)', text)
    tags = [(m.group(1), text[m.start():m.end()], m.start(), m.end()) for m in matches]
    new_tags = []
    for idx in range(len(tags)):
        diff_len = len(tags[idx][1]) - len(tags[idx][0])
        text = text.replace(tags[idx][1], tags[idx][0], 1)
        new_end = tags[idx][3] - diff_len
        if new_end - tags[idx][2] > 1 and tags[idx][0].strip() and new_end > tags[idx][2]:
            new_tags.append({"start": tags[idx][2], "end": new_end, "label": get_label(tags[idx][1])})
        for idx2 in range(idx + 1, len(tags)):
            tags[idx2] = (tags[idx2][0], tags[idx2][1], tags[idx2][2] - diff_len, tags[idx2][3] - diff_len)
    return text, new_tags

In [None]:
import re
import pandas as pd

pattern_nonspace = r"(<p>|</p>|<hon>|<hoff>|<SRC_CTX>|<SRC>|<TGT_CTX>|<TGT>|\S+)"
pattern_word = r"(<p>|</p>|<hon>|<hoff>|<SRC_CTX>|<SRC>|<TGT_CTX>|<TGT>|\w+)"

def tokenize(text: str):
    return [x for nonspace in re.split(pattern_nonspace, text) for x in re.split(pattern_word, nonspace) if x.strip() and x not in ["<p>", "</p>", "<hon>", "<hoff>"]]

tokens = [tokenize(text) for text in texts]
texts, tags = zip(*[find_replace_tags(text) for text in texts])
df = pd.DataFrame({"text": texts, "prediction": tags, "tokens": tokens})
df.to_json("scat_argilla_target.json", orient="records", lines=True)

In [None]:
import argilla as rg
from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("private-demos/scat_argilla_target", split="train", download_mode="force_redownload")

rg.init(api_url="https://private-demos-argilla-test.hf.space", api_key="admin.apikey")

# read in dataset, assuming its a dataset for token classification
dataset_rg = rg.read_datasets(dataset, task="TokenClassification")

# log the dataset
rg.log(dataset_rg, "scat_target")

In [None]:
import argilla as rg

rg.init(api_url="https://private-demos-argilla-test.hf.space", api_key="admin.apikey")
df = rg.load("scat_target").to_datasets().to_pandas()
df = df[df["status"] == "Validated"]
df.shape

In [None]:
tagged_texts = []
for idx, row in df.iterrows():
    annotations = sorted([(dic["start"], dic["end"], dic["label"]) for dic in row["annotation"]], key=lambda x: x[0])
    text = row["text"]
    for idx in range(len(annotations)):
        add = 0
        if annotations[idx][2] == "P":
            text = text[:annotations[idx][0]] + "<p>" + text[annotations[idx][0]:annotations[idx][1]] + "</p>" + text[annotations[idx][1]:]
            add = 7
        elif annotations[idx][2] == "H":
            text = text[:annotations[idx][0]] + "<hon>" + text[annotations[idx][0]:annotations[idx][1]] + "<hoff>" + text[annotations[idx][1]:]
            add = 11
        for idx2 in range(idx + 1, len(annotations)):
            annotations[idx2] = (annotations[idx2][0] + add, annotations[idx2][1] + add, annotations[idx2][2])
    tagged_texts.append(text)       

In [None]:
ctx_en = [t.split("<SRC_CTX>")[1].split("<SRC>")[0].strip() for t in tagged_texts]
tgt_en = [t.split("<SRC>")[1].split("<TGT_CTX>")[0].strip() for t in tagged_texts]
ctx_fr = [t.split("<TGT_CTX>")[1].split("<TGT>")[0].strip() for t in tagged_texts]
tgt_fr = [t.split("<TGT>")[1].strip() for t in tagged_texts]
len(ctx_en), len(tgt_en), len(ctx_fr), len(tgt_fr)

In [None]:
with open("filtered.test.context.en", "a+") as f:
    for line in ctx_en:
        f.write(line + "\n")
with open("filtered.test.en", "a+") as f:
    for line in tgt_en:
        f.write(line + "\n")
with open("filtered.test.context.fr", "a+") as f:
    for line in ctx_fr:
        f.write(line + "\n")
with open("filtered.test.fr", "a+") as f:
    for line in tgt_fr:
        f.write(line + "\n")

Next steps:

- Add columns `is_context_sensitive` and `is_supporting_context` to scores dataframes. First, map tags to tokens tokenized using `re.split` on word boundaries. Then use code from `peviz` to remap scores to subword tokens following model tokenization. `</s>` is assigned an extra 0 at the end by default.

- Add `likelihood_ratio` metric to Inseq and allow `top_p` for `kl_divergence` metric. Add `kl_div_0.90` and `likelihood_ratio` columns to scores dataframes.

- Plot scores distribution for different metrics

SCAT Issues:
- Malformed tags (easy to fix)
- It impersonal vs it ambiguous (hard to fix automatically)
- Wrong annotations, bias on context (cannot be fixed automatically)

Method's issues:
- Using gold reference as target for comparison? Generate tgt, align with gold, measure metric on tgt.

In [None]:
import inseq

model = inseq.load_model("context-mt/scat-marian-big-target-ctx4-cwd0-en-fr", "dummy")

In [None]:
ex = scat[11]
src = ex["context_en"] + "<brk> " + ex["en"]
ctx_gen = model.generate(src, max_new_tokens=512, skip_special_tokens=False)[0].split("<brk>")[1].strip().strip("</s>")
noctx_gen = "S'ils le font, ils ne se vantent pas aussi fort." #model.generate(ex["en"], max_new_tokens=512)[0]
print(ctx_gen)
print(noctx_gen)

In [None]:
print(model.encode(ctx_gen, as_targets=True).input_tokens[0])
print(model.encode(noctx_gen, as_targets=True).input_tokens[0])

In [None]:
out_tgt = model.attribute(
    src,
    ctx_gen,
    method="input_x_gradient",
    attribute_target=True,
    attributed_fn="contrast_prob_diff",
    step_scores=["contrast_prob_diff"],
    contrast_targets=noctx_gen,
    contrast_targets_alignments="auto"
    
)
out_tgt.weight_attributions("contrast_prob_diff")
out_tgt.show()

In [None]:
import inseq
model = inseq.load_model("context-mt/scat-marian-big-ctx4-cwd1-en-fr", "saliency")

In [None]:
src = "Yes. And an egg, but only once a week. How's that? Good. Where did these eggs come from?<brk> Are they the hotel's?"
tgt = "Oui. Et un oeuf, mais seulement une fois par semaine. Comment va-t-il? Bon. D'où proviennent ces oeuf?<brk> Ils sont à l'hôtel?"

has_pad = True

tgt_prefix, tgt_curr = tgt.split("<brk>")
tgt_prefix = tgt_prefix.strip() + "<brk>"
tgt_curr = tgt_curr.strip()

src = "Okay, okay. Calm down. First of all, you did the right thing by hiding under this table. Secondly, your man is here. I'm gonna take care of this for us. I've been playing Xbox for years.<brk> I'm really good at fixing it when it freezes."
#tgt_prefix = "D'accord, d'accord. Calme-toi. D'abord, tu as fait ce qu'il fallait en te cachant sous cette table. Deuxièmement, ton homme est là. Je vais m'en occuper pour nous. Je joue à la Xbox depuis des années.<brk>"
tgt_curr = "Je suis vraiment doué pour la réparer quand elle gèle."
tgt = tgt_curr #tgt_prefix + " " + tgt_curr

src_curr = src.split("<brk>")[1].strip()
offset = 0 #len(model.encode(tgt_prefix, as_targets=True).input_tokens[0])
curr_len = len(model.encode(tgt_curr, as_targets=True).input_tokens[0])

if has_pad:
#    offset -= 1
    curr_len -= 1

out = model.attribute(
    src,
    tgt,
    attribute_target=True,
    show_progress=False,
    #attr_pos_start=offset,
    step_scores=["probability", "contrast_prob", "contrast_prob_diff"],
    attributed_fn="contrast_prob_diff",
    contrast_sources=src_curr,
    contrast_targets=tgt_curr,
    contrast_targets_alignments=[(idx_full, idx_curr) for idx_curr, idx_full in enumerate(range(offset, offset + curr_len), start=0)],
)
out.show(normalize=False)

In [1]:
import pandas as pd
import torch
from typing import Dict

def contributions_attribute_fn(ex: Dict[str, str], idx: int) -> pd.DataFrame:
    has_target_context = ex["tgt_fr_ctx"] is not None
    full_src = ex["src_en_ctx"] + " " + ex["src_en"]
    full_tgt = ex["tgt_fr_ctx"] + " " + ex["tgt_fr"] if ex["tgt_fr_ctx"] else ex["tgt_fr"]
    offset = len(model.encode(ex["tgt_fr_ctx"], as_targets=True).input_tokens[0]) - 1 if has_target_context else 0 # pad
    curr_len = len(model.encode(ex["tgt_fr"], as_targets=True).input_tokens[0]) - 1 # pad
    out = model.attribute(
        full_src,
        full_tgt,
        attribute_target=True,
        show_progress=False,
        attr_pos_start=offset if has_target_context else None,
        attributed_fn="contrast_prob_diff",
        contrast_sources=ex["src_en"],
        contrast_targets=ex["tgt_fr"],
        contrast_targets_alignments=[
            (idx_full, idx_curr) 
            for idx_curr, idx_full in enumerate(range(offset, offset + curr_len), start=1 if has_target_context else 0)
        ],
    )
    has_lang_tag = out[0].source[0].token == "en_XX" and out[0].target[0].token == "fr_XX"
    aggr_args = {}
    src_brk_idx = [t.token for t in out[0].source].index("<brk>")
    lang_tag_offset = 1 if has_lang_tag else 0
    aggr_args["source_spans"] = [(lang_tag_offset,src_brk_idx), (src_brk_idx+1,len(out[0].source) - 1)]
    if has_target_context:
        tgt_brk_idx = [t.token for t in out[0].target].index("<brk>")
        aggr_args["target_spans"] = [(lang_tag_offset,tgt_brk_idx)]
    aggr_out = out.aggregate("spans", **aggr_args).aggregate()
    assert aggr_out[0].source_attributions.size(0) == 4 + lang_tag_offset, (
        f"Expected {4 + lang_tag_offset} source tokens but found {aggr_out[0].source_attributions.size(0)} "
        f"instead: {aggr_out[0].source}"
    )
    if has_lang_tag:
        aggr_out[0].step_scores["src_langtag_attr"] = aggr_out[0].source_attributions[0, :]
    aggr_out[0].step_scores["src_ctx_attr"] = aggr_out[0].source_attributions[0 + lang_tag_offset, :]
    aggr_out[0].step_scores["src_brk_attr"] = aggr_out[0].source_attributions[1 + lang_tag_offset, :]
    aggr_out[0].step_scores["src_curr_attr"] = aggr_out[0].source_attributions[2 + lang_tag_offset, :]
    aggr_out[0].step_scores["src_eos_attr"] = aggr_out[0].source_attributions[3 + lang_tag_offset, :]
    tgt_curr_start_idx = 0
    if has_lang_tag:
        aggr_out[0].step_scores["tgt_langtag_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx, :]
        tgt_curr_start_idx += 1
    if has_target_context:
        aggr_out[0].step_scores["tgt_ctx_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx, :]
        aggr_out[0].step_scores["tgt_brk_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx + 1, :]
        tgt_curr_start_idx += 2
    aggr_out[0].step_scores["tgt_curr_attr"] = aggr_out[0].target_attributions[tgt_curr_start_idx:, :].nansum(axis=0)
    aggr_out.show(do_aggregation=False)
    assert torch.allclose(torch.stack(list(aggr_out[0].step_scores.values()), dim=1).nansum(axis=1), torch.ones_like(aggr_out[0].step_scores["src_ctx_attr"]))
    return aggr_out

In [None]:
import inseq

#model = inseq.load_model("context-mt/scat-marian-big-ctx4-cwd1-en-fr", "saliency")
model = inseq.load_model("context-mt/scat-mbart50-1toM-ctx4-cwd1-en-fr", "saliency")
model.tokenizer.src_lang = "en_XX"
model.tokenizer.tgt_lang = "fr_XX"

ex_no_tgt_ctx = {
    "src_en": "I'm really good at fixing it when it freezes.",
    "src_en_ctx": "Okay, okay. Calm down. First of all, you did the right thing by hiding under this table. Secondly, your man is here. I'm gonna take care of this for us. I've been playing Xbox for years.<brk>",
    "tgt_fr": "Je suis vraiment doué pour la réparer quand elle gèle.",
    "tgt_fr_ctx": None,
}

out = contributions_attribute_fn(ex_no_tgt_ctx, 0)

#model = inseq.load_model("context-mt/scat-marian-big-target-ctx4-cwd0-en-fr", "saliency")

ex_tgt_ctx = {
    "src_en": "I'm really good at fixing it when it freezes.",
    "src_en_ctx": "Okay, okay. Calm down. First of all, you did the right thing by hiding under this table. Secondly, your man is here. I'm gonna take care of this for us. I've been playing Xbox for years.<brk>",
    "tgt_fr": "Je suis vraiment doué pour la réparer quand elle gèle.",
    "tgt_fr_ctx": "D'accord, d'accord. Calme-toi. D'abord, tu as fait ce qu'il fallait en te cachant sous cette table. Deuxièmement, ton homme est là. Je vais m'en occuper pour nous. Je joue à la Xbox depuis des années.<brk>"
}

#out = contributions_attribute_fn(ex_tgt_ctx, 0)

In [10]:
# Add an is_correct column to the scores file, marking examples where the model correctly disambiguates the gender in the target sentence
# Use to analyze data folds and not as feature, since it is not available at test time
import pandas as pd

model_name = "scat-mbart50-1toM-target-ctx4-cwd0"
examples = pd.read_csv(f"translations/processed_examples/{model_name}-en-fr.tsv", sep="\t")
is_correct = []
for ref, contrast, mt in zip(scat["fr"], scat["contrast_fr"], examples["tgt_fr"]):
    out_scores = get_aligned_gender_annotations(ref, contrast, mt)
    if sum(out_scores) > 0:
        is_correct += [1]
    else:
        is_correct += [0]
correct_df = pd.DataFrame({"example_idx": [i for i in range(len(is_correct))], "is_correct": is_correct})
df = pd.read_csv(f"translations/scores/{model_name}-scores-gen.tsv", sep="\t")
if "is_correct" in df.columns:
    df = df.drop("is_correct", axis=1)
df = df.merge(correct_df, on="example_idx", how="left")
df.to_csv(f"translations/scores/temp/{model_name}-scores-gen.tsv", sep="\t", index=False)