## Data prep for retrieving beliefs for dialogs

**Goal:** Train an embedding model to match dialogs with beliefs and facts
**Method:**
- [x] Use stacked_samsum + dialogsum as datasets
- [x] Prepare datasets
    - [x] remove unnecessary columns
    - [x] remove '#' from dialogsum
    - [x] expand the stacked dataset
    - [x] truncate on the left to 512 tokens
    - [x] combine and train_test_split

### Imports and utils

In [10]:
model_name = "BAAI/bge-small-en-v1.5"
query_prefix = "Represent this sentence for searching relevant passages:"
max_len = 512
next_concept_sep = "\n[NEXT_CONCEPT]\n"
training_input_data = "./data/output.jsonl"
eval_input_data = "./data/eval-output.jsonl"
training_hn_data = "./data/hn-output.jsonl"
eval_size = 12_500
combined_data_path = "./data/combined"

### Constants

In [3]:
%matplotlib inline

from functools import partial
import os
import random

from datasets import load_dataset, concatenate_datasets, load_from_disk
from FlagEmbedding import FlagModel
from FlagEmbedding.baai_general_embedding.finetune.hn_mine import find_knn_neg
import jsonlines as jsonl
import matplotlib.pyplot as plt
import numpy as np
from numpy import dot
from numpy.linalg import norm
from tqdm.auto import tqdm
from transformers import AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

### Datasets

#### Initial run

- stacked_samsum: to be used as training
- dialogsum: to be used as testing

#### Final run

- combined: stacked_samsum + dialogsum

In [3]:
datasets = {
    "stacked_samsum": load_dataset(
        "stacked-summaries/stacked-samsum-1024", 
        split="train+validation+test",
    ).remove_columns(['chapter_length', 'summary_length', 'is_stacked',]).filter(
        lambda row: row["dialogue"]
    ).map(
        lambda row: dict(dialogue=row["dialogue"].replace("\r\n", '\n'))
    ),
    "dialogsum": load_dataset(
        "knkarthick/dialogsum", 
        split="train+validation+test",
    ).remove_columns(["id", "topic"]).map(
        lambda row: dict(dialogue=row["dialogue"].replace("#", ''))
    ),
}

In [4]:
def count_tokens(row):

    dialogue = row["dialogue"]
    tokens = tokenizer.encode(dialogue, add_special_tokens=False)

    return dict(token_count=len(tokens))

datasets["stacked_samsum"] = datasets["stacked_samsum"].map(count_tokens)
datasets["dialogsum"] = datasets["dialogsum"].map(count_tokens)

In [5]:
def truncate_dir(dialogue, left=False):
    lines = dialogue.split("\n")
    
    toks_by_line = [
        len(tokenizer.encode(line, add_special_tokens=False))
        for line in lines
    ]

    idx = 0 if left else -1
    
    while sum(toks_by_line) > max_len:
        toks_by_line.pop(idx)
        lines.pop(idx)

    return "\n".join(lines)

def expand_stacked(rows):
    dialogues = rows["dialogue"]
    summaries = rows["summary"]

    final_dialogues = []
    final_summaries = []
    
    for dialogue, summary in zip(dialogues, summaries):
        
        ss = summary.split(next_concept_sep)
        dd = [
            truncate_dir(dialogue, left=(i >= (len(ss) // 2)))
            for i in range(len(ss))
        ]

        final_dialogues += dd
        final_summaries += ss

    return dict(
        dialogue=final_dialogues,
        summary=final_summaries,
        token_count=[None]*len(final_summaries),
    )

datasets["stacked_samsum"] = datasets["stacked_samsum"].map(expand_stacked, batched=True).remove_columns(["token_count"])

In [6]:
combined = concatenate_datasets(list(datasets.values()))

In [7]:
combined = combined.train_test_split(test_size=0.1)

In [8]:
combined.save_to_disk(combined_data_path)

Saving the dataset (0/1 shards):   0%|          | 0/125705 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/13968 [00:00<?, ? examples/s]

In [5]:
# combined = load_from_disk(combined_data_path)

### Prepare dataset for finetuning
[Docs](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune)

Format:
```json
{"query": str, "pos": List[str], "neg":List[str]}
```

Keys:
- query: belief
- pos: list of matching conversations
- neg: list of random conversations from dataset

In [6]:
def pick_random(split="train", far_from=0):
    ds = combined[split]
    ds_len = len(ds)
    mid = ds_len // 2
    which_half = far_from // mid
    
    start = (1 - which_half) * mid
    end = ds_len - which_half * mid
    idx = random.randrange(start, end)
    
    return ds[idx]

In [14]:
with jsonl.open(training_input_data, mode='w') as writer:
    for i, row in enumerate(tqdm(combined["train"], total=len(combined["train"]))):
        query = row["summary"]
        pos = [row["dialogue"]]
    
        neg = [
            pick_random(split="train", far_from=i)["dialogue"]
            for _ in range(3)
        ]
        
        writer.write(dict(query=query, pos=pos, neg=neg))

  0%|          | 0/125705 [00:00<?, ?it/s]

In [11]:
with jsonl.open(eval_input_data, mode='w') as writer:
    for i, row in enumerate(tqdm(combined["test"], total=len(combined["train"]))):
        if i > eval_size:
            break

        query = row["summary"]
        pos = [row["dialogue"]]
    
        neg = [
            pick_random(split="test", far_from=i)["dialogue"]
            for _ in range(3)
        ]
        
        writer.write(dict(query=query, pos=pos, neg=neg))

  0%|          | 0/125705 [00:00<?, ?it/s]

### Mine hard negatives

In [15]:
model = FlagModel(
    model_name,
    query_instruction_for_retrieval=query_prefix,
)

----------using 3*GPUs----------


In [16]:
find_knn_neg(
    model,
    input_file=training_input_data,
    candidate_pool=None,
    output_file=training_hn_data,
    sample_range=list(range(2, 200)),
    negative_number=15,
    use_gpu=True,
)

inferencing embedding for corpus (number=51135)--------------


Inference Embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:45<00:00,  1.58s/it]


inferencing embedding for queries (number=125705)--------------


Inference Embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 164/164 [00:46<00:00,  3.54it/s]


create index and search------------------


Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1965/1965 [00:04<00:00, 428.59it/s]
