In [1]:
!pip install pandas

Collecting pandas
  Using cached pandas-1.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.5 MB)
Installing collected packages: pandas
Successfully installed pandas-1.3.1
You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.[0m


In [2]:
import os
import torch 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [3]:
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
cos = torch.nn.CosineSimilarity()
import pandas as pd
from fairseq.data import Dictionary
import numpy as np
from fairseq.models.bart import BARTModel
from collections import defaultdict
from numpy import linalg as LA

In [4]:
from typing import List, Dict, Iterator, Tuple, Any, Optional

def forward_decoder(
    ensemble_model,
    tokens,
    encoder_outs,
    incremental_states,
    temperature: float = 1.0,
):
    log_probs = []
    avg_attn: Optional[Tensor] = None
    encoder_out: Optional[EncoderOut] = None
    for i, model in enumerate(ensemble_model.models):
        if ensemble_model.has_encoder():
            encoder_out = encoder_outs[i]
        # decode each model
        if ensemble_model.has_incremental_states():
            decoder_out = _forward_decoder(
                model.decoder,
                tokens,
                encoder_out=encoder_out,
                incremental_state=incremental_states[i],
            )
#             decoder_out = model.decoder.forward(
#                 tokens,
#                 encoder_out=encoder_out,
#                 incremental_state=incremental_states[i],
#             )
        else:
            decoder_out = _forward_decoder(model.decoder, tokens, encoder_out=encoder_out)
#             decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)

        attn: Optional[Tensor] = None
        decoder_len = len(decoder_out)
        if decoder_len > 1 and decoder_out[1] is not None:
            if isinstance(decoder_out[1], Tensor):
                attn = decoder_out[1]
            else:
                attn_holder = decoder_out[1]["attn"]
                if isinstance(attn_holder, Tensor):
                    attn = attn_holder
                elif attn_holder is not None:
                    attn = attn_holder[0]
            if attn is not None:
                attn = attn[:, -1, :]

        decoder_out_tuple = (
            decoder_out[0][:, -1:, :].div_(temperature),
            None if decoder_len <= 1 else decoder_out[1],
        )

        probs = model.get_normalized_probs(
            decoder_out_tuple, log_probs=True, sample=None
        )
        probs = probs[:, -1, :]
        if ensemble_model.models_size == 1:
            return decoder_out, probs, attn

        log_probs.append(probs)
        if attn is not None:
            if avg_attn is None:
                avg_attn = attn
            else:
                avg_attn.add_(attn)
    avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
        ensemble_model.models_size
    )
    if avg_attn is not None:
        avg_attn.div_(ensemble_model.models_size)
    return decoder_out, avg_probs, avg_attn

def _generate(
    generator,
    sample,
    prefix_tokens=None,
    constraints=None,
    bos_token=None,
):
    incremental_states = torch.jit.annotate(
        List[Dict[str, Dict[str, Optional[Tensor]]]],
        [
            torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
            for i in range(generator.model.models_size)
        ],
    )
    net_input = sample["net_input"]

    if 'src_tokens' in net_input:
        src_tokens = net_input['src_tokens']
        # length of the source text being the character length except EndOfSentence and pad
        src_lengths = (src_tokens.ne(generator.eos) & src_tokens.ne(generator.pad)).long().sum(dim=1)
    elif 'source' in net_input:
        src_tokens = net_input['source']
        src_lengths = (
            net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1)
            if net_input['padding_mask'] is not None
            else torch.tensor(src_tokens.size(-1)).to(src_tokens)
        )
    else:
        raise Exception('expected src_tokens or source in net input')

    # bsz: total number of sentences in beam
    # Note that src_tokens may have more than 2 dimenions (i.e. audio features)
    bsz, src_len = src_tokens.size()[:2]
    beam_size = generator.beam_size

    if constraints is not None and not generator.search.supports_constraints:
        raise NotImplementedError("Target-side constraints were provided, but search method doesn't support them")

    # Initialize constraints, when active
    generator.search.init_constraints(constraints, beam_size)

    max_len: int = -1
    if generator.match_source_len:
        max_len = src_lengths.max().item()
    else:
        max_len = min(
            int(generator.max_len_a * src_len + generator.max_len_b),
            # exclude the EOS marker
            generator.model.max_decoder_positions() - 1,
        )
    assert (
        generator.min_len <= max_len
    ), "min_len cannot be larger than max_len, please adjust these!"
    # compute the encoder output for each beam
#     encoder_outs = generator.model.forward_encoder(net_input, return_all_hiddens=True)

#     # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
#     new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
#     new_order = new_order.to(src_tokens.device).long()
#     encoder_outs = generator.model.reorder_encoder_out(encoder_outs, new_order)

    encoder_outs = [generator.model.models[0].encoder(src_tokens,  None, return_all_hiddens=True)]
    
    # ensure encoder_outs is a List.
    assert encoder_outs is not None

    # initialize buffers
    scores = (
        torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
    )  # +1 for eos; pad is never chosen for scoring
    tokens = (
        torch.zeros(bsz * beam_size, max_len + 2)
        .to(src_tokens)
        .long()
        .fill_(generator.pad)
    )  # +2 for eos and pad
    tokens[:, 0] = generator.eos if bos_token is None else bos_token
    attn: Optional[Tensor] = None

    # A list that indicates candidates that should be ignored.
    # For example, suppose we're sampling and have already finalized 2/5
    # samples. Then cands_to_ignore would mark 2 positions as being ignored,
    # so that we only finalize the remaining 3 samples.
    cands_to_ignore = (
        torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
    )  # forward and backward-compatible False mask

    # list of completed sentences
    finalized = torch.jit.annotate(
        List[List[Dict[str, Tensor]]],
        [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
    )  # contains lists of dictionaries of infomation about the hypothesis being finalized at each step

    finished = [
        False for i in range(bsz)
    ]  # a boolean array indicating if the sentence at the index is finished or not
    num_remaining_sent = bsz  # number of sentences remaining

    # number of candidate hypos per step
    cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

    # offset arrays for converting between different indexing schemes
    bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
    cand_offsets = torch.arange(0, cand_size).type_as(tokens)

    reorder_state: Optional[Tensor] = None
    batch_idxs: Optional[Tensor] = None
        
    decoder_outs = []
    for step in range(max_len + 1):  # one extra step for EOS marker
#         print(tokens[:, : step + 1])
        decoder_out, lprobs, avg_attn_scores = forward_decoder(
            generator.model,
            tokens[:, : step + 1],
            encoder_outs,
            incremental_states,
            generator.temperature,
        )
        
#         print(lprobs.shape)
#         print(lprobs)
        decoder_outs.append(decoder_out)
        
        lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

        lprobs[:, generator.pad] = -math.inf  # never select pad
        lprobs[:, generator.unk] -= generator.unk_penalty  # apply unk penalty

        # handle max length constraint
        if step >= max_len:
            lprobs[:, : generator.eos] = -math.inf
            lprobs[:, generator.eos + 1 :] = -math.inf

        # handle prefix tokens (possibly with different lengths)
        if (
            prefix_tokens is not None
            and step < prefix_tokens.size(1)
            and step < max_len
        ):
            lprobs, tokens, scores = generator._prefix_tokens(
                step, lprobs, scores, tokens, prefix_tokens, beam_size
            )
        elif step < generator.min_len:
            # minimum length constraint (does not apply if using prefix_tokens)
            lprobs[:, generator.eos] = -math.inf

        # Record attention scores, only support avg_attn_scores is a Tensor
        if avg_attn_scores is not None:
            if attn is None:
                attn = torch.empty(
                    bsz * beam_size, avg_attn_scores.size(1), max_len + 2
                ).to(scores)
            attn[:, :, step + 1].copy_(avg_attn_scores)

        scores = scores.type_as(lprobs)
        eos_bbsz_idx = torch.empty(0).to(
            tokens
        )  # indices of hypothesis ending with eos (finished sentences)
        eos_scores = torch.empty(0).to(
            scores
        )  # scores of hypothesis ending with eos (finished sentences)

        if generator.should_set_src_lengths:
            generator.search.set_src_lengths(src_lengths)

        if generator.no_repeat_ngram_size > 0:
            lprobs = generator._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step)

        # Shape: (batch, cand_size)
        cand_scores, cand_indices, cand_beams = generator.search.step(
            step,
            lprobs.view(bsz, -1, generator.vocab_size),
            scores.view(bsz, beam_size, -1)[:, :, :step],
        )

        # cand_bbsz_idx contains beam indices for the top candidate
        # hypotheses, with a range of values: [0, bsz*beam_size),
        # and dimensions: [bsz, cand_size]
        cand_bbsz_idx = cand_beams.add(bbsz_offsets)

        # finalize hypotheses that end in eos
        # Shape of eos_mask: (batch size, beam size)
        eos_mask = cand_indices.eq(generator.eos) & cand_scores.ne(-math.inf)
        eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)

        # only consider eos when it's among the top beam_size indices
        # Now we know what beam item(s) to finish
        # Shape: 1d list of absolute-numbered
        eos_bbsz_idx = torch.masked_select(
            cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
        )

        finalized_sents: List[int] = []
        if eos_bbsz_idx.numel() > 0:
            eos_scores = torch.masked_select(
                cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
            )

            finalized_sents = generator.finalize_hypos(
                step,
                eos_bbsz_idx,
                eos_scores,
                tokens,
                scores,
                finalized,
                finished,
                beam_size,
                attn,
                src_lengths,
                max_len,
            )
            num_remaining_sent -= len(finalized_sents)

        assert num_remaining_sent >= 0
        if num_remaining_sent == 0:
            break
        assert step < max_len

        # Remove finalized sentences (ones for which {beam_size}
        # finished hypotheses have been generated) from the batch.
        if len(finalized_sents) > 0:
            new_bsz = bsz - len(finalized_sents)

            # construct batch_idxs which holds indices of batches to keep for the next pass
            batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device)
            batch_mask[finalized_sents] = False
            # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
            batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask)

            # Choose the subset of the hypothesized constraints that will continue
            generator.search.prune_sentences(batch_idxs)

            eos_mask = eos_mask[batch_idxs]
            cand_beams = cand_beams[batch_idxs]
            bbsz_offsets.resize_(new_bsz, 1)
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)
            cand_scores = cand_scores[batch_idxs]
            cand_indices = cand_indices[batch_idxs]

            if prefix_tokens is not None:
                prefix_tokens = prefix_tokens[batch_idxs]
            src_lengths = src_lengths[batch_idxs]
            cands_to_ignore = cands_to_ignore[batch_idxs]

            scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
            tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
            if attn is not None:
                attn = attn.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, attn.size(1), -1
                )
            bsz = new_bsz
        else:
            batch_idxs = None

        # Set active_mask so that values > cand_size indicate eos hypos
        # and values < cand_size indicate candidate active hypos.
        # After, the min values per row are the top candidate active hypos

        # Rewrite the operator since the element wise or is not supported in torchscript.

        eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
        active_mask = torch.add(
            eos_mask.type_as(cand_offsets) * cand_size,
            cand_offsets[: eos_mask.size(1)],
        )

        # get the top beam_size active hypotheses, which are just
        # the hypos with the smallest values in active_mask.
        # {active_hypos} indicates which {beam_size} hypotheses
        # from the list of {2 * beam_size} candidates were
        # selected. Shapes: (batch size, beam size)
        new_cands_to_ignore, active_hypos = torch.topk(
            active_mask, k=beam_size, dim=1, largest=False
        )

        # update cands_to_ignore to ignore any finalized hypos.
        cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
        # Make sure there is at least one active item for each sentence in the batch.
        assert (~cands_to_ignore).any(dim=1).all()

        # update cands_to_ignore to ignore any finalized hypos

        # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
        # can be selected more than once).
        active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
        active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)

        active_bbsz_idx = active_bbsz_idx.view(-1)
        active_scores = active_scores.view(-1)

        # copy tokens and scores for active hypotheses

        # Set the tokens for each beam (can select the same row more than once)
        tokens[:, : step + 1] = torch.index_select(
            tokens[:, : step + 1], dim=0, index=active_bbsz_idx
        )
        # Select the next token for each of them
        tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
            cand_indices, dim=1, index=active_hypos
        )
        if step > 0:
            scores[:, :step] = torch.index_select(
                scores[:, :step], dim=0, index=active_bbsz_idx
            )
        scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
            cand_scores, dim=1, index=active_hypos
        )

        # Update constraints based on which candidates were selected for the next beam
        generator.search.update_constraints(active_hypos)

        # copy attention for active hypotheses
        if attn is not None:
            attn[:, :, : step + 2] = torch.index_select(
                attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
            )

        # reorder incremental state in decoder
        reorder_state = active_bbsz_idx

    # sort by score descending
    for sent in range(len(finalized)):
        scores = torch.tensor([float(elem["score"].item()) for elem in finalized[sent]])
        _, sorted_scores_indices = torch.sort(scores, descending=True)
        finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
        finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], finalized[sent])
    return encoder_outs, decoder_outs, finalized

from torch import Tensor
import math

In [5]:
def _forward_decoder(
    decoder,
    prev_output_tokens,
    encoder_out,
    incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
    full_context_alignment: bool = False,
    alignment_layer: Optional[int] = None,
    alignment_heads: Optional[int] = None,
):
    if alignment_layer is None:
        alignment_layer = decoder.num_layers - 1

    # embed positions
    positions = (
        decoder.embed_positions(
            prev_output_tokens, incremental_state=incremental_state
        )
        if decoder.embed_positions is not None
        else None
    )

    if incremental_state is not None:
        prev_output_tokens = prev_output_tokens[:, -1:]
        if positions is not None:
            positions = positions[:, -1:]

    # embed tokens and positions
    x = decoder.embed_scale * decoder.embed_tokens(prev_output_tokens)

    if decoder.quant_noise is not None:
        x = decoder.quant_noise(x)

    if decoder.project_in_dim is not None:
        x = decoder.project_in_dim(x)

    if positions is not None:
        x += positions

    if decoder.layernorm_embedding is not None:
        x = decoder.layernorm_embedding(x)

    x = decoder.dropout_module(x)

    # B x T x C -> T x B x C
    x = x.transpose(0, 1)

    self_attn_padding_mask: Optional[Tensor] = None
    if decoder.cross_self_attention or prev_output_tokens.eq(decoder.padding_idx).any():
        self_attn_padding_mask = prev_output_tokens.eq(decoder.padding_idx)

    # decoder layers
    attn: Optional[Tensor] = None
    inner_states: List[Optional[Tensor]] = [x]
    self_dec_attns = []
    for idx, layer in enumerate(decoder.layers):
        if incremental_state is None and not full_context_alignment:
            self_attn_mask = decoder.buffered_future_mask(x)
        else:
            self_attn_mask = None

        x, layer_attn, self_dec_attn, _ = _forward_layer(
            layer,
            x,
            encoder_out.encoder_out if encoder_out is not None else None,
            encoder_out.encoder_padding_mask if encoder_out is not None else None,
            incremental_state,
            self_attn_mask=self_attn_mask,
            self_attn_padding_mask=self_attn_padding_mask,
            need_attn=bool((idx == alignment_layer)),
            need_head_weights=bool((idx == alignment_layer)),
        )
        inner_states.append(x)
        self_dec_attns.append(self_dec_attn)
        if layer_attn is not None and idx == alignment_layer:
            attn = layer_attn.float().to(x)

    if attn is not None:
        if alignment_heads is not None:
            attn = attn[:alignment_heads]

        # average probabilities over heads
        attn = attn.mean(dim=0)

    if decoder.layer_norm is not None:
        x = decoder.layer_norm(x)

    # T x B x C -> B x T x C
    x = x.transpose(0, 1)

    if decoder.project_out_dim is not None:
        x = decoder.project_out_dim(x)
    
    x = decoder.output_layer(x)

    return x, {"attn": [attn], "inner_states": inner_states, "self_dec_attns": self_dec_attns}

