# Retro[on] - Retro[off] loss diff analysis

In this analysis we will investigate whether the difference in loss (i.e., with vs. without retrieval) correlates with the number of tokens that overlap (between retrieved and to-be-predicted tokens).

In [None]:
import sys
sys.path.append('retro/src')

In [None]:
import json
import numpy as np
import torch
import pytorch_lightning as pl
from scipy.stats import spearmanr, pearsonr
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from train_retro import get_retro_dataset_from_spec, retro_collate_fn, RetroModelLMHeadLightning
from modeling_retro import RetroConfig

In [None]:
BATCH_SIZE = 10

In [None]:
retro_on_config = RetroConfig(**json.load(Path('retro/data/model/retro.json').open()))

val_ds = get_retro_dataset_from_spec(
    spec_file=Path('retro/data/datasets/MassiveOpenText/val_sentence_transformer_neighbours.spec.json'),
    num_neighbours=2,
    continuation_chunks=1,
    pad_token_idx=retro_on_config.pad_token_idx,
    max_len=1024
)

val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=4)
print(f"Total size of validation set: {len(val_ds)}")

### Unigram token overlap
The unigram overlap between the input and its neighbors (and contunation chunks).

In [None]:
from dataset_retro import RetroTrainingExample

overlap_ds = get_retro_dataset_from_spec(
    spec_file=Path('retro/data/datasets/MassiveOpenText/val_sentence_transformer_neighbours.spec.json'),
    num_neighbours=2,
    continuation_chunks=1,
    pad_token_idx=retro_on_config.pad_token_idx,
    max_len=1024
)


def get_attending_chunk_neighbour_unigram_overlap_for_sequence(ex: RetroTrainingExample, pad_token_idx=0):
    num_chunks = ex.neighbour_ids.shape[0]
    num_neighbours = ex.neighbour_ids.shape[1]
    neighbour_size = ex.neighbour_ids.shape[2]
    input_ids = ex.input_ids.numpy().reshape(num_chunks, -1)
    chunk_size = input_ids.shape[1]
    neighbour_ids = ex.neighbour_ids.numpy()
    
    if num_chunks == 1:
        return []

    overlaps = np.zeros(num_chunks - 1)

    for i in range(1, num_chunks):
        first_pad_index_of_chunk = np.nonzero(input_ids[i, :] == pad_token_idx)[0]
        chunk_len = first_pad_index_of_chunk[0] if first_pad_index_of_chunk.size > 0 else chunk_size
        overlaps[i-1] = np.in1d(input_ids[i, :chunk_len], neighbour_ids[i-1]).sum()
    return overlaps

In [None]:
u_overlaps = []
for i, ex in tqdm(enumerate(overlap_ds), total=len(overlap_ds)):
    u_overlaps.extend(get_attending_chunk_neighbour_unigram_overlap_for_sequence(ex))  # Average over input chunks [num neighbours, chunk size]

print(len(u_overlaps))

### Cosine similarity

The cosine similarity between the input and its neighbors (and contunation chunks) in terms of sentence transformers representations. 

In [None]:
from dataset_retro import ChunkedSequenceDataset, RetroTrainingExample, ShardedChunkedSequenceDataset, \
    ChunkNeighbourDataset, ShardedChunkNeighbourDataset
from torch.utils.data import Dataset
from scipy.spatial.distance import cosine
import dask.array as da
import numpy as np


class ChunkedSequenceDatasetWithIDGetter(ChunkedSequenceDataset):
    def get_chunk_ids(self, chunk_index, shard_range_start, include_continuation_chunks=0):
        start_idx = chunk_index
        end_idx = chunk_index + 1
        while end_idx - start_idx - 1 < include_continuation_chunks and \
            end_idx < len(self.chunk2seq) and \
            self.chunk2seq[start_idx] == self.chunk2seq[end_idx]:
            end_idx += 1
        return slice(start_idx+shard_range_start, end_idx+shard_range_start)
        
        
class ShardedChunkedSequenceDatasetWithIDGetter(ShardedChunkedSequenceDataset):  
    def get_chunk_ids(self, chunk_index, include_continuation_chunks: int=0):
        for shard_range, shard in zip(self.shard_chunk_ranges, self.shards):
            if int(chunk_index) in shard_range:
                local_chunk_index = chunk_index - shard_range.start
                return shard.get_chunk_ids(local_chunk_index, shard_range.start, include_continuation_chunks)
        raise IndexError(f"Chunk with index {chunk_index} not found in index")        

        
class ChunkNeighbourDatasetWithIDGetter(ChunkNeighbourDataset):
    def get_neighbours_ids(self, chunk_index: int, num_neighbours: int=None, continuation_chunks: int=1):
        return [self.retrieval_dataset.get_chunk_ids(neighbour_chunk_idx, continuation_chunks) \
                if neighbour_chunk_idx != -1 else None \
                for neighbour_chunk_idx in self.neighbours[chunk_index][:num_neighbours]]

    
class ShardedChunkNeighbourDatasetWithIDGetter(ShardedChunkNeighbourDataset):
    def get_neighbours_ids(self, chunk_index: int, num_neighbours: int=None, continuation_chunks: int=1):
        for shard_range, shard in zip(self.shard_ranges, self.shards):
            if int(chunk_index) in shard_range:
                local_index = chunk_index - shard_range.start
                return shard.get_neighbours_ids(local_index, num_neighbours, continuation_chunks)
        raise IndexError(f"Neighbours for index {chunk_index} not found")    
    
                
class RetroEmbeddingDataset(Dataset):

    def __init__(
        self, 
        input_dataset: ShardedChunkedSequenceDataset, 
        neighbour_dataset: ShardedChunkNeighbourDatasetWithIDGetter, 
        val_embeddings,
        ret_embeddings,
        num_neighbours=None, 
        continuation_chunks=1, 
        pad_token_idx=0,
        max_len=None
    ):
        super().__init__()
        self.input_dataset = input_dataset
        self.neighbour_dataset = neighbour_dataset
        self.num_neighbours = num_neighbours
        self.continuation_chunks = continuation_chunks
        self.neighbour_size = neighbour_dataset.chunk_size * (1 + continuation_chunks)
        self.pad_token_idx = pad_token_idx
        self.max_num_chunks = max_len // input_dataset.chunk_size if max_len is not None else None
        self.val_embeddings = val_embeddings
        self.ret_embeddings = ret_embeddings

        if max_len is not None:
            assert max_len % input_dataset.chunk_size == 0, \
                "max_len must be a multiple of chunk_size"

        assert input_dataset.num_chunks == len(neighbour_dataset), \
            "The number of chunks in input dataset did not match the number of chunks in neighbour dataset"

    def __len__(self):
        return self.input_dataset.num_sequences

    def __getitem__(self, seq_index: int):
        input_chunk_indices = self.input_dataset.get_chunk_indices_of_sequence(seq_index)[:self.max_num_chunks]
        neighbours_chunk_ids = [self.neighbour_dataset.get_neighbours_ids(
                    chunk_index, 
                    num_neighbours=self.num_neighbours, 
                    continuation_chunks=self.continuation_chunks
                ) for chunk_index in input_chunk_indices[:self.max_num_chunks]]

        neighbour_embs = []
        for neighbours in neighbours_chunk_ids:
            slc_embs = []
            for slc in neighbours:
                if slc is not None:
                    slc_embs.append(self.ret_embeddings[slc])
                else:
                    slc_embs.append(None)
            neighbour_embs.append(slc_embs)

        return self.val_embeddings[input_chunk_indices], neighbour_embs
    
    