In [6]:
def _forward_layer(
    layer,
    x,
    encoder_out: Optional[torch.Tensor] = None,
    encoder_padding_mask: Optional[torch.Tensor] = None,
    incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
    prev_self_attn_state: Optional[List[torch.Tensor]] = None,
    prev_attn_state: Optional[List[torch.Tensor]] = None,
    self_attn_mask: Optional[torch.Tensor] = None,
    self_attn_padding_mask: Optional[torch.Tensor] = None,
    need_attn: bool = False,
    need_head_weights: bool = False,
):
    if need_head_weights:
            need_attn = True

    residual = x
    if layer.normalize_before:
        x = layer.self_attn_layer_norm(x)
    if prev_self_attn_state is not None:
        prev_key, prev_value = prev_self_attn_state[:2]
        saved_state: Dict[str, Optional[Tensor]] = {
            "prev_key": prev_key,
            "prev_value": prev_value,
        }
        if len(prev_self_attn_state) >= 3:
            saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
        assert incremental_state is not None
        layer.self_attn._set_input_buffer(incremental_state, saved_state)
    _self_attn_input_buffer = layer.self_attn._get_input_buffer(incremental_state)
    if layer.cross_self_attention and not (
        incremental_state is not None
        and _self_attn_input_buffer is not None
        and "prev_key" in _self_attn_input_buffer
    ):
        if self_attn_mask is not None:
            assert encoder_out is not None
            self_attn_mask = torch.cat(
                (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
            )
        if self_attn_padding_mask is not None:
            if encoder_padding_mask is None:
                assert encoder_out is not None
                encoder_padding_mask = self_attn_padding_mask.new_zeros(
                    encoder_out.size(1), encoder_out.size(0)
                )
            self_attn_padding_mask = torch.cat(
                (encoder_padding_mask, self_attn_padding_mask), dim=1
            )
        assert encoder_out is not None
        y = torch.cat((encoder_out, x), dim=0)
    else:
        y = x

    x, attn = layer.self_attn(
        query=x,
        key=y,
        value=y,
        key_padding_mask=self_attn_padding_mask,
        incremental_state=incremental_state,
        need_weights=False,
        attn_mask=self_attn_mask,
    )
    x = layer.dropout_module(x)
    x = residual + x
    if not layer.normalize_before:
        x = layer.self_attn_layer_norm(x)
        
    self_dec_attn = copy.copy(x)

    if layer.encoder_attn is not None and encoder_out is not None:
        residual = x
        if layer.normalize_before:
            x = layer.encoder_attn_layer_norm(x)
        if prev_attn_state is not None:
            prev_key, prev_value = prev_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_attn_state[2]
            assert incremental_state is not None
            layer.encoder_attn._set_input_buffer(incremental_state, saved_state)

        x, attn = layer.encoder_attn(
            query=x,
            key=encoder_out,
            value=encoder_out,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=need_attn or (not layer.training and layer.need_attn),
            need_head_weights=need_head_weights,
        )
        x = layer.dropout_module(x)
        x = residual + x
        if not layer.normalize_before:
            x = layer.encoder_attn_layer_norm(x)

    residual = x
    if layer.normalize_before:
        x = layer.final_layer_norm(x)

    x = layer.activation_fn(layer.fc1(x))
    x = layer.activation_dropout_module(x)
    x = layer.fc2(x)
    x = layer.dropout_module(x)
    x = residual + x
    if not layer.normalize_before:
        x = layer.final_layer_norm(x)
    if layer.onnx_trace and incremental_state is not None:
        saved_state = layer.self_attn._get_input_buffer(incremental_state)
        assert saved_state is not None
        if self_attn_padding_mask is not None:
            self_attn_state = [
                saved_state["prev_key"],
                saved_state["prev_value"],
                saved_state["prev_key_padding_mask"],
            ]
        else:
            self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
        return x, attn, self_attn_state, None
    return x, attn, self_dec_attn, None

In [7]:

def _generate_unsupervised(
    enc_generator,
    dec_generator,
    sample,
    prefix_tokens=None,
    constraints=None,
    bos_token=None,
):
    incremental_states = torch.jit.annotate(
        List[Dict[str, Dict[str, Optional[Tensor]]]],
        [
            torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
            for i in range(dec_generator.model.models_size)
        ],
    )
    net_input = sample["net_input"]

    if 'src_tokens' in net_input:
        src_tokens = net_input['src_tokens']
        # length of the source text being the character length except EndOfSentence and pad
        src_lengths = (src_tokens.ne(dec_generator.eos) & src_tokens.ne(dec_generator.pad)).long().sum(dim=1)
    elif 'source' in net_input:
        src_tokens = net_input['source']
        src_lengths = (
            net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1)
            if net_input['padding_mask'] is not None
            else torch.tensor(src_tokens.size(-1)).to(src_tokens)
        )
    else:
        raise Exception('expected src_tokens or source in net input')

    # bsz: total number of sentences in beam
    # Note that src_tokens may have more than 2 dimenions (i.e. audio features)
    bsz, src_len = src_tokens.size()[:2]
    beam_size = dec_generator.beam_size

    if constraints is not None and not dec_generator.search.supports_constraints:
        raise NotImplementedError("Target-side constraints were provided, but search method doesn't support them")

    # Initialize constraints, when active
    dec_generator.search.init_constraints(constraints, beam_size)

    max_len: int = -1
    if dec_generator.match_source_len:
        max_len = src_lengths.max().item()
    else:
        max_len = min(
            int(dec_generator.max_len_a * src_len + dec_generator.max_len_b),
            # exclude the EOS marker
            dec_generator.model.max_decoder_positions() - 1,
        )
    assert (
        dec_generator.min_len <= max_len
    ), "min_len cannot be larger than max_len, please adjust these!"
    # compute the encoder output for each beam
    encoder_outs = enc_generator.model.forward_encoder(net_input)

    # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
    new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
    new_order = new_order.to(src_tokens.device).long()
    encoder_outs = enc_generator.model.reorder_encoder_out(encoder_outs, new_order)

#     encoder_outs = [enc_model.model.encoder(src_tokens,  None, return_all_hiddens=True)]
    
    # ensure encoder_outs is a List.
    assert encoder_outs is not None

    # initialize buffers
    scores = (
        torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
    )  # +1 for eos; pad is never chosen for scoring
    tokens = (
        torch.zeros(bsz * beam_size, max_len + 2)
        .to(src_tokens)
        .long()
        .fill_(dec_generator.pad)
    )  # +2 for eos and pad
    tokens[:, 0] = dec_generator.eos if bos_token is None else bos_token
    attn: Optional[Tensor] = None

    # A list that indicates candidates that should be ignored.
    # For example, suppose we're sampling and have already finalized 2/5
    # samples. Then cands_to_ignore would mark 2 positions as being ignored,
    # so that we only finalize the remaining 3 samples.
    cands_to_ignore = (
        torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
    )  # forward and backward-compatible False mask

    # list of completed sentences
    finalized = torch.jit.annotate(
        List[List[Dict[str, Tensor]]],
        [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
    )  # contains lists of dictionaries of infomation about the hypothesis being finalized at each step

    finished = [
        False for i in range(bsz)
    ]  # a boolean array indicating if the sentence at the index is finished or not
    num_remaining_sent = bsz  # number of sentences remaining

    # number of candidate hypos per step
    cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

    # offset arrays for converting between different indexing schemes
    bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
    cand_offsets = torch.arange(0, cand_size).type_as(tokens)

    reorder_state: Optional[Tensor] = None
    batch_idxs: Optional[Tensor] = None
        
    decoder_outs = []
    for step in range(max_len + 1):  # one extra step for EOS marker
        
        decoder_out, lprobs, avg_attn_scores = forward_decoder(
            dec_generator.model,
            tokens[:, : step + 1],
            encoder_outs,
            incremental_states,
            dec_generator.temperature,
        )
        
        decoder_outs.append(decoder_out)
        
        lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

        lprobs[:, dec_generator.pad] = -math.inf  # never select pad
        lprobs[:, dec_generator.unk] -= dec_generator.unk_penalty  # apply unk penalty

        # handle max length constraint
        if step >= max_len:
            lprobs[:, : dec_generator.eos] = -math.inf
            lprobs[:, dec_generator.eos + 1 :] = -math.inf

        # handle prefix tokens (possibly with different lengths)
        if (
            prefix_tokens is not None
            and step < prefix_tokens.size(1)
            and step < max_len
        ):
            lprobs, tokens, scores = dec_generator._prefix_tokens(
                step, lprobs, scores, tokens, prefix_tokens, beam_size
            )
        elif step < dec_generator.min_len:
            # minimum length constraint (does not apply if using prefix_tokens)
            lprobs[:, dec_generator.eos] = -math.inf

        # Record attention scores, only support avg_attn_scores is a Tensor
        if avg_attn_scores is not None:
            if attn is None:
                attn = torch.empty(
                    bsz * beam_size, avg_attn_scores.size(1), max_len + 2
                ).to(scores)
            attn[:, :, step + 1].copy_(avg_attn_scores)

        scores = scores.type_as(lprobs)
        eos_bbsz_idx = torch.empty(0).to(
            tokens
        )  # indices of hypothesis ending with eos (finished sentences)
        eos_scores = torch.empty(0).to(
            scores
        )  # scores of hypothesis ending with eos (finished sentences)

        if dec_generator.should_set_src_lengths:
            dec_generator.search.set_src_lengths(src_lengths)

        if dec_generator.no_repeat_ngram_size > 0:
            lprobs = dec_generator._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step)

        # Shape: (batch, cand_size)
        cand_scores, cand_indices, cand_beams = dec_generator.search.step(
            step,
            lprobs.view(bsz, -1, dec_generator.vocab_size),
            scores.view(bsz, beam_size, -1)[:, :, :step],
        )

        # cand_bbsz_idx contains beam indices for the top candidate
        # hypotheses, with a range of values: [0, bsz*beam_size),
        # and dimensions: [bsz, cand_size]
        cand_bbsz_idx = cand_beams.add(bbsz_offsets)

        # finalize hypotheses that end in eos
        # Shape of eos_mask: (batch size, beam size)
        eos_mask = cand_indices.eq(dec_generator.eos) & cand_scores.ne(-math.inf)
        eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)

        # only consider eos when it's among the top beam_size indices
        # Now we know what beam item(s) to finish
        # Shape: 1d list of absolute-numbered
        eos_bbsz_idx = torch.masked_select(
            cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
        )

        finalized_sents: List[int] = []
        if eos_bbsz_idx.numel() > 0:
            eos_scores = torch.masked_select(
                cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
            )

            finalized_sents = dec_generator.finalize_hypos(
                step,
                eos_bbsz_idx,
                eos_scores,
                tokens,
                scores,
                finalized,
                finished,
                beam_size,
                attn,
                src_lengths,
                max_len,
            )
            num_remaining_sent -= len(finalized_sents)

        assert num_remaining_sent >= 0
        if num_remaining_sent == 0:
            break
        assert step < max_len

        # Remove finalized sentences (ones for which {beam_size}
        # finished hypotheses have been generated) from the batch.
        if len(finalized_sents) > 0:
            new_bsz = bsz - len(finalized_sents)

            # construct batch_idxs which holds indices of batches to keep for the next pass
            batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device)
            batch_mask[finalized_sents] = False
            # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
            batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask)

            # Choose the subset of the hypothesized constraints that will continue
            dec_generator.search.prune_sentences(batch_idxs)

            eos_mask = eos_mask[batch_idxs]
            cand_beams = cand_beams[batch_idxs]
            bbsz_offsets.resize_(new_bsz, 1)
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)
            cand_scores = cand_scores[batch_idxs]
            cand_indices = cand_indices[batch_idxs]

            if prefix_tokens is not None:
                prefix_tokens = prefix_tokens[batch_idxs]
            src_lengths = src_lengths[batch_idxs]
            cands_to_ignore = cands_to_ignore[batch_idxs]

            scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
            tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
            if attn is not None:
                attn = attn.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, attn.size(1), -1
                )
            bsz = new_bsz
        else:
            batch_idxs = None

        # Set active_mask so that values > cand_size indicate eos hypos
        # and values < cand_size indicate candidate active hypos.
        # After, the min values per row are the top candidate active hypos

        # Rewrite the operator since the element wise or is not supported in torchscript.

        eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
        active_mask = torch.add(
            eos_mask.type_as(cand_offsets) * cand_size,
            cand_offsets[: eos_mask.size(1)],
        )

        # get the top beam_size active hypotheses, which are just
        # the hypos with the smallest values in active_mask.
        # {active_hypos} indicates which {beam_size} hypotheses
        # from the list of {2 * beam_size} candidates were
        # selected. Shapes: (batch size, beam size)
        new_cands_to_ignore, active_hypos = torch.topk(
            active_mask, k=beam_size, dim=1, largest=False
        )

        # update cands_to_ignore to ignore any finalized hypos.
        cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
        # Make sure there is at least one active item for each sentence in the batch.
        assert (~cands_to_ignore).any(dim=1).all()

        # update cands_to_ignore to ignore any finalized hypos

        # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
        # can be selected more than once).
        active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
        active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)

        active_bbsz_idx = active_bbsz_idx.view(-1)
        active_scores = active_scores.view(-1)

        # copy tokens and scores for active hypotheses

        # Set the tokens for each beam (can select the same row more than once)
        tokens[:, : step + 1] = torch.index_select(
            tokens[:, : step + 1], dim=0, index=active_bbsz_idx
        )
        # Select the next token for each of them
        tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
            cand_indices, dim=1, index=active_hypos
        )
        if step > 0:
            scores[:, :step] = torch.index_select(
                scores[:, :step], dim=0, index=active_bbsz_idx
            )
        scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
            cand_scores, dim=1, index=active_hypos
        )

        # Update constraints based on which candidates were selected for the next beam
        dec_generator.search.update_constraints(active_hypos)

        # copy attention for active hypotheses
        if attn is not None:
            attn[:, :, : step + 2] = torch.index_select(
                attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
            )

        # reorder incremental state in decoder
        reorder_state = active_bbsz_idx

    # sort by score descending
    for sent in range(len(finalized)):
        scores = torch.tensor([float(elem["score"].item()) for elem in finalized[sent]])
        _, sorted_scores_indices = torch.sort(scores, descending=True)
        finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
        finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], finalized[sent])
    return encoder_outs, decoder_outs, finalized