def get_retro_embedding_dataset_from_spec(
    spec_file: Path, 
    retrieval_spec_file: Path,
    num_neighbours=None,
    continuation_chunks=1,
    pad_token_idx=0,
    max_len=None,
) -> RetroEmbeddingDataset:

    spec = json.load(spec_file.open())
    base_dir = spec_file.parent

    # input dataset
    input_dataset = ShardedChunkedSequenceDataset([
        ChunkedSequenceDataset(
            chunks=base_dir / shard["chunks"],
            seq2chunk=base_dir / shard["seq2chunk"],
            chunk2seq=base_dir / shard["chunk2seq"]
        )
        for shard in spec["shards"]
    ])
    
    # retrieval dataset
    index_spec = json.load((base_dir / spec["neighbours"]["index_spec"]).open())
    index_base_dir = base_dir / Path(spec["neighbours"]["index_spec"]).parent
    retrieval_dataset = ShardedChunkedSequenceDatasetWithIDGetter([
        ChunkedSequenceDatasetWithIDGetter(
            chunks=index_base_dir / shard["chunks"],
            seq2chunk=index_base_dir / shard["seq2chunk"],
            chunk2seq=index_base_dir / shard["chunk2seq"]
        )
        for shard in index_spec
    ])

    # neighbour dataset
    neighbour_dataset = ShardedChunkNeighbourDatasetWithIDGetter([
        ChunkNeighbourDatasetWithIDGetter(
            neighbours=base_dir / shard["neighbours"],
            retrieval_dataset=retrieval_dataset
        )
        for shard in spec["shards"]
    ])
    
    # embeddings
    val_emb_addrss = [base_dir / shard["neighbours"].replace("neighbours", "embeddings") \
                      for shard in spec["shards"]]
    val_embeddings = [np.load(emb_addrs, mmap_mode="r") for emb_addrs in val_emb_addrss]
    val_embeddings = da.concatenate(val_embeddings, axis=0)

    ret_emb_addrss = [base_dir / Path(el['embeddings'].replace('../', 'retriever_sentence_transformer/')) \
                      for el in json.load(retrieval_spec_file.open())]
    ret_embeddings = [np.load(emb_addrs, mmap_mode="r") for emb_addrs in ret_emb_addrss]
    ret_embeddings = da.concatenate(ret_embeddings, axis=0)
    
    retro_dataset = RetroEmbeddingDataset(
        input_dataset=input_dataset,
        neighbour_dataset=neighbour_dataset,
        val_embeddings=val_embeddings,
        ret_embeddings=ret_embeddings,
        num_neighbours=num_neighbours,
        continuation_chunks=continuation_chunks,
        pad_token_idx=pad_token_idx,
        max_len=max_len,
    )

    return retro_dataset

cos_ds = get_retro_embedding_dataset_from_spec(
    spec_file=Path('retro/data/datasets/MassiveOpenText/val_sentence_transformer_neighbours.spec.json'),
    retrieval_spec_file=Path('retro/data/datasets/MassiveOpenText/retriever_sentence_transformer/val.index.spec.json'),
    num_neighbours=2,
    continuation_chunks=1,
    pad_token_idx=retro_on_config.pad_token_idx,
    max_len=1024
)

In [None]:
def get_attending_chunk_neighbour_l2_distance(inputs_emb, neighbours_emb):
    if len(inputs_emb) <= 1 or len(neighbours_emb) == 0:
        return []

    l2_distances = []
    for i in range(1, len(inputs_emb)):
        distance_from_neighbors = []
        for neighbors in neighbours_emb[i-1]:
            if neighbors is None:
                continue
            for neighbor in neighbors:
                distance_from_neighbors.append(((inputs_emb[i]-neighbor)**2).sum().compute())
        try:
            l2_distances.append(max(distance_from_neighbors))
        except:
            l2_distances.append(0.)
        
    return l2_distances

In [None]:
l2_distances = []
for seq in tqdm(cos_ds_):
    l2_distances.extend(get_attending_chunk_neighbour_l2_distance(*seq))

### BM25

In [None]:
from dataset_retro import ChunkedSequenceDataset, RetroTrainingExample, ShardedChunkedSequenceDataset, \
    ChunkNeighbourDataset, ShardedChunkNeighbourDataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset

  
bm25_neighbours = np.load('retro/data/datasets/MassiveOpenText/retriever_bm25/neighbours.npy')

class ChunkNeighbourDatasetBM25(ChunkNeighbourDataset):
    def get_neighbours(self, global_idx: int, num_neighbours: int=None, continuation_chunks: int=1):
        return [
            self.retrieval_dataset.get_chunk_tokens(
                neighbour_chunk_idx, 
                include_continuation_chunks=continuation_chunks
            ) if neighbour_chunk_idx != -1 else None
            for neighbour_chunk_idx in bm25_neighbours[conv_ids[global_idx]][:num_neighbours]
        ]


class ShardedChunkNeighbourDatasetBM25(ShardedChunkNeighbourDataset):
    def get_neighbours(self, chunk_index: int, num_neighbours: int=None, continuation_chunks: int=1):
        for shard in self.shards:
            return shard.get_neighbours(chunk_index, num_neighbours, continuation_chunks)
        raise IndexError(f"Neighbours for index {chunk_index} not found")

        