In [8]:
def gram_linear(x):
  """Compute Gram (kernel) matrix for a linear kernel.

  Args:
    x: A num_examples x num_features matrix of features.

  Returns:
    A num_examples x num_examples Gram matrix of examples.
  """
  return x.dot(x.T)


def gram_rbf(x, threshold=1.0):
  """Compute Gram (kernel) matrix for an RBF kernel.

  Args:
    x: A num_examples x num_features matrix of features.
    threshold: Fraction of median Euclidean distance to use as RBF kernel
      bandwidth. (This is the heuristic we use in the paper. There are other
      possible ways to set the bandwidth; we didn't try them.)

  Returns:
    A num_examples x num_examples Gram matrix of examples.
  """
  dot_products = x.dot(x.T)
  sq_norms = np.diag(dot_products)
  sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
  sq_median_distance = np.median(sq_distances)
  return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))


def center_gram(gram, unbiased=False):
  """Center a symmetric Gram matrix.

  This is equvialent to centering the (possibly infinite-dimensional) features
  induced by the kernel before computing the Gram matrix.

  Args:
    gram: A num_examples x num_examples symmetric matrix.
    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
      estimate of HSIC. Note that this estimator may be negative.

  Returns:
    A symmetric matrix with centered columns and rows.
  """
  if not np.allclose(gram, gram.T):
    raise ValueError('Input must be a symmetric matrix.')
  gram = gram.copy()

  if unbiased:
    # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
    # L. (2014). Partial distance correlation with methods for dissimilarities.
    # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
    # stable than the alternative from Song et al. (2007).
    n = gram.shape[0]
    np.fill_diagonal(gram, 0)
    means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
    means -= np.sum(means) / (2 * (n - 1))
    gram -= means[:, None]
    gram -= means[None, :]
    np.fill_diagonal(gram, 0)
  else:
    means = np.mean(gram, 0, dtype=np.float64)
    means -= np.mean(means) / 2
    gram -= means[:, None]
    gram -= means[None, :]

  return gram


def cka(gram_x, gram_y, debiased=False):
  """Compute CKA.

  Args:
    gram_x: A num_examples x num_examples Gram matrix.
    gram_y: A num_examples x num_examples Gram matrix.
    debiased: Use unbiased estimator of HSIC. CKA may still be biased.

  Returns:
    The value of CKA between X and Y.
  """
  gram_x = center_gram(gram_x, unbiased=debiased)
  gram_y = center_gram(gram_y, unbiased=debiased)

  # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
  # n*(n-3) (unbiased variant), but this cancels for CKA.
  scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

  normalization_x = np.linalg.norm(gram_x)
  normalization_y = np.linalg.norm(gram_y)
  return scaled_hsic / (normalization_x * normalization_y)