class RetroBM25Dataset(Dataset):

    def __init__(
        self, 
        input_dataset: ShardedChunkedSequenceDataset, 
        neighbour_dataset: ShardedChunkNeighbourDatasetBM25, 
        num_neighbours=None, 
        continuation_chunks=1, 
        pad_token_idx=0,
        max_len=None
    ):
        super().__init__()
        self.input_dataset = input_dataset
        self.neighbour_dataset = neighbour_dataset
        self.num_neighbours = num_neighbours
        self.continuation_chunks = continuation_chunks
        self.neighbour_size = neighbour_dataset.chunk_size * (1 + continuation_chunks)
        self.pad_token_idx = pad_token_idx
        self.max_num_chunks = max_len // input_dataset.chunk_size if max_len is not None else None

        if max_len is not None:
            assert max_len % input_dataset.chunk_size == 0, \
                "max_len must be a multiple of chunk_size"

        assert input_dataset.num_chunks == len(neighbour_dataset), \
            "The number of chunks in input dataset did not match the number of chunks in neighbour dataset"

    def __len__(self):
        return self.input_dataset.num_sequences

    def __getitem__(self, seq_index: int) -> RetroTrainingExample:
        input_chunk_indices = self.input_dataset.get_chunk_indices_of_sequence(seq_index)

        for idx in input_chunk_indices[:self.max_num_chunks]:
            if idx not in conv_ids:
                conv_ids[idx] = len(conv_ids)

        # input_ids
        input_ids = np.concatenate([
            self.input_dataset.get_chunk_tokens(chunk_index)
            for chunk_index in input_chunk_indices[:self.max_num_chunks]
        ])

        # neighbour_ids
        neighbour_ids = np.stack([
            [
                np.pad(neighbour_tokens, (0, self.neighbour_size - len(neighbour_tokens)), constant_values=self.pad_token_idx) \
                    if neighbour_tokens is not None else \
                np.ones(self.neighbour_size) * self.pad_token_idx

                for neighbour_tokens in self.neighbour_dataset.get_neighbours(
                    chunk_index, 
                    num_neighbours=self.num_neighbours, 
                    continuation_chunks=self.continuation_chunks
                )
            ]
            for chunk_index in input_chunk_indices[:self.max_num_chunks]
        ])

        # labels - set to -100 at padded tokens
        labels = np.pad(input_ids[1:], (0, 1), constant_values=self.pad_token_idx).astype(np.int64)
        labels[labels == self.pad_token_idx] = -100

        return RetroTrainingExample(
            torch.from_numpy(input_ids.astype(np.int32)), 
            torch.from_numpy(neighbour_ids.astype(np.int32)), 
            torch.from_numpy(labels)
        )
        
def get_retro_bm25_dataset_from_spec(
    spec_file: Path, 
    num_neighbours=None,
    continuation_chunks=1,
    pad_token_idx=0,
    max_len=None,
):

    spec = json.load(spec_file.open())
    base_dir = spec_file.parent

    # input dataset
    input_dataset = ShardedChunkedSequenceDataset([
        ChunkedSequenceDataset(
            chunks=base_dir / shard["chunks"],
            seq2chunk=base_dir / shard["seq2chunk"],
            chunk2seq=base_dir / shard["chunk2seq"]
        )
        for shard in spec["shards"]
    ])

    # retrieval dataset
    index_spec = json.load((base_dir / spec["neighbours"]["index_spec"]).open())
    index_base_dir = base_dir / Path(spec["neighbours"]["index_spec"]).parent
    retrieval_dataset = ShardedChunkedSequenceDataset([
        ChunkedSequenceDataset(
            chunks=index_base_dir / shard["chunks"],
            seq2chunk=index_base_dir / shard["seq2chunk"],
            chunk2seq=index_base_dir / shard["chunk2seq"]
        )
        for shard in index_spec
    ])

    # neighbour dataset
    neighbour_dataset = ShardedChunkNeighbourDatasetBM25([
        ChunkNeighbourDatasetBM25(
            neighbours=base_dir / shard["neighbours"],
            retrieval_dataset=retrieval_dataset
        )
        for shard in spec["shards"]
    ])

    retro_dataset = RetroBM25Dataset(
        input_dataset=input_dataset,
        neighbour_dataset=neighbour_dataset,
        num_neighbours=num_neighbours,
        continuation_chunks=continuation_chunks,
        pad_token_idx=pad_token_idx,
        max_len=max_len
    )

    return retro_dataset

In [None]:
bm25_neighbours.shape

In [None]:
bm25_val_ds = get_retro_bm25_dataset_from_spec(
    spec_file=Path('retro/data/datasets/MassiveOpenText/val_sentence_transformer_neighbours.spec.json'),
    num_neighbours=2,
    continuation_chunks=1,
    pad_token_idx=retro_on_config.pad_token_idx,
    max_len=1024
    )