def _debiased_dot_product_similarity_helper(
    xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y,
    n):
  """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
  # This formula can be derived by manipulating the unbiased estimator from
  # Song et al. (2007).
  return (
      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))


def feature_space_linear_cka(features_x, features_y, debiased=False):
  """Compute CKA with a linear kernel, in feature space.

  This is typically faster than computing the Gram matrix when there are fewer
  features than examples.

  Args:
    features_x: A num_examples x num_features matrix of features.
    features_y: A num_examples x num_features matrix of features.
    debiased: Use unbiased estimator of dot product similarity. CKA may still be
      biased. Note that this estimator may be negative.

  Returns:
    The value of CKA between X and Y.
  """
  features_x = features_x - np.mean(features_x, 0, keepdims=True)
  features_y = features_y - np.mean(features_y, 0, keepdims=True)

  dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
  normalization_x = np.linalg.norm(features_x.T.dot(features_x))
  normalization_y = np.linalg.norm(features_y.T.dot(features_y))

  if debiased:
    n = features_x.shape[0]
    # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
    sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
    sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
    squared_norm_x = np.sum(sum_squared_rows_x)
    squared_norm_y = np.sum(sum_squared_rows_y)

    dot_product_similarity = _debiased_dot_product_similarity_helper(
        dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
        squared_norm_x, squared_norm_y, n)
    normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
        normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
        squared_norm_x, squared_norm_x, n))
    normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
        normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
        squared_norm_y, squared_norm_y, n))

  return dot_product_similarity / (normalization_x * normalization_y)

In [9]:
def load_model(path):
    data = torch.load(path)
    return data["model"]

def load_data(path):
    data = torch.load(path)
    return data

def load_dict(path: str) -> Dictionary:
    d = Dictionary.load(path)
    # for l in langs:
    d.add_symbol("<mask>")
    return d

def sent_to_ids(d, sent):
    tokens = sent.split()
    return tokens_to_ids(d, tokens)

def dec_cal_mean(bart, d, sent):
    token_ids = torch.tensor(sent_to_ids(d, sent), device=DEVICE)
    all_layers = bart.extract_features(token_ids, return_all_hiddens=True)
    return [x[0].mean(dim=0).tolist() for x in all_layers]

def enc_cal_mean(bart, d, sent):
    token_ids = torch.tensor([sent_to_ids(d, sent).tolist()], device=DEVICE)
    encoder_layers = bart.model.encoder(token_ids,  None, return_all_hiddens=True)
    embedding = encoder_layers.encoder_embedding
    encoder_out = encoder_layers.encoder_out
    all_layers = encoder_layers.encoder_states
    
    mean_layers = [torch.squeeze(embedding).mean(dim=0).tolist()]
    for x in all_layers:
        mean_layers.append(torch.squeeze(x).mean(dim=0).tolist())
    mean_layers.append(torch.squeeze(encoder_out).mean(dim=0).tolist())
        
    return mean_layers

def linear_HSIC(X, Y):
    L_X = np.dot(X, X.T)
    L_Y = np.dot(Y, Y.T)
    return np.sum(centering(L_X) * centering(L_Y))

def centering(K):
    n = K.shape[0]
    unit = np.ones([n, n])
    I = np.eye(n)
    H = I - unit / n

    return np.dot(np.dot(H, K), H)

def linear_CKA(X, Y):
    hsic = linear_HSIC(X, Y)
    var1 = np.sqrt(linear_HSIC(X, X))
    var2 = np.sqrt(linear_HSIC(Y, Y))

    return hsic / (var1 * var2)

def cal_cka_sim(matrix1, matrix2):
    dot = np.dot(matrix2.T, matrix1)
    LA_norm = LA.norm(dot) ** 2
    norm1 = LA.norm(np.dot(matrix1.T, matrix1))
    norm2 = LA.norm(np.dot(matrix2.T, matrix2))
    return LA_norm / (norm1 * norm2)

def enc_cka_sim(jp_bart, ft, d, sentences):
    jp_bart_d = defaultdict(list)
    ft_d = defaultdict(list)
    
    for n, sent in enumerate(sentences):
        
        if (n+1) % 200 == 0:
            print(f"done {n+1}")
        sent = sent.strip()
        jp_bart_vecs = enc_cal_mean(bart=jp_bart, d=d, sent=sent)
        ft_vecs = enc_cal_mean(bart=ft, d=d, sent=sent)
        
        for n, (jp_vec, ft_vec) in enumerate(zip(jp_bart_vecs, ft_vecs)):
            jp_bart_d[n].append(jp_vec)
            ft_d[n].append(ft_vec)
    
    for key in jp_bart_d.keys():
        jp_bart_matrix = np.array(jp_bart_d[key])
        ft_matrix = np.array(ft_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("v2", key, linear_CKA(jp_bart_matrix, ft_matrix))
        
    
def dec_cka_sim(jp_bart, ft, d, sentences):
    jp_bart_d = defaultdict(list)
    ft_d = defaultdict(list)
    
    for n, sent in enumerate(sentences):
        if (n+1) % 200 == 0:
            print(f"done {n+1}")
        sent = sent.strip()
        jp_bart_vecs = dec_cal_mean(bart=jp_bart, d=d, sent=sent)
        ft_vecs = dec_cal_mean(bart=ft, d=d, sent=sent)
        
        for n, (jp_vec, ft_vec) in enumerate(zip(jp_bart_vecs, ft_vecs)):
            jp_bart_d[n].append(jp_vec)
            ft_d[n].append(ft_vec)
    
    for key in jp_bart_d.keys():
        jp_bart_matrix = np.array(jp_bart_d[key])
        ft_matrix = np.array(ft_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("v2", key, linear_CKA(jp_bart_matrix, ft_matrix))

def tokens_to_ids(d, tokens):
    idxs = []
    for token in tokens:
        idx = d.index(token)
        idxs.append(idx)
    return torch.tensor(idxs)

def ids_to_tokens(d, idxs):
    tokens = []
    for idx in idxs:
        token = d[idx]
        tokens.append(token)
    return tokens

def load_bart(path, model_name):
    bart = BARTModel.from_pretrained(path, checkpoint_file=model_name)
    bart.eval()
    return bart

import copy

def generate(tokens, model, beam: int = 1,  **kwargs):
    sample = build_sample(tokens)
    # build dec_generator using current args as well as any kwargs
    gen_args = copy.copy(model.args)
    gen_args.beam = beam
    for k, v in kwargs.items():
        setattr(gen_args, k, v)
    generator = model.task.build_generator([model.model], gen_args)
    encoder_outs, decoder_outs, translations = _generate(
        generator,
        sample,
        prefix_tokens=sample['net_input']['src_tokens'].new_zeros((len(tokens), 1)).fill_(model.task.source_dictionary.bos()),
    )
    
    def getarg(name, default):
        return getattr(gen_args, name, getattr(model.args, name, default))

    # Process top predictions
    hypos = [x[0] for x in translations]
    hypos = [v for _, v in sorted(zip(sample['id'].tolist(), hypos))]
    return encoder_outs, decoder_outs, hypos

def generate_unsupervised(tokens, enc_model, dec_model, beam: int = 1,  **kwargs):
    sample = build_sample(tokens)
    # build generator using current args as well as any kwargs
    enc_gen_args = copy.copy(enc_model.args)
    enc_gen_args.beam = beam
    
    dec_gen_args = copy.copy(dec_model.args)
    dec_gen_args.beam = beam
    
    for k, v in kwargs.items():
        setattr(gen_args, k, v)
    enc_generator = enc_model.task.build_generator([enc_model.model], enc_gen_args)
    dec_generator = dec_model.task.build_generator([dec_model.model], dec_gen_args)
    encoder_outs, decoder_outs, translations = _generate_unsupervised(
        enc_generator,
        dec_generator,
        sample,
        prefix_tokens=sample['net_input']['src_tokens'].new_zeros((len(tokens), 1)).fill_(dec_model.task.source_dictionary.bos()),
    )
    
    def getarg(name, default):
        return getattr(gen_args, name, getattr(model.args, name, default))

    # Process top predictions
#     print(translations)
    hypos = [x[0] for x in translations]
    hypos = [v for _, v in sorted(zip(sample['id'].tolist(), hypos))]
    return encoder_outs, decoder_outs, hypos

def build_sample(tokens):
    d = {'id': torch.tensor([0]),
         'nsentences': len(tokens),
         'ntokens': len(tokens[0]),
         'net_input': {'src_tokens': tokens,
                      'src_lengths': torch.tensor(len(tokens[0]))},
         'target': None}
    return d

def encdec_cal_mean(model, d, sent):
    token_ids = torch.tensor([sent_to_ids(d, sent).tolist()], device=DEVICE)
    encoder_outs, decoder_outs, outs = generate(token_ids, model)
    
    enc_mean_vecs = enc_mean(encoder_outs[0])
    inner_mean_layers, self_attn_mean_layers = dec_mean(decoder_outs)
    
    return enc_mean_vecs, {"inner_mean_layers": inner_mean_layers, "self_attn_mean_layers": self_attn_mean_layers}

def encdec_cal_mean_v2(model, d, sent):
    token_ids = torch.tensor([sent_to_ids(d, sent).tolist()], device=DEVICE)
    encoder_outs, decoder_outs, outs = generate(token_ids, model)
    
    enc_mean_vecs = enc_mean_v2(encoder_outs[0])
    dec_mean_vecs = dec_mean_v2(decoder_outs)
    
    return enc_mean_vecs, dec_mean_vecs
    
def enc_mean(encoder_layers):
    embedding = encoder_layers.encoder_embedding
    encoder_out = encoder_layers.encoder_out
    all_layers = encoder_layers.encoder_states
    
    mean_layers = [torch.squeeze(embedding).mean(dim=0).tolist()]
    for x in all_layers:
        mean_layers.append(torch.squeeze(x).mean(dim=0).tolist())
    mean_layers.append(torch.squeeze(encoder_out).mean(dim=0).tolist())
        
    return mean_layers

def enc_mean_v2(encoder_layers):
    embedding = encoder_layers.encoder_embedding
    encoder_out = encoder_layers.encoder_out
    all_layers = encoder_layers.encoder_states
    
    mean_layers = [torch.squeeze(embedding).mean(dim=0).tolist()[1:-1]]
    for x in all_layers:
        mean_layers.append(torch.squeeze(x).mean(dim=0).tolist()[1:-1])
    mean_layers.append(torch.squeeze(encoder_out).mean(dim=0).tolist()[1:-1])
        
    return mean_layers

def dec_mean(dec_outs):
    inner_mean_layers = []
    self_attn_mean_layers = []
    
    for layer in range(len(dec_outs[0][1]["inner_states"])):
        tensors = [x[1]["inner_states"][layer][0][0].tolist() for x in dec_outs]
        inner_mean_layers.append(torch.tensor(tensors).mean(dim=0).tolist())
        
    for layer in range(len(dec_outs[0][1]["self_dec_attns"])):
        tensors = [x[1]["self_dec_attns"][layer][0][0].tolist() for x in dec_outs]
        self_attn_mean_layers.append(torch.tensor(tensors).mean(dim=0).tolist())
        
    return inner_mean_layers, self_attn_mean_layers

def dec_mean_v2(dec_outs):
    mean_layers = []
    
    for layer in range(len(dec_outs[0][1]["inner_states"])):
        tensors = [x[1]["inner_states"][layer][0][0].tolist() for x in dec_outs]
        mean_layers.append(torch.tensor(tensors).mean(dim=0).tolist()[1:])
    
    return mean_layers

def encdec_cka_sim(pre, ft, pre_d, ft_d, pre_sentences, ft_sentences):
    pre_enc_d = defaultdict(list)
    pre_dec_d = defaultdict(list)
    pre_dec_self_attn_d = defaultdict(list)
    
    ft_dec_d = defaultdict(list)
    ft_enc_d = defaultdict(list)
    ft_dec_self_attn_d = defaultdict(list)
    
    for n, (pre_sent, ft_sent) in enumerate(zip(pre_sentences, ft_sentences)):
        if (n+1) % 200 == 0:
            print(f"done {n+1}")
        pre_sent = pre_sent.strip()
        ft_sent = ft_sent.strip()
        pre_enc_vecs, pre_dec_vecs = encdec_cal_mean(model=pre, d=pre_d, sent=pre_sent)
        ft_enc_vecs, ft_dec_vecs = encdec_cal_mean(model=ft, d=ft_d, sent=ft_sent)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_enc_vecs, ft_enc_vecs)):
            pre_enc_d[n].append(pre_vec)
            ft_enc_d[n].append(ft_vec)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_dec_vecs["inner_mean_layers"], \
                                                  ft_dec_vecs["inner_mean_layers"])):
            pre_dec_d[n].append(pre_vec)
            ft_dec_d[n].append(ft_vec)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_dec_vecs["self_attn_mean_layers"], \
                                                  ft_dec_vecs["self_attn_mean_layers"])):
            pre_dec_self_attn_d[n].append(pre_vec)
            ft_dec_self_attn_d[n].append(ft_vec)
    
    print("Encoder CKA")
    for key in pre_enc_d.keys():
        pre_matrix = np.array(pre_enc_d[key])
        ft_matrix = np.array(ft_enc_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("Layer", key, linear_CKA(pre_matrix, ft_matrix))
    
    print("\nDecoder CKA")
    for key in pre_dec_d.keys():
        pre_matrix = np.array(pre_dec_d[key])
        ft_matrix = np.array(ft_dec_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("Layer", key, linear_CKA(pre_matrix, ft_matrix))
    
    print("\nDecoder Self Attention CKA")
    for key in pre_dec_self_attn_d.keys():
        pre_matrix = np.array(pre_dec_self_attn_d[key])
        ft_matrix = np.array(ft_dec_self_attn_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("Layer", key, linear_CKA(pre_matrix, ft_matrix))


def encdec_cka_sim_rm_special_tokens(pre, ft, pre_d, ft_d, pre_sentences, ft_sentences, debiased=False):
    pre_enc_d = defaultdict(list)
    pre_dec_d = defaultdict(list)
    
    ft_dec_d = defaultdict(list)
    ft_enc_d = defaultdict(list)
    
    for n, (pre_sent, ft_sent) in enumerate(zip(pre_sentences, ft_sentences)):
        if (n+1) % 200 == 0:
            print(f"done {n+1}")
        pre_sent = pre_sent.strip()
        ft_sent = ft_sent.strip()
        pre_enc_vecs, pre_dec_vecs = encdec_cal_mean_v2(model=pre, d=pre_d, sent=pre_sent)
        ft_enc_vecs, ft_dec_vecs = encdec_cal_mean_v2(model=ft, d=ft_d, sent=ft_sent)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_enc_vecs, ft_enc_vecs)):
            pre_enc_d[n].append(pre_vec)
            ft_enc_d[n].append(ft_vec)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_dec_vecs, ft_dec_vecs)):
            pre_dec_d[n].append(pre_vec)
            ft_dec_d[n].append(ft_vec)
    
    print("Encoder CKA")
    for key in pre_enc_d.keys():
        pre_matrix = np.array(pre_enc_d[key])
        ft_matrix = np.array(ft_enc_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("Layer", key, feature_space_linear_cka(pre_matrix, ft_matrix, debiased=debiased))
    
    print("\nDecoder CKA")
    for key in pre_dec_d.keys():
        pre_matrix = np.array(pre_dec_d[key])
        ft_matrix = np.array(ft_dec_d[key])
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
        print("Layer", key, feature_space_linear_cka(pre_matrix, ft_matrix, debiased=debiased))
    
def cal_layer_sim(ft, pre, embs, padding_masks):
    
    enc_d = defaultdict(list)
    dec_d = defaultdict(list)
    
    for emb, padding_mask in zip(embs, padding_masks):
        
        for idx, (ft_layer, pre_layer) in enumerate(zip(ft.model.encoder.layers, pre.model.encoder.layers)):
            ft_outs = ft_layer(emb, padding_mask)
            pre_outs = pre_layer(emb, padding_mask)
            enc_d[idx].append(cos(ft_outs[0], pre_outs[0]).tolist()[0])

        for idx, (ft_layer, pre_layer) in enumerate(zip(ft.model.decoder.layers, \
                                                        pre.model.decoder.layers)):

            ft_outs = ft_layer(
                        emb,
                        None,
                        None,
                        None,
                        self_attn_mask=None,
                        self_attn_padding_mask=padding_mask,
                        need_attn=bool((idx==ft.model.decoder.num_layers-1)),
                        need_head_weights=bool((idx==ft.model.decoder.num_layers-1))
                      )

            pre_outs = pre_layer(
                        emb,
                        None,
                        None,
                        None,
                        self_attn_mask=None,
                        self_attn_padding_mask=padding_mask,
                        need_attn=bool((idx==pre.model.decoder.num_layers-1)),
                        need_head_weights=bool((idx==pre.model.decoder.num_layers-1))
                      )
            
#             print(cos(ft_outs[0], pre_outs[0]))
            dec_d[idx].append(cos(ft_outs[0][0], pre_outs[0][0]).tolist()[0])
            
    print("Encoder")
    for key, values in enc_d.items():
        print(key, sum(values)/len(values))
    print("\nDecoder")
    for key, values in dec_d.items():
        print(key, sum(values)/len(values))

        
def fusion_model(model1_path, model2_path, save_path, dec_flag=False):
    data1 = load_data(model1_path)
    model1 = data1["model"]
    
    data2 = load_data(model2_path)
    model2 = data2["model"]
    
    for key in model1.keys():
        if key == 'encoder.layer_norm.weight' or key == 'encoder.layer_norm.bias':
            print(key)
            model1[key] = model2[key]
#         if "decoder.layers.0" in key or "decoder.layers.1" in key or \
#            "decoder.layers.2" in key or "decoder.layers.4" in key:        
#             model1[key] = model2[key]
#         elif "decoder.layers.5" in key and dec_flag:
#             print(key)
#             model1[key] = model2[key]
    
#     data1["args"].share_all_embeddings = False
#     print(data1["args"].share_all_embeddings)
    torch.save(data1, save_path)

    
def fusion_enc_dec(enc_path, dec_path, save_path):
    enc = load_data(enc_path)
    enc_model = enc["model"]
    
    dec = load_data(dec_path)
    dec_model = dec["model"]
    
    enc_keys = list(enc_model.keys())
        
    for key in enc_keys:
        if key not in dec_model:
            enc_model.pop(key)
            continue
        elif "encoder." in key:
            continue
            
        enc_model[key] = dec_model[key]
    
    enc["args"].share_all_embeddings = False
#     print(data1["args"].share_all_embeddings)
    torch.save(enc, save_path)
    
    
def initialize_encoder_only(model_path, save_path, seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    data = load_data(model_path)
    model = data["model"]

    for key in model.keys():
        if 'encoder.' in key or 'emb' in key \
            or key == "decoder.output_projection.weight" or ".version" in key:
            continue
        
        matrix = model[key].to('cpu').detach().numpy().copy()
        std = matrix.std()
        mean = matrix.mean()
        shape = matrix.shape
#         print(key, mean, std)
        model[key] = torch.normal(
            mean, std, size=shape, dtype=model[key].dtype, layout=model[key].layout, device=model[key].device
        )
        
    torch.save(data, save_path)

def initialize_decoder_only(model_path, save_path, seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    data = load_data(model_path)
    model = data["model"]

    for key in model.keys():
        if 'decoder.' in key or 'emb' in key \
            or key == "decoder.output_projection.weight" or ".version" in key:
            continue
        
        matrix = model[key].to('cpu').detach().numpy().copy()
        std = matrix.std()
        mean = matrix.mean()
        shape = matrix.shape
#         print(key, mean, std)
        model[key] = torch.normal(
            mean, std, size=shape, dtype=model[key].dtype, layout=model[key].layout, device=model[key].device
        )
        
    torch.save(data, save_path)

def sents_to_encodervecs(model, sents, d, layer=-1):
    vecs = []
    for sent in sents:
        token_ids = torch.tensor([sent_to_ids(d=d, sent=sent).tolist()], device=DEVICE)
        enc_outs = model.model.encoder(token_ids,  None, return_all_hiddens=True)
        enc_vecs = enc_mean(enc_outs)
        vecs.append(enc_vecs[layer])
    return vecs

def precision_topk(preds, topk):
    num = 0
    for n, x in enumerate(preds):
        if n in x[:topk]:
            num += 1
    return num

def sent_similarity_topk(model1, model2, sents1, sents2, d1, d2):
    length = len(sents1)
    
    sent_vecs1 = torch.tensor(sents_to_encodervecs(model=model1, sents=sents1, d=d1))
    sent_vecs2 = torch.tensor(sents_to_encodervecs(model=model2, sents=sents2, d=d2))
    print("done cal vecs")
#     print(sent_vecs2)
    sims = [cos(sent_vecs1[n:n+1], sent_vecs2).to('cpu').detach().numpy().copy() for n in range(len(sent_vecs1))]
    print(sims)
    preds = [np.argsort(x)[::-1] for x in sims]
    
    topk = 1
    num = precision_topk(preds=preds, topk=topk)
    print(topk, num, num/length)
    
    topk = 10
    num = precision_topk(preds=preds, topk=topk)
    print(topk, num, num/length)

# nmt -> layers
# bert -> layer
def fusion_bertenc(bert_path, dec_path, save_path):
    bert = torch.load(bert_path)
    
    dec = load_data(dec_path)
    dec_model = dec["model"]
    
    # bert embs -> dec_model
    key = "bert.embeddings.word_embeddings.weight"
    dec_model["encoder.embed_tokens.weight"] = bert[key]
    dec_model["decoder.embed_tokens.weight"] = bert[key]
    dec_model["decoder.output_projection.weight"] = bert[key]
    print(dec_model["decoder.output_projection.weight"].shape)
    print(dec_model["decoder.embed_tokens.weight"].shape)
    print(dec_model["decoder.output_projection.weight"].shape)
    
    # bert layers -> dec_model
    for layer in range(12):
        bert_keys = [x for x in bert.keys() if f"layer.{layer}" in x]
        dec_keys = [x for x in dec_model.keys() if f"encoder.layers.{layer}" in x]
        for bert_key, dec_key in zip(bert_keys, dec_keys):
            dec_model[dec_key] = bert[bert_key]
    
    dec_keys = [x for x in dec_model.keys() if "decoder.layers" in x]
    for key in dec_keys:
        matrix = dec_model[key].to('cpu').detach().numpy().copy()
        std = matrix.std()
        mean = matrix.mean()
        shape = matrix.shape
#         print(key, mean, std)
        dec_model[key] = torch.normal(
            mean, std, size=shape, 
            dtype=dec_model[key].dtype, 
            layout=dec_model[key].layout, 
            device=dec_model[key].device
        )
    
#     enc["args"].share_all_embeddings = False
#     print(data1["args"].share_all_embeddings)
    torch.save(dec, save_path)

In [8]:
fusion_bertenc(
    bert_path="../pretrained_bart/trim/koja_trimed_bert.bin", 
    dec_path="../ja-ko/bert/base/checkpoints/checkpoint1.pt", 
    save_path="../pretrained_bart/trim/jabert_enc.pt"
)

KeyboardInterrupt: 

In [57]:
bert_trim = torch.load("../pretrained_bart/trim/koja_trimed_bert.bin")
bert = torch.load("../pretrained_bart/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin")

In [61]:
index = 10
bert["bert.embeddings.word_embeddings.weight"][1575].tolist() == bert_trim["bert.embeddings.word_embeddings.weight"][108].tolist()

True

In [57]:
fusion_enc_dec(
    enc_path="../pretrained_bart/trim/jaen_ja_bart_base.pt", 
    dec_path="../pretrained_bart/trim/jaen_en_bart_base.pt", 
    save_path="../pretrained_bart/trim/jaenc_endec_bart.pt"
)

In [31]:
fusion_enc_dec(
    enc_path="../pretrained_bart/trim/jaen_en_bart_base.pt", 
    dec_path="../pretrained_bart/trim/jaen_ja_bart_base.pt", 
    save_path="../pretrained_bart/trim/enenc_jadec_bart.pt"
)

In [17]:
# init_decoder = load_model("../../jpBART/japanese_bart_base_1.1/test/init_decoder.pt")
# init_encoder = load_model("../../jpBART/japanese_bart_base_1.1/test/init_encoder.pt")
ja_model = load_model("../pretrained_bart/trim/jaen_ja_bart_base.pt")

In [18]:
en_model = load_model("../pretrained_bart/trim/jaen_en_bart_base.pt")

In [36]:
type(en_model)

collections.OrderedDict

In [59]:
fusion_jaen = load_model("../pretrained_bart/trim/jaenc_endec_bart.pt")

In [60]:
[x for x in fusion_jaen.keys() if x not in en_model.keys()]

[]

In [14]:
[x for x in base_model.keys() if x not in en_model.keys()]

['encoder.layer_norm.weight',
 'encoder.layer_norm.bias',
 'decoder.layer_norm.weight',
 'decoder.layer_norm.bias',
 'decoder.output_projection.weight']

In [12]:
initialize_encoder_only(
    model_path="../../jpBART/japanese_bart_base_1.1/test/model.pt", \
    save_path="../../jpBART/japanese_bart_base_1.1/test/init_encoder.pt", \
    seed=0
)

In [21]:
initialize_decoder_only(
    model_path="../../jpBART/japanese_bart_base_1.1/test/model.pt", \
    save_path="../../jpBART/japanese_bart_base_1.1/test/init_decoder.pt", \
    seed=0
)

In [93]:
initialize_encoder_only(
    model_path="../../jpBART/en-ja/ja-bart/model.pt", \
    save_path="../../jpBART/en-ja/ja-bart/init_encoder.pt", \
    seed=0
)

In [22]:
initialize_decoder_only(
    model_path="../../jpBART/en-ja/ja-bart/model.pt", \
    save_path="../../jpBART/en-ja/ja-bart/init_decoder.pt", \
    seed=0
)

In [10]:
base_model = load_model("../../jpBART/japanese_bart_base_1.1/test/model.pt")

In [36]:
len(dec_outs[-1][1]['inner_states'])

7

In [None]:
tensors = [x[1]["inner_states"][layer][0][0].tolist() for x in dec_outs]

In [None]:
torch.tensor(tensors).mean(dim=0)

## Load models

In [10]:
# enBART and fine-tuned by jaen or enja
enbart_ft_enja_path = "../en-ja/enbart/checkpoints"
enbart_ft_enja_name = "checkpoint_best.pt"
enbart_ft_enja = load_bart(path=enbart_ft_enja_path, model_name=enbart_ft_enja_name)
enbart_ft_enja = enbart_ft_enja.to(DEVICE)

enbart_ft_jaen_path = "../ja-en/enbart/v2/checkpoints"
enbart_ft_jaen_name = "checkpoint_last.pt"
enbart_ft_jaen = load_bart(path=enbart_ft_jaen_path, model_name=enbart_ft_jaen_name)
enbart_ft_jaen = enbart_ft_jaen.to(DEVICE)

enbart_enja_path = "../pretrained_bart/trim/enbart_jaen"
enbart_enja_name = "jaen_en_bart_base.pt"
enbart_enja = load_bart(path=enbart_enja_path, model_name=enbart_enja_name)
enbart_enja = enbart_enja.to(DEVICE)

In [11]:
# enBART and fine-tuned by fren or enfr
enbart_ft_enfr_path = "../en-fr/bart/1M/checkpoints"
enbart_ft_enfr_name = "checkpoint_best.pt"
enbart_ft_enfr = load_bart(path=enbart_ft_enfr_path, model_name=enbart_ft_enfr_name)
enbart_ft_enfr = enbart_ft_enfr.to(DEVICE)

enbart_ft_fren_path = "../fr-en/bart/1M/v2/checkpoints"
enbart_ft_fren_name = "checkpoint_best.pt"
enbart_ft_fren = load_bart(path=enbart_ft_fren_path, model_name=enbart_ft_fren_name)
enbart_ft_fren = enbart_ft_fren.to(DEVICE)

enbart_enfr_path = "../pretrained_bart/trim"
enbart_enfr_name = "enfr_enbart_random_sampling.pt"
enbart_enfr = load_bart(path=enbart_enfr_path, model_name=enbart_enfr_name)
enbart_enfr = enbart_enfr.to(DEVICE)

In [12]:
# load dict of ft enBART
enja_enbart_d = load_dict("../pretrained_bart/trim/enbart_jaen/dict.txt")
enfr_enbart_d = load_dict("../pretrained_bart/trim/dict.txt")

In [10]:
ft_jako_path = "../ja-ko/bart/checkpoints"
ft_jako_name = "checkpoint_best.pt"
ft_jako= load_bart(path=ft_jako_path, model_name=ft_jako_name).to(DEVICE)

ft_koja_path = "../ko-ja/bart/checkpoints"
ft_koja_name = "checkpoint_best.pt"
ft_koja = load_bart(path=ft_koja_path, model_name=ft_koja_name).to(DEVICE)

ft_enja_path = "../en-ja/bart/v2/checkpoints"
ft_enja_name = "checkpoint_best.pt"
ft_enja = load_bart(path=ft_enja_path, model_name=ft_enja_name).to(DEVICE)

ft_jaen_path = "../ja-en/bart/checkpoints"
ft_jaen_name = "checkpoint_best.pt"
ft_jaen = load_bart(path=ft_jaen_path, model_name=ft_jaen_name).to(DEVICE)

jabart_jako_path = "../pretrained_bart/trim/jabart_jako"
jabart_jako_name = "ja_bart_base.pt"
jabart_jako = load_bart(path=jabart_jako_path, model_name=jabart_jako_name).to(DEVICE)

# ft_jako_fa_path = "../ja-ko/bart/fastalign/checkpoints"
# ft_jako_fa_name = "checkpoint_best.pt"
# ft_jako_fa = load_bart(path=ft_jako_fa_path, model_name=ft_jako_fa_name).to(DEVICE)

# ft_koja_fa_path = "../ko-ja/bart/fastalign/checkpoints"
# ft_koja_fa_name = "checkpoint_best.pt"
# ft_koja_fa = load_bart(path=ft_koja_fa_path, model_name=ft_koja_fa_name).to(DEVICE)

# jabart_jako_fa_path = "../pretrained_bart/muse/koja"
# jabart_jako_fa_name = "fastalign_bart.pt"
# jabart_jako_fa = load_bart(path=jabart_jako_fa_path, model_name=jabart_jako_fa_name).to(DEVICE)

In [11]:
# load dict of ft jaBART
koja_d = load_dict("../pretrained_bart/trim/jabart_jako/dict.txt")
enja_d = load_dict("../en-ja/bart/v2/checkpoints/dict.en.txt")

## Load sentences

In [13]:
# enja sentences
# file_path = "../data/enja/enBART/dev.en"
# with open(file_path, "r") as f:
#     en_enbart_sentences = f.readlines()
    
# file_path = "../data/enja/enBART/dev.ja"
# with open(file_path, "r") as f:d
#     ja_enbart_sentences = f.readlines()

# file_path = "../data/enja_v2/dev.en"
# with open(file_path, "r") as f:
#     jaen_sentences_en = f.readlines()
    
# file_path = "../data/enja_v2/dev.ja"
# with open(file_path, "r") as f:
#     jaen_sentences_ja = f.readlines()
    
# enbart sentences
file_path = "../data/enja_2/enBART/dev.ja"
with open(file_path, "r") as f:
    enbart_jaen_sentences_ja = f.readlines()

file_path = "../data/enja_2/enBART/dev.en"
with open(file_path, "r") as f:
    enbart_jaen_sentences_en = f.readlines()
    
file_path = "../data/enfr/random/dev.fr"
with open(file_path, "r") as f:
    enbart_fren_sentences_fr = f.readlines()

file_path = "../data/enfr/random/dev.en"
with open(file_path, "r") as f:
    enbart_fren_sentences_en = f.readlines()

In [12]:
file_path = "../../dev_head10.sp.ko"
with open(file_path, "r") as f:
    ko_koen_setntences = f.readlines()

In [13]:
# koja sentencef
file_path = "../data/dev.ja"
with open(file_path , "r") as f:
    jako_sentences_ja = f.readlines()

file_path = "../data/dev.ko"
with open(file_path , "r") as f:
    jako_sentences_ko = f.readlines()

In [25]:
koenja_bart = load_bart("../koen-ja/trained", "bart.pt").to(DEVICE)
koenja_base = load_bart("../koen-ja/trained", "base.pt").to(DEVICE)
# koenja_d = load_dict("../koen-ja/trained/dict.ja.txt")

In [22]:
type(koenja_bart)

fairseq.models.bart.hub_interface.BARTHubInterface

## Sentence similarity topk

### Korean and Japanese 

In [27]:
sent_similarity_topk(
    model1=koenja_base, model2=koenja_base, \
    sents1=ko_koen_setntences, sents2=en_jabart_sentences[:10] , \
    d1=koenja_d, d2=koenja_d
)

done cal vecs
[array([0.94447064, 0.8985309 , 0.77088374, 0.876096  , 0.91257375,
       0.87748945, 0.8765159 , 0.8220344 , 0.9025764 , 0.8711979 ],
      dtype=float32), array([0.90635896, 0.9177929 , 0.7812545 , 0.8878749 , 0.903947  ,
       0.8812882 , 0.8821026 , 0.8397099 , 0.9103645 , 0.88529974],
      dtype=float32), array([0.8803854 , 0.8578151 , 0.88624156, 0.90691185, 0.89150316,
       0.8919578 , 0.8877258 , 0.8240379 , 0.87416893, 0.861958  ],
      dtype=float32), array([0.9092634 , 0.8979603 , 0.8384323 , 0.9440382 , 0.91345626,
       0.91033113, 0.88732105, 0.838293  , 0.9128608 , 0.87709403],
      dtype=float32), array([0.9058174 , 0.8721152 , 0.75757927, 0.87674737, 0.94528437,
       0.8844552 , 0.88667333, 0.8213956 , 0.8935128 , 0.87877375],
      dtype=float32), array([0.89609975, 0.8907674 , 0.8317682 , 0.90842515, 0.91415983,
       0.94131553, 0.88462543, 0.83679104, 0.91248083, 0.8736946 ],
      dtype=float32), array([0.8919687 , 0.86697805, 0.80068225, 

In [26]:
sent_similarity_topk(
    model1=koenja_bart, model2=koenja_bart, \
    sents1=ko_koen_setntences, sents2=en_jabart_sentences[:10] , \
    d1=koenja_d, d2=koenja_d
)

done cal vecs
[array([0.62733096, 0.45849878, 0.37524152, 0.3588468 , 0.4197688 ,
       0.4074183 , 0.3803748 , 0.29117808, 0.36261916, 0.28462642],
      dtype=float32), array([0.40138695, 0.49521008, 0.3316309 , 0.34301957, 0.3484587 ,
       0.39920777, 0.31767103, 0.29428077, 0.3458485 , 0.30130962],
      dtype=float32), array([0.32192665, 0.35079235, 0.79198956, 0.5534144 , 0.53239304,
       0.5197051 , 0.4423126 , 0.3429982 , 0.33676696, 0.273505  ],
      dtype=float32), array([0.4476252 , 0.4222133 , 0.5505032 , 0.6753596 , 0.48915303,
       0.51628774, 0.41857058, 0.34983388, 0.44404706, 0.3060086 ],
      dtype=float32), array([0.41691366, 0.30735722, 0.37516597, 0.3803054 , 0.6869408 ,
       0.41474575, 0.36312482, 0.3180371 , 0.3709499 , 0.30966386],
      dtype=float32), array([0.34724092, 0.41902518, 0.5200223 , 0.5049714 , 0.50964004,
       0.71659416, 0.38403064, 0.31761035, 0.39487284, 0.29666388],
      dtype=float32), array([0.35167575, 0.35623065, 0.36319512, 

In [101]:
# koja (korean), jako (japanse)
sent_similarity_topk(
    model1=ft_koja, model2=ft_jako, \
    sents1=ko_koen_setntences, sents2=ja_enbart_sentences[:10] , \
    d1=koja_d, d2=koja_d
)

done cal vecs
[array([0.62593573, 0.44740704, 0.35251287, 0.41480747, 0.4190161 ,
       0.42780688, 0.3480967 , 0.43163708, 0.412306  , 0.349921  ],
      dtype=float32), array([0.41523075, 0.59781337, 0.31864694, 0.39931047, 0.39677545,
       0.42269796, 0.40781367, 0.40736115, 0.41286013, 0.39752078],
      dtype=float32), array([0.3015234 , 0.3140416 , 0.6738399 , 0.49981257, 0.37643817,
       0.4894787 , 0.34595373, 0.3578904 , 0.2984012 , 0.3152542 ],
      dtype=float32), array([0.4262741 , 0.44368464, 0.5446184 , 0.71414346, 0.45696956,
       0.56215847, 0.43032426, 0.41910085, 0.4427546 , 0.3417266 ],
      dtype=float32), array([0.39567402, 0.35154027, 0.3489403 , 0.41428366, 0.66027933,
       0.3902354 , 0.3205379 , 0.3841722 , 0.36703518, 0.37736973],
      dtype=float32), array([0.34219906, 0.4184179 , 0.5136139 , 0.5771184 , 0.43466803,
       0.6834449 , 0.38274866, 0.4112187 , 0.4157455 , 0.40587112],
      dtype=float32), array([0.38760516, 0.3872572 , 0.3614854 , 

In [91]:
# koja (korean), jako (japanse)
sent_similarity_topk(model1=ft_koja, model2=ft_jako, \
                     sents1=ko_jabart_sentences, sents2=ja_jabart_sentences , \
                     d1=koja_d, d2=koja_d)

done cal vecs
1 1741 0.8705
10 1943 0.9715


In [91]:
# koja (korean), jaBART (japanse)
sent_similarity_topk(model1=ft_koja, model2=bart_jako, \
                     sents1=jako_sentences, sents2=koja_sentences, \
                     d1=koja_d, d2=koja_d)

done cal vecs
1 1050 0.525
10 1645 0.8225


In [92]:
# koja (korean), jaen(japanse)
sent_similarity_topk(model1=ft_koja, model2=ft_jaen, \
                     sents1=jako_sentences, sents2=koja_sentences, \
                     d1=koja_d, d2=enja_d)

done cal vecs
1 1630 0.815
10 1926 0.963


### English and Japanese

In [93]:
# enja (english), jaen (japanse)
sent_similarity_topk(model1=ft_enja, model2=ft_jaen, \
                     sents1=jaen_sentences, sents2=enja_sentences, \
                     d1=enja_d, d2=enja_d)

done cal vecs
1 1478 0.739
10 1777 0.8885


In [94]:
# enja (english), jaBART (japanse)
sent_similarity_topk(model1=ft_enja, model2=bart_enja, \
                     sents1=jaen_sentences, sents2=enja_sentences, \
                     d1=enja_d, d2=enja_d)

done cal vecs
1 1137 0.5685
10 1670 0.835


In [95]:
# enja (english), jako (japanse)
sent_similarity_topk(model1=ft_enja, model2=ft_jako, \
                     sents1=jaen_sentences, sents2=enja_sentences, \
                     d1=enja_d, d2=koja_d)

done cal vecs
1 1334 0.667
10 1738 0.869


In [101]:
fusion_model(
    model1_path=f"{ft_jako_path}/{ft_jako_name}", 
    model2_path=f"{ft_jaen_path}/{ft_jaen_name}", 
    save_path="test/fusion_test"
)

In [75]:
fusion_model(
    model1_path=f"{ft_jako_path}/{ft_jako_name}", 
    model2_path=f"{ft_jaen_path}/{ft_jaen_name}", 
    save_path="fusion_jako_jaen"
)

False


In [85]:
fusion_model(
    model1_path=f"{ft_jako_path}/{ft_jako_name}", 
    model2_path=f"{ft_jaen_path}/{ft_jaen_name}", 
    save_path="fusion_jako_jaen_dec",
    dec_flag=True
)

decoder.layers.5.self_attn.k_proj.weight
decoder.layers.5.self_attn.k_proj.bias
decoder.layers.5.self_attn.v_proj.weight
decoder.layers.5.self_attn.v_proj.bias
decoder.layers.5.self_attn.q_proj.weight
decoder.layers.5.self_attn.q_proj.bias
decoder.layers.5.self_attn.out_proj.weight
decoder.layers.5.self_attn.out_proj.bias
decoder.layers.5.self_attn_layer_norm.weight
decoder.layers.5.self_attn_layer_norm.bias
decoder.layers.5.encoder_attn.k_proj.weight
decoder.layers.5.encoder_attn.k_proj.bias
decoder.layers.5.encoder_attn.v_proj.weight
decoder.layers.5.encoder_attn.v_proj.bias
decoder.layers.5.encoder_attn.q_proj.weight
decoder.layers.5.encoder_attn.q_proj.bias
decoder.layers.5.encoder_attn.out_proj.weight
decoder.layers.5.encoder_attn.out_proj.bias
decoder.layers.5.encoder_attn_layer_norm.weight
decoder.layers.5.encoder_attn_layer_norm.bias
decoder.layers.5.fc1.weight
decoder.layers.5.fc1.bias
decoder.layers.5.fc2.weight
decoder.layers.5.fc2.bias
decoder.layers.5.final_layer_norm.weig

In [14]:
bart_ende_path = "../../enBART/bart.base/trim"
bart_ende_name = "new_model.pt"
bart_ende = load_bart(path=bart_ende_path, model_name=bart_ende_name)
bart_ende = bart_ende.to(DEVICE)

In [15]:
ft_deen_path = "../../enBART/ft_model"
ft_deen_name = "checkpoint_best.pt"
ft_deen = load_bart(path=ft_deen_path, model_name=ft_deen_name)
ft_deen = ft_deen.to(DEVICE)

In [16]:
ft_jaen_path = "../ja-en/enbart/checkpoints"
ft_jaen_name = "checkpoint_best.pt"
ft_jaen_enbart = load_bart(path=ft_jaen_path, model_name=ft_jaen_name)
ft_jaen_enbart = ft_jaen_enbart.to(DEVICE)

In [17]:
ende_d = load_dict(ft_deen_path + "/" + "dict.en.txt")

## CKA

In [37]:
sent = enbart_fren_sentences_en[0]
token_ids = sent_to_ids(d=enfr_enbart_d, sent=sent)
print(sent)
print(token_ids)
generated = generate(token_ids.unsqueeze(0).to(DEVICE), enbart_ft_enfr)
" ".join(ids_to_tokens(enfr_enbart_d, generated[-1][0]['tokens']))

5122 14305 286 2489 318 262 717 286 428 39210 13

tensor([ 1429,  3678,     9,   224,    23,     6,   246,     9,    35, 13064,
            5])


'<s> ▁de ▁droits 11 ▁c 6 395 ▁le ▁premier ▁de ▁ce ▁millénaire 13 </s>'

In [36]:
sent = enbart_jaen_sentences_ja[0]
token_ids = sent_to_ids(d=enja_enbart_d, sent=sent)
print(sent)
print(token_ids)
generated = generate(token_ids.unsqueeze(0).to(DEVICE), enbart_ft_jaen)
" ".join(ids_to_tokens(enja_enbart_d, generated[-1][0]['tokens']))

▁しかし ▁、 ▁変 位 ▁ベクトル ▁計測 ▁を ▁行う ▁場合 ▁は ▁、 ▁その ▁支配 ▁的な ▁変 位 ▁方向 ▁を ▁検出 ▁する ▁必要 ▁が ▁ない ▁ので ▁便利 である ▁。

tensor([ 1653,     4,  2146,   925,  2663,  4095,    10,   594,   118,     8,
            4,   117, 12279,   332,  2146,   925,   222,    10,   218,    15,
          341,    17,    93,  1448, 12837,   529,     7])


'<s> 8875 341 286 262 29358 15879 318 11282 11 2158 11 1201 340 318 407 3306 284 4886 262 11410 29358 4571 15370 13 </s>'

In [19]:
# enBART, jaen
encdec_cka_sim(
    pre=enbart_enja, ft=enbart_ft_jaen, \
    pre_d=enja_enbart_d, ft_d=enja_enbart_d, \
    pre_sentences=enbart_jaen_sentences_en[:1000], ft_sentences=enbart_jaen_sentences_ja[:1000]
)

done 200
done 400
done 600
done 800
done 1000
Encoder CKA
Layer 0 0.32811857454424265
Layer 1 0.5546532730269308
Layer 2 0.5995680525914352
Layer 3 0.6204308990722559
Layer 4 0.6279079745304093
Layer 5 0.5875804892141104
Layer 6 0.6083146674879423
Layer 7 0.6083146674879423

Decoder CKA
Layer 0 0.09744597521954425
Layer 1 0.15249698280794938
Layer 2 0.210876683315914
Layer 3 0.23154576349387865
Layer 4 0.22852527316730292
Layer 5 0.2448341766431118
Layer 6 0.3007433409915509

Decoder Self Attention CKA
Layer 0 0.0939722852607291
Layer 1 0.15361276708531657
Layer 2 0.2028390013781639
Layer 3 0.22414851803754945
Layer 4 0.25860800875306866
Layer 5 0.23398130362536687


In [25]:
# enBART, enja
encdec_cka_sim(
    pre=enbart_enja, ft=enbart_ft_enja, \
    pre_d=enja_enbart_d, ft_d=enja_enbart_d, \
    pre_sentences=enbart_jaen_sentences_en[:1000], ft_sentences=enbart_jaen_sentences_en[:1000]
)

done 200
done 400
done 600
done 800
done 1000
Encoder CKA
Layer 0 0.9710718630659055
Layer 1 0.8120210146657719
Layer 2 0.8006389337341203
Layer 3 0.7654746487379535
Layer 4 0.7380350953301268
Layer 5 0.7333817872153834
Layer 6 0.7386213187787545
Layer 7 0.7386213187787545

Decoder CKA
Layer 0 0.06769563267976879
Layer 1 0.12355517600508875
Layer 2 0.1915150468153713
Layer 3 0.23190734801412766
Layer 4 0.22419532676016332
Layer 5 0.24174943072293711
Layer 6 0.4389424708551309

Decoder Self Attention CKA
Layer 0 0.08121777781972517
Layer 1 0.14379190379947235
Layer 2 0.2034657652326823
Layer 3 0.22121958176162368
Layer 4 0.2543349435979615
Layer 5 0.23425001394704456


In [14]:
def encdec_cka_sim_test(
    pre, ft, 
    pre_d, ft_d, 
    pre_sentences, ft_sentences
):
    pre_enc_d = defaultdict(list)
    pre_dec_d = defaultdict(list)
    pre_dec_self_attn_d = defaultdict(list)
    
    ft_dec_d = defaultdict(list)
    ft_enc_d = defaultdict(list)
    ft_dec_self_attn_d = defaultdict(list)
    
    for n, (pre_sent, ft_sent) in enumerate(zip(pre_sentences, ft_sentences)):
        if (n+1) % 200 == 0:
            print(f"done {n+1}")
        pre_sent = pre_sent.strip()
        ft_sent = ft_sent.strip()
#         print(pre_sent)
        pre_enc_vecs, pre_dec_vecs = encdec_cal_mean(model=pre, d=pre_d, sent=pre_sent)
#         print(ft_sent)
        ft_enc_vecs, ft_dec_vecs = encdec_cal_mean(model=ft, d=ft_d, sent=ft_sent)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_enc_vecs, ft_enc_vecs)):
            pre_enc_d[n].append(pre_vec)
            ft_enc_d[n].append(ft_vec)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_dec_vecs["inner_mean_layers"], \
                                                  ft_dec_vecs["inner_mean_layers"])):
            pre_dec_d[n].append(pre_vec)
            ft_dec_d[n].append(ft_vec)
        
        for n, (pre_vec, ft_vec) in enumerate(zip(pre_dec_vecs["self_attn_mean_layers"], \
                                                  ft_dec_vecs["self_attn_mean_layers"])):
            pre_dec_self_attn_d[n].append(pre_vec)
            ft_dec_self_attn_d[n].append(ft_vec)
    
    return pre_enc_d, ft_enc_d, pre_dec_d, ft_dec_d, pre_dec_self_attn_d, ft_dec_self_attn_d 
    
#     print("Encoder CKA")
#     for key in pre_enc_d.keys():
#         pre_matrix = np.array(pre_enc_d[key])
#         ft_matrix = np.array(ft_enc_d[key])
# #         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
#         print("Layer", key, linear_CKA(pre_matrix, ft_matrix))
    
#     print("\nDecoder CKA")
#     for key in pre_dec_d.keys():
#         pre_matrix = np.array(pre_dec_d[key])
#         ft_matrix = np.array(ft_dec_d[key])
# #         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
#         print("Layer", key, linear_CKA(pre_matrix, ft_matrix))
    
#     print("\nDecoder Self Attention CKA")
#     for key in pre_dec_self_attn_d.keys():
#         pre_matrix = np.array(pre_dec_self_attn_d[key])
#         ft_matrix = np.array(ft_dec_self_attn_d[key])
# #         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
#         print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

In [15]:
# enBART, enfr
pre_enc_d, ft_enc_d, pre_dec_d, ft_dec_d, pre_dec_self_attn_d, ft_dec_self_attn_d  = encdec_cka_sim_test(
    pre=enbart_enfr, ft=enbart_ft_enfr, \
    pre_d=enfr_enbart_d, ft_d=enfr_enbart_d, \
    pre_sentences=enbart_fren_sentences_en[:1000], ft_sentences=enbart_fren_sentences_en[:1000]
)

import copy
pre_enc_d_ = copy.deepcopy(pre_enc_d)
ft_enc_d_ = copy.deepcopy(ft_enc_d)
for key in pre_enc_d_.keys():
    pre_enc_d_[key].pop(301)
    ft_enc_d_[key].pop(301)
    
# encoder
print('encoder')
for key in pre_enc_d_.keys():
    pre_matrix = np.array(pre_enc_d_[key])
    ft_matrix = np.array(ft_enc_d_[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

    
# decoder
print('decoder')
for key in pre_dec_d.keys():
    pre_matrix = np.array(pre_dec_d[key])
    ft_matrix = np.array(ft_dec_d[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

# decoder
print('self attn')
for key in pre_dec_self_attn_d.keys():
    pre_matrix = np.array(pre_dec_self_attn_d[key])
    ft_matrix = np.array(ft_dec_self_attn_d[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

done 200
done 400
done 600
done 800
done 1000
encoder
Layer 0 0.9729578273360813
Layer 1 0.863921782443766
Layer 2 0.83642350711511
Layer 3 0.8197397063118146
Layer 4 0.7979322622420596
Layer 5 0.7767335145898732
Layer 6 0.7917862164124667
Layer 7 0.7917862164124667
decoder
Layer 0 0.13771677898439774
Layer 1 0.2756676511382894
Layer 2 0.479574223248925
Layer 3 0.5311533728022519
Layer 4 0.4342821369419174
Layer 5 0.37446123791379327
Layer 6 0.4570160011679051
self attn
Layer 0 0.1658540749843371
Layer 1 0.3927683910431462
Layer 2 0.5151794108942672
Layer 3 0.5453780886867194
Layer 4 0.5104035620170355
Layer 5 0.39448659546226084


In [23]:
# enBART, fren
pre_enc_d, ft_enc_d, pre_dec_d, ft_dec_d, pre_dec_self_attn_d, ft_dec_self_attn_d  = encdec_cka_sim_test(
    pre=enbart_enfr, ft=enbart_ft_fren, \
    pre_d=enfr_enbart_d, ft_d=enfr_enbart_d, \
    pre_sentences=enbart_fren_sentences_en[:1000], ft_sentences=enbart_fren_sentences_fr[:1000]
)

done 200
done 400
done 600
done 800
done 1000


In [24]:
import copy
pre_enc_d_ = copy.deepcopy(pre_enc_d)
ft_enc_d_ = copy.deepcopy(ft_enc_d)
for key in pre_enc_d_.keys():
    pre_enc_d_[key].pop(301)
    ft_enc_d_[key].pop(301)
    
# encoder
print('encoder')
for key in pre_enc_d_.keys():
    pre_matrix = np.array(pre_enc_d_[key])
    ft_matrix = np.array(ft_enc_d_[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

    
# decoder
print('decoder')
for key in pre_dec_d.keys():
    pre_matrix = np.array(pre_dec_d[key])
    ft_matrix = np.array(ft_dec_d[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

# decoder
print('self attn')
for key in pre_dec_self_attn_d.keys():
    pre_matrix = np.array(pre_dec_self_attn_d[key])
    ft_matrix = np.array(ft_dec_self_attn_d[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

encoder
Layer 0 0.1752605944275357
Layer 1 0.6512307162908527
Layer 2 0.6710861384363259
Layer 3 0.6888394556049736
Layer 4 0.6883013824515215
Layer 5 0.6883653756778685
Layer 6 0.702137791005405
Layer 7 0.702137791005405
decoder
Layer 0 0.1904847206331985
Layer 1 0.3102052891602589
Layer 2 0.5059126944672591
Layer 3 0.5266502616683636
Layer 4 0.4802726042683419
Layer 5 0.38671790971708636
Layer 6 0.47487189395441165
self attn
Layer 0 0.19369339189243265
Layer 1 0.41890349497784435
Layer 2 0.5157724904991667
Layer 3 0.5434187574018567
Layer 4 0.501903158446825
Layer 5 0.4457863845303657


In [31]:
# encoder
for key in list(pre_enc_d_.keys()):
    pre_matrix = np.array(pre_enc_d_[key][:100])
    ft_matrix = np.array(ft_enc_d_[key][:100])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

Layer 0 0.3686674118939949
Layer 1 0.764543460352847
Layer 2 0.7732998699589144
Layer 3 0.7783773902712067
Layer 4 0.7768667884210361
Layer 5 0.7649257051404549
Layer 6 0.75241364353149
Layer 7 0.75241364353149


In [29]:
# encoder
for key in list(pre_enc_d_.keys()):
    pre_matrix = np.array(pre_enc_d_[key][:1000])
    ft_matrix = np.array(ft_enc_d_[key][:1000])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

Layer 0 0.1751167957811808
Layer 1 0.6485063065422028
Layer 2 0.6682326277343577
Layer 3 0.6869086617683943
Layer 4 0.6869959584973313
Layer 5 0.6876304094033615
Layer 6 0.701408444348908
Layer 7 0.701408444348908


In [23]:
# encoder
for key in pre_enc_d_.keys():
    pre_matrix = np.array(pre_enc_d_[key])
    ft_matrix = np.array(ft_enc_d_[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

Layer 0 0.1961328198444898
Layer 1 0.6416696873840658
Layer 2 0.6561150719027569
Layer 3 0.6778685181706523
Layer 4 0.6813434228322837
Layer 5 0.6833915482410283
Layer 6 0.6987504653861911
Layer 7 0.6987504653861911


In [24]:
# decoder
for key in pre_dec_d.keys():
    pre_matrix = np.array(pre_dec_d[key])
    ft_matrix = np.array(ft_dec_d[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

Layer 0 0.17552321581938815
Layer 1 0.28511377663220694
Layer 2 0.47311819037920333
Layer 3 0.5033558407798207
Layer 4 0.4611915764592244
Layer 5 0.3678856021885661
Layer 6 0.465555820709266


In [25]:
# decoder
for key in pre_dec_self_attn_d.keys():
    pre_matrix = np.array(pre_dec_self_attn_d[key])
    ft_matrix = np.array(ft_dec_self_attn_d[key])
#     print(pre_matrix.shape, ft_matrix.shape)
#         print("v1", key, cal_cka_sim(jp_bart_matrix, ft_matrix))
    print("Layer", key, linear_CKA(pre_matrix, ft_matrix))

Layer 0 0.1837217050701935
Layer 1 0.3856648766373253
Layer 2 0.4846951074641052
Layer 3 0.5186218163180184
Layer 4 0.48464657473181366
Layer 5 0.4293935656848947


In [18]:
# enBART, enfr
encdec_cka_sim(
    pre=enbart_enfr, ft=enbart_ft_enfr, \
    pre_d=enfr_enbart_d, ft_d=enfr_enbart_d, \
    pre_sentences=enbart_fren_sentences_en[:100], ft_sentences=enbart_fren_sentences_en[:100]
)

Encoder CKA
Layer 0 0.9780712565411755
Layer 1 0.8937279811922552
Layer 2 0.8658691974817965
Layer 3 0.8453532725587504
Layer 4 0.8418310600469964
Layer 5 0.8242200640133998
Layer 6 0.8231984675617963
Layer 7 0.8231984675617963

Decoder CKA
Layer 0 0.25038420598361627
Layer 1 0.343640572874969
Layer 2 0.5370430129241208
Layer 3 0.5962060892498431
Layer 4 0.5279244853339019
Layer 5 0.45512418680039474
Layer 6 0.6828338541700857

Decoder Self Attention CKA
Layer 0 0.2546367824372348
Layer 1 0.4726495006136006
Layer 2 0.5777921159012347
Layer 3 0.6183637663639528
Layer 4 0.6065769455032194
Layer 5 0.5498576622801526


In [156]:
# jaBART, jako
encdec_cka_sim(
    pre=jabart_jako, ft=ft_jako, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=jako_sentences_ja, ft_sentences=jako_sentences_ja
)

done 200
done 400
done 600
done 800
done 1000
done 1200
done 1400
done 1600
done 1800
done 2000
Encoder CKA
Layer 0 0.9733945818732439
Layer 1 0.8687681675939442
Layer 2 0.8434158785275693
Layer 3 0.9370607349495887
Layer 4 0.9614628335582258
Layer 5 0.9700674107880242
Layer 6 0.9635520520292242
Layer 7 0.7176681915201348

Decoder CKA
Layer 0 0.16763322071215492
Layer 1 0.5132706661887682
Layer 2 0.5282737227516087
Layer 3 0.5554090535647359
Layer 4 0.519369563986752
Layer 5 0.43157912157737083
Layer 6 0.02994465836785132

Decoder Self Attention CKA
Layer 0 0.29671080036439834
Layer 1 0.5727299236081045
Layer 2 0.5479109091949836
Layer 3 0.49175863056010855
Layer 4 0.415207337167861
Layer 5 0.222258803687259


In [150]:
# jaBART, jako
encdec_cka_sim(
    pre=jabart_jako, ft=ft_jako, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=jako_sentences_ja[:100], ft_sentences=jako_sentences_ja[:100]
)

Encoder CKA
Layer 0 0.9759973298327697
Layer 1 0.9256172266120095
Layer 2 0.8741730657158969
Layer 3 0.9465301895243656
Layer 4 0.969123673766228
Layer 5 0.9790643359818927
Layer 6 0.9741754062303096
Layer 7 0.8699713363496407

Decoder CKA
Layer 0 0.26566675837396675
Layer 1 0.6244883834420452
Layer 2 0.6167888731680147
Layer 3 0.6883192185988544
Layer 4 0.6888598779958519
Layer 5 0.6522734899382354
Layer 6 0.07072162457780592

Decoder Self Attention CKA
Layer 0 0.37322840753426545
Layer 1 0.6726726791306216
Layer 2 0.659283002403375
Layer 3 0.6400832383024861
Layer 4 0.6084594033549899
Layer 5 0.3837667729528462


In [155]:
# jaBART, jako
encdec_cka_sim(
    pre=jabart_jako, ft=ft_jako, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=jaen_sentences_ja[:100], ft_sentences=jaen_sentences_ja[:100]
)

Encoder CKA
Layer 0 0.9729115896163111
Layer 1 0.9135987907704255
Layer 2 0.821561152461635
Layer 3 0.8680825866974448
Layer 4 0.9170677291728615
Layer 5 0.9297949430319115
Layer 6 0.9092694707275283
Layer 7 0.8168259950394288

Decoder CKA
Layer 0 0.4915655752079639
Layer 1 0.6809627218543579
Layer 2 0.6104667347317372
Layer 3 0.6451872520219702
Layer 4 0.5917880699782185
Layer 5 0.5479433708283096
Layer 6 0.09042188181138364

Decoder Self Attention CKA
Layer 0 0.5492166624244063
Layer 1 0.7044201948904427
Layer 2 0.6110072979431107
Layer 3 0.5474638884843045
Layer 4 0.49048689881554514
Layer 5 0.3571971711707886


In [154]:
# jaBART, jaen
encdec_cka_sim(
    pre=jabart_jako, ft=ft_jaen, \
    pre_d=koja_d, ft_d=enja_d, \
    pre_sentences=jaen_sentences_ja[:100], ft_sentences=jaen_sentences_ja[:100]
)

Encoder CKA
Layer 0 0.970523551005553
Layer 1 0.9151731549531593
Layer 2 0.8277985826173108
Layer 3 0.9280565330505587
Layer 4 0.960762716719473
Layer 5 0.9697891275931673
Layer 6 0.9642266387392431
Layer 7 0.8444771900088857

Decoder CKA
Layer 0 0.4263150415348768
Layer 1 0.571452449384793
Layer 2 0.6663581064078352
Layer 3 0.7303663734183071
Layer 4 0.6982537767593614
Layer 5 0.6197065847507601
Layer 6 0.08705909603272607

Decoder Self Attention CKA
Layer 0 0.48895318702558804
Layer 1 0.6127027821668646
Layer 2 0.704060883088779
Layer 3 0.6916327937374782
Layer 4 0.6116822922622422
Layer 5 0.4720994915017967


In [157]:
# jaBART, jaen
encdec_cka_sim(
    pre=jabart_jako, ft=ft_jaen, \
    pre_d=koja_d, ft_d=enja_d, \
    pre_sentences=jako_sentences_ja, ft_sentences=jako_sentences_ja
)

done 200
done 400
done 600
done 800
done 1000
done 1200
done 1400
done 1600
done 1800
done 2000
Encoder CKA
Layer 0 0.9698217565326317
Layer 1 0.8682271338603668
Layer 2 0.830867335717296
Layer 3 0.9593129674841185
Layer 4 0.9793769867646098
Layer 5 0.9828426449643142
Layer 6 0.9808061057361167
Layer 7 0.7834330659491489

Decoder CKA
Layer 0 0.16921201369717648
Layer 1 0.437809268078041
Layer 2 0.5526726631509171
Layer 3 0.6140158325860415
Layer 4 0.6323702549712223
Layer 5 0.525754290394626
Layer 6 0.049906383588926094

Decoder Self Attention CKA
Layer 0 0.2743104183527964
Layer 1 0.48853681863761567
Layer 2 0.5983830092594575
Layer 3 0.6143103270959407
Layer 4 0.5513991431027618
Layer 5 0.3425840837702103


In [48]:
# jaen, jaBART
encdec_cka_sim(
    pre=jabart_jako, ft=ft_jaen, \
    pre_d=koja_d, ft_d=enja_d, \
    pre_sentences=ja_jaenbart_sentences[:100], ft_sentences=ja_jaenbart_sentences[:100]
)

Encoder CKA
Layer 0 0.970523551005553
Layer 1 0.9151731549531593
Layer 2 0.8277985826173108
Layer 3 0.9280565330505587
Layer 4 0.960762716719473
Layer 5 0.9697891275931673
Layer 6 0.9642266387392431
Layer 7 0.8444771900088857

Decoder CKA
Layer 0 0.4263150415348768
Layer 1 0.571452449384793
Layer 2 0.6663581064078352
Layer 3 0.7303663734183071
Layer 4 0.6982537767593614
Layer 5 0.6197065847507601
Layer 6 0.08705909603272607


In [26]:
# jaBART, jako_fa
encdec_cka_sim(
    pre=jabart_jako_fa, ft=ft_jako_fa, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ja_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.9674408329461288
Layer 1 0.9343761404528619
Layer 2 0.75086902735202
Layer 3 0.7450048805260591
Layer 4 0.8005402722944482
Layer 5 0.832185747041194
Layer 6 0.8449792487178245
Layer 7 0.8687638817449904

Decoder CKA
Layer 0 0.3829109775723496
Layer 1 0.6179080335070457
Layer 2 0.567610266163832
Layer 3 0.6538575803829098
Layer 4 0.6969924748737887
Layer 5 0.6247609008165051
Layer 6 0.051791402605285794


In [31]:
# jaBART, jako_fa
encdec_cka_sim(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ko_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.5093697481332495
Layer 1 0.8607331305958683
Layer 2 0.7021390384058981
Layer 3 0.5819528201946342
Layer 4 0.4928926286905733
Layer 5 0.49713961149770364
Layer 6 0.5818514551276623
Layer 7 0.8340159722100928

Decoder CKA
Layer 0 0.5510239926259972
Layer 1 0.7549550905931844
Layer 2 0.6667061462110943
Layer 3 0.7398014686717572
Layer 4 0.735674058974253
Layer 5 0.6852968467697668
Layer 6 0.07426608661909906


In [27]:
# jaBART, koja_fa
encdec_cka_sim(
    pre=jabart_jako_fa, ft=ft_koja_fa, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ko_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.486273220571243
Layer 1 0.8595038398785774
Layer 2 0.7033624169068726
Layer 3 0.5827525385954249
Layer 4 0.4929243158033353
Layer 5 0.49678832290368524
Layer 6 0.5807955722898139
Layer 7 0.8324206411740162

Decoder CKA
Layer 0 0.4713994422494003
Layer 1 0.6870161440890267
Layer 2 0.6200601975537867
Layer 3 0.7012665521444087
Layer 4 0.6967926584543497
Layer 5 0.6513957015346789
Layer 6 0.07202165970481228


In [152]:
# enja, jaBART
encdec_cka_sim(
    pre=jabart_jako, ft=ft_enja, 
    pre_d=koja_d, ft_d=enja_d, 
    pre_sentences=jaen_sentences_ja[:100], ft_sentences=jaen_sentences_en[:100]
)

Encoder CKA
Layer 0 0.2133956818154692
Layer 1 0.7509003912833928
Layer 2 0.6922122203834341
Layer 3 0.6069842822146548
Layer 4 0.5733573434807498
Layer 5 0.5827822053694163
Layer 6 0.6114356266047001
Layer 7 0.778814152549551

Decoder CKA
Layer 0 0.6823823424997126
Layer 1 0.6796465853609647
Layer 2 0.4943844554526894
Layer 3 0.6165263221615128
Layer 4 0.7040800035040596
Layer 5 0.6710977798490118
Layer 6 0.13769230941600222

Decoder Self Attention CKA
Layer 0 0.6551121034156565
Layer 1 0.7128061060160942
Layer 2 0.5553544668569662
Layer 3 0.678281670753409
Layer 4 0.6748196715070189
Layer 5 0.5423299911263483


In [45]:
# enja, jaBART
encdec_cka_sim(
    pre=jabart_jako, ft=ft_enja, 
    pre_d=koja_d, ft_d=enja_d, 
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=en_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.2133956818154692
Layer 1 0.7509003912833928
Layer 2 0.6922122203834341
Layer 3 0.6069842822146548
Layer 4 0.5733573434807498
Layer 5 0.5827822053694163
Layer 6 0.6114356266047001
Layer 7 0.778814152549551

Decoder CKA
Layer 0 0.6823823424997126
Layer 1 0.6796465853609647
Layer 2 0.4943844554526894
Layer 3 0.6165263221615128
Layer 4 0.7040800035040596
Layer 5 0.6710977798490118
Layer 6 0.13769230941600222


In [46]:
encdec_cka_sim_rm_special_tokens(
    pre=jabart_jako, ft=ft_enja, 
    pre_d=koja_d, ft_d=enja_d, 
    pre_sentences=ja_jaenbart_sentences[:100], ft_sentences=en_jaenbart_sentences[:100]
)

Encoder CKA
Layer 0 0.21473489816154617
Layer 1 0.7498205167981
Layer 2 0.691505666374513
Layer 3 0.6068007578536602
Layer 4 0.573210229228657
Layer 5 0.5827459905508178
Layer 6 0.6114135420006851
Layer 7 0.7790729070819575

Decoder CKA
Layer 0 0.6824959605189248
Layer 1 0.6797730053523041
Layer 2 0.49437331951835245
Layer 3 0.6168067794802606
Layer 4 0.7047705761430938
Layer 5 0.6709161777008913
Layer 6 0.1376089437361425


In [147]:
# enja, jaen
encdec_cka_sim(pre=ft_jaen, ft=ft_enja, pre_d=enja_d, ft_d=enja_d, pre_sentences=enja_sentences[:100], ft_sentences=jaen_sentences[:100])

Encoder CKA
Layer 0 0.44956301094439244
Layer 1 0.7614944188015746
Layer 2 0.6400164042235462
Layer 3 0.5321559040794689
Layer 4 0.5105673216760417
Layer 5 0.5171827178674153
Layer 6 0.553711656819959
Layer 7 0.8278560807765531

Decoder CKA
Layer 0 0.638503693999009
Layer 1 0.5501354377787423
Layer 2 0.4896567254842808
Layer 3 0.5422355298174486
Layer 4 0.6054068802227315
Layer 5 0.700545378921386
Layer 6 0.6799855957357414


In [153]:
# koja, jaBART
encdec_cka_sim(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=jako_sentences_ja[:100], ft_sentences=jako_sentences_ko[:100]
)

Encoder CKA
Layer 0 0.3942145081540587
Layer 1 0.829636910217655
Layer 2 0.8718990500765171
Layer 3 0.9163439671916835
Layer 4 0.9247556100999105
Layer 5 0.9253361288367854
Layer 6 0.9351865740131241
Layer 7 0.8551808660096432

Decoder CKA
Layer 0 0.5268097265235973
Layer 1 0.7539800268719123
Layer 2 0.7297802524794936
Layer 3 0.7843447525690713
Layer 4 0.7566099680784231
Layer 5 0.7428515601808281
Layer 6 0.08962926898552133

Decoder Self Attention CKA
Layer 0 0.5117977862552471
Layer 1 0.7677623124738384
Layer 2 0.7748265491152246
Layer 3 0.7346492567630117
Layer 4 0.6835754858209302
Layer 5 0.5693995840357118


In [142]:
# koja, jaBART
encdec_cka_sim(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ko_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.39813758667059446
Layer 1 0.8372665742577943
Layer 2 0.884557799175998
Layer 3 0.9167199884185456
Layer 4 0.9193502656476504
Layer 5 0.9218732232526243
Layer 6 0.9340356606040892
Layer 7 0.8588061707660819

Decoder CKA
Layer 0 0.5050652878275668
Layer 1 0.7533884376494239
Layer 2 0.7310251120235626
Layer 3 0.7696943537806099
Layer 4 0.7340218602447084
Layer 5 0.7219580688844796
Layer 6 0.06652651826563986


In [22]:
# koja, jaBART
encdec_cka_sim_rm_special_tokens(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ko_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.39748979105451354
Layer 1 0.8296825798725999
Layer 2 0.8719219192291665
Layer 3 0.9163728779845367
Layer 4 0.9247754575194317
Layer 5 0.9253441818207621
Layer 6 0.9352153326508204
Layer 7 0.8556609981027861

Decoder CKA
Layer 0 0.5292802317066537
Layer 1 0.7537436719615525
Layer 2 0.7294155346030693
Layer 3 0.7841146349196392
Layer 4 0.7570655568098495
Layer 5 0.7434283779149016
Layer 6 0.0895921941697945


In [59]:
encdec_cka_sim(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ko_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.39813758667059446
Layer 1 0.8372665742577943
Layer 2 0.884557799175998
Layer 3 0.9167199884185456
Layer 4 0.9193502656476504
Layer 5 0.9218732232526243
Layer 6 0.9340356606040892
Layer 7 0.8588061707660819

Decoder CKA
Layer 0 0.5050652878275668
Layer 1 0.7533884376494239
Layer 2 0.7310251120235626
Layer 3 0.7696943537806099
Layer 4 0.7340218602447084
Layer 5 0.7219580688844796
Layer 6 0.06652651826563986


In [60]:
# koja, jaBART use last token's hidden states
encdec_cka_sim_v2(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:100], ft_sentences=ko_jabart_sentences[:100]
)

Encoder CKA
Layer 0 0.39813758667059446
Layer 1 0.8372665742577943
Layer 2 0.884557799175998
Layer 3 0.9167199884185456
Layer 4 0.9193502656476504
Layer 5 0.9218732232526243
Layer 6 0.9340356606040892
Layer 7 0.8588061707660819

Decoder CKA
Layer 0 0.18737031889213474
Layer 1 0.27824082522955756
Layer 2 0.2912734439356014
Layer 3 0.3114805956735732
Layer 4 0.3444135556508579
Layer 5 0.293009997229996
Layer 6 0.0256687581392078


In [64]:
encdec_cka_sim(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[:20], ft_sentences=ko_jabart_sentences[:20]
)

Encoder CKA
Layer 0 0.5447266789603019
Layer 1 0.9408569768496966
Layer 2 0.89988828928416
Layer 3 0.882193235810978
Layer 4 0.8761381932163314
Layer 5 0.8806489404925734
Layer 6 0.9109798977908943
Layer 7 0.9425632468949836

Decoder CKA
Layer 0 0.6717251359159946
Layer 1 0.8700682889912937
Layer 2 0.837079933312262
Layer 3 0.8903882350217927
Layer 4 0.8628185237805014
Layer 5 0.8448014053877598
Layer 6 0.4144622661368667


In [62]:
encdec_cka_sim_v2(
    pre=jabart_jako, ft=ft_koja, \
    pre_d=koja_d, ft_d=koja_d, \
    pre_sentences=ja_jabart_sentences[-100:], ft_sentences=ko_jabart_sentences[-100:]
)

Encoder CKA
Layer 0 0.3692928894893377
Layer 1 0.816572963563203
Layer 2 0.8919192527464466
Layer 3 0.9184192672084958
Layer 4 0.9162769446453053
Layer 5 0.9181849883188986
Layer 6 0.9320291458613279
Layer 7 0.8327011045213755

Decoder CKA
Layer 0 0.13916792912710638
Layer 1 0.2121558555510714
Layer 2 0.2224732547177916
Layer 3 0.24894895052957788
Layer 4 0.2772183862086337
Layer 5 0.21853501853480337
Layer 6 0.03117152116482811


In [169]:
encdec_cka_sim_v2(pre=bart_jako, ft=ft_koja, \
                  pre_d=koja_d, ft_d=koja_d, \
                  pre_sentences=koja_sentences[:100], ft_sentences=jako_sentences[:100])

Encoder CKA
Layer 0 0.39813758667059446
Layer 1 0.8372665742577943
Layer 2 0.884557799175998
Layer 3 0.9167199884185456
Layer 4 0.9193502656476504
Layer 5 0.9218732232526243
Layer 6 0.9340356606040892
Layer 7 0.8588061707660819

Decoder CKA
Layer 0 0.18737031889213474
Layer 1 0.27824082522955756
Layer 2 0.2912734439356014
Layer 3 0.3114805956735732
Layer 4 0.3444135556508579
Layer 5 0.293009997229996
Layer 6 0.0256687581392078


In [143]:
# koja, jako
encdec_cka_sim(pre=ft_jako, ft=ft_koja, pre_d=koja_d, ft_d=koja_d, pre_sentences=koja_sentences[:100], ft_sentences=jako_sentences[:100])

Encoder CKA
Layer 0 0.38914966417195
Layer 1 0.8772289821922453
Layer 2 0.8899276493629373
Layer 3 0.9037599286008311
Layer 4 0.9058940246759116
Layer 5 0.9117364861394909
Layer 6 0.9338582542674801
Layer 7 0.8843983840721953

Decoder CKA
Layer 0 0.8255794427451447
Layer 1 0.7307286521922709
Layer 2 0.7551462865527659
Layer 3 0.7159976288355673
Layer 4 0.7050116593377683
Layer 5 0.7698804841836703
Layer 6 0.48371331175894106


In [101]:
# jako, jaBART
encdec_cka_sim(pre=bart_jako, ft=ft_jako, pre_d=koja_d, ft_d=koja_d, sentences=koja_sentences[:100])

Encoder CKA
Layer 0 0.9764038300133336
Layer 1 0.9262711358440008
Layer 2 0.8891990692662358
Layer 3 0.9495674060863042
Layer 4 0.9690189805392235
Layer 5 0.9769565647363784
Layer 6 0.9769848289594428
Layer 7 0.8691340463003271

Decoder CKA
Layer 0 0.2885973356828191
Layer 1 0.561022844039608
Layer 2 0.5292809256793927
Layer 3 0.6442883136936155
Layer 4 0.7444158038803855
Layer 5 0.6832025402794522
Layer 6 0.06761801081006032


In [102]:
# jako, jaen
encdec_cka_sim(pre=ft_jako, ft=ft_jaen, pre_d=koja_d, ft_d=enja_d, sentences=koja_sentences[:100])

Encoder CKA
Layer 0 0.9897062022986469
Layer 1 0.942404347304907
Layer 2 0.9360558006020253
Layer 3 0.9749580252041363
Layer 4 0.9854106431689292
Layer 5 0.9884419785774842
Layer 6 0.9847088999092097
Layer 7 0.8975083491452324

Decoder CKA
Layer 0 0.7752163404416937
Layer 1 0.7797717512225465
Layer 2 0.8455108030791875
Layer 3 0.8263716658756463
Layer 4 0.7835157133452053
Layer 5 0.7382126413634155
Layer 6 0.41533271093317037


## enBART

In [19]:
# enBART, enja
encdec_cka_sim(
    pre=enbart_enja, ft=enbart_ft_enja, \
    pre_d=enja_enbart_d, ft_d=enja_enbart_d, \
    pre_sentences=en_enbart_sentences[:1000], ft_sentences=en_enbart_sentences[:1000]
)

done 200
done 400
done 600
done 800
done 1000
Encoder CKA
Layer 0 0.9710718630659055
Layer 1 0.8120210146657719
Layer 2 0.8006389337341203
Layer 3 0.7654746487379535
Layer 4 0.7380350953301268
Layer 5 0.7333817872153834
Layer 6 0.7386213187787545
Layer 7 0.7386213187787545

Decoder CKA
Layer 0 0.06769563267976879
Layer 1 0.12355517600508875
Layer 2 0.1915150468153713
Layer 3 0.23190734801412766
Layer 4 0.22419532676016332
Layer 5 0.24174943072293711
Layer 6 0.4389424708551309


In [189]:
encdec_cka_sim(pre=ft_enja_enbart, ft=bart_ende, pre_d=enja_enbart_d, ft_d=ende_d, sentences=en_enbart_sentences[:100])

Encoder CKA
Layer 0 0.9800959491456113
Layer 1 0.9063465790962039
Layer 2 0.8859278730303163
Layer 3 0.8680285143785702
Layer 4 0.8561370800893219
Layer 5 0.8507612922158924
Layer 6 0.8457999253847499
Layer 7 0.8457999253847499

Decoder CKA
Layer 0 0.18947321394560013
Layer 1 0.29860176561342905
Layer 2 0.4003624704114809
Layer 3 0.45474652714640984
Layer 4 0.41357241889775465
Layer 5 0.38817485409671454
Layer 6 0.6640995950456667


In [18]:
# enBART, jaen
encdec_cka_sim(
    pre=enbart_enja, ft=enbart_ft_jaen, \
    pre_d=enja_enbart_d, ft_d=enja_enbart_d, \
    pre_sentences=en_enbart_sentences[:1000], ft_sentences=ja_enbart_sentences[:1000]
)

done 200
done 400
done 600
done 800
done 1000
Encoder CKA
Layer 0 0.3266183627050019
Layer 1 0.5554550327393063
Layer 2 0.5996869103216395
Layer 3 0.6228156618648397
Layer 4 0.6345839533936144
Layer 5 0.5826970221310772
Layer 6 0.6178093575017386
Layer 7 0.6178093575017386

Decoder CKA
Layer 0 0.10103534028961636
Layer 1 0.15979227332956158
Layer 2 0.21093274378601296
Layer 3 0.2341980910356678
Layer 4 0.2375313805561715
Layer 5 0.26754144559807225
Layer 6 0.4110353001871192


In [152]:
encdec_cka_sim(pre=ft_jaen_enbart, ft=bart_ende, \
               pre_d=enja_enbart_d, ft_d=ende_d, \
               pre_sentences=ja_enbart_sentences[:100], ft_sentences=en_enbart_sentences[:100])

Encoder CKA
Layer 0 0.4826819965841025
Layer 1 0.7448068701601226
Layer 2 0.7608820498717873
Layer 3 0.774781274391684
Layer 4 0.7782604994262258
Layer 5 0.7560034084652245
Layer 6 0.7619852822661689
Layer 7 0.7619852822661689

Decoder CKA
Layer 0 0.24741065684638075
Layer 1 0.358124774393418
Layer 2 0.4407861368273513
Layer 3 0.486088296263293
Layer 4 0.47788928752847304
Layer 5 0.47284600780030955
Layer 6 0.6871021252518612


In [153]:
encdec_cka_sim(pre=ft_jaen_enbart, ft=bart_ende, \
               pre_d=enja_enbart_d, ft_d=ende_d, \
               pre_sentences=en_enbart_sentences[:100], ft_sentences=en_enbart_sentences[:100])

Encoder CKA
Layer 0 0.9840656777237077
Layer 1 0.5583370897039305
Layer 2 0.505575642587246
Layer 3 0.41657622538682904
Layer 4 0.3298949865619421
Layer 5 0.2737503780868102
Layer 6 0.23969355799771241
Layer 7 0.23969355799771241

Decoder CKA
Layer 0 0.09492613189329874
Layer 1 0.13209263859133985
Layer 2 0.16080890859662506
Layer 3 0.17255819495854252
Layer 4 0.174431708171571
Layer 5 0.20196991993492058
Layer 6 0.29839598438325493


## ENJA

In [6]:
bart_enja_path = "../../jpBART/en-ja/ja-bart"
bart_enja_name = "model.pt"
bart_enja = load_bart(path=bart_enja_path, model_name=bart_enja_name)
bart_enja.to(DEVICE)

ft_jaen_path = "../../jpBART/en-ja/jaen/ft_model"
ft_jaen_name = "checkpoint_best.pt"
ft_jaen = load_bart(path=ft_jaen_path, model_name=ft_jaen_name)
ft_jaen.to(DEVICE)

print("loaded models")

enja_d = load_dict("../../jpBART/en-ja/ja-bart/dict.txt")
file_path = "../../jpBART/en-ja/data/dev.ja"
with open(file_path , "r") as f:
    enja_sentences = f.readlines()

loaded models


In [7]:
file_path = "../../jpBART/en-ja/data/dev.en"
with open(file_path , "r") as f:
    jaen_sentences = f.readlines()

In [8]:
ft_enja_path = "../../jpBART/en-ja/ft_model"
ft_enja_name = "checkpoint_best.pt"
ft_enja = load_bart(path=ft_enja_path, model_name=ft_enja_name)
ft_enja.to(DEVICE)
print("loaded models")

loaded models


In [9]:
file_path = "../../jpBART/data/dev.ja"
with open(file_path , "r") as f:
    koja_sentences = f.readlines()

In [10]:
file_path = "../../jpBART/data/dev.ko"
with open(file_path , "r") as f:
    jako_sentences = f.readlines()

## KOJA

In [11]:
bart_jako_path = "../../jpBART/japanese_bart_base_1.1/test"
bart_jako_name = "model.pt"
ft_jako_path = "../../jpBART/ja-ko/bart/ft/"
ft_jako_name = "checkpoint_best.pt"
bart_jako = load_bart(path=bart_jako_path, model_name=bart_jako_name)
bart_jako.to(DEVICE)
ft_jako = load_bart(path=ft_jako_path, model_name=ft_jako_name)
ft_jako.to(DEVICE)
koja_d = load_dict("../../jpBART/japanese_bart_base_1.1/test/dict.txt")
print("loaded models")

loaded models


In [12]:
ft_koja_path = "../../jpBART/v3/ft"
ft_koja_name = "checkpoint_best.pt"
ft_koja = load_bart(path=ft_koja_path, model_name=ft_koja_name)
ft_koja.to(DEVICE)
print("loaded models")

loaded models


## JAJA

In [142]:
ft_jaja_path = "../../jpBART/jaja/bart/ft"
ft_jaja_name = "checkpoint_best.pt"
ft_jaja = load_bart(path=ft_jaja_path, model_name=ft_jaja_name)
ft_jaja.to(DEVICE)
print("loaded models")

loaded models