bm25_val_dl = DataLoader(bm25_val_ds, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=4)

### Running the models

In [None]:
class CaptureLossesHook(pl.Callback):
    def __init__(self):
        self.losses = []
    
    def on_test_batch_end(self, trainer, module, outputs, *args):
        loss = outputs
        self.losses.append(loss)

**RETRO\[OFF\]:**

In [None]:
retro_off_config = RetroConfig(**json.load(Path('retro/data/model/retro.json').open()))
retro_off_config.dec_cca_layers = []
retro_off_model = RetroModelLMHeadLightning(retro_off_config)
CHECKPOINT_PATH = 'retro/data/model/retro_model.ckpt'

retro_off_loss_hook = CaptureLossesHook()
retro_off_model = RetroModelLMHeadLightning.load_from_checkpoint(str(CHECKPOINT_PATH), config=retro_off_config, strict=False).eval()

trainer = pl.Trainer(
    gpus=-1,
    logger=False,
    callbacks=[retro_off_loss_hook]
)

trainer.test(retro_off_model, dataloaders=val_dl)

**RETRO\[ON\]-ST:**

In [None]:
retro_on_model = RetroModelLMHeadLightning(retro_on_config)
retro_on_loss_hook = CaptureLossesHook()
trainer = pl.Trainer(
    gpus=-1, 
    logger=False,
    callbacks=[retro_on_loss_hook]
)

trainer.test(retro_on_model, dataloaders=val_dl, ckpt_path=CHECKPOINT_PATH)

**RETRO\[ON\]-BM25:**

In [None]:
for x in bm25_val_ds:
    pass

retro_on_bm25_model = RetroModelLMHeadLightning(retro_on_config)
retro_on_bm25_loss_hook = CaptureLossesHook()
trainer = pl.Trainer(
    gpus=-1,
    logger=False,
    callbacks=[retro_on_bm25_loss_hook]
)

trainer.test(retro_on_bm25_model, dataloaders=bm25_val_dl, ckpt_path=CHECKPOINT_PATH)

In [None]:
overlap, total = 0, 0
for i in range(len(val_ds)):
    bert_ns = val_ds[i].neighbour_ids
    bm25_ns = bm25_val_ds[i].neighbour_ids
    overlap += sum((bert_ns[:,0] == bm25_ns[:,0]).all(axis=-1))
    total += bert_ns.shape[0]
    
print(overlap/total)

In [None]:
def chunker(seq: list, chunk_size: int=64):
    for i in range(0, len(seq), chunk_size):
        yield seq[i:i + chunk_size]

retro_on_losses = [loss for losses in retro_on_loss_hook.losses for loss in losses]
retro_off_losses = [loss for losses in retro_off_loss_hook.losses for loss in losses]

loss_per_chunk = []

for i in range(len(val_ds)):
    ids_ = val_ds[i].input_ids.nonzero().reshape(-1)
    retro_on_chunks = chunker(retro_on_losses[i][ids_].cpu())
    retro_off_chunks = chunker(retro_off_losses[i][ids_].cpu())
    for j, (on_chunk, off_chunk) in enumerate(zip(retro_on_chunks, retro_off_chunks)):
        if j == 0:
            continue
        loss_per_chunk.append((off_chunk - on_chunk).mean())
print(len(loss_per_chunk))

In [None]:
retro_on_losses = torch.concat(list(batch_loss.mean(-1) for batch_loss in retro_on_loss_hook.losses)).cpu().detach().numpy()
retro_off_losses = torch.concat(list(batch_loss.mean(-1) for batch_loss in retro_off_loss_hook.losses)).cpu().detach().numpy()
avg_loss_diff = retro_off_losses - retro_on_losses

### The correlations

In [None]:
print(spearmanr(u_overlaps, np.exp(loss_per_chunk)))
print(pearsonr(u_overlaps, np.exp(loss_per_chunk)))

In [None]:
l2_distances = np.nan_to_num(l2_distances, neginf=0) 
print(spearmanr(-l2_distances, np.exp(loss_per_chunk)))
print(pearsonr(-l2_distances, np.exp(loss_per_chunk)))

In [None]:
print(spearmanr(-l2_distances, u_overlaps))
print(pearsonr(-l2_distances, u_overlaps))