<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/data_analysis/ELI5_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [None]:
%cd drive/MyDrive/LLMs/ELI5_dataset

/content/drive/MyDrive/LLMs/ELI5_dataset


In [None]:
!pip install datasets --quiet
!pip install textstat --quiet
!pip install wandb --quiet
!pip install redditcleaner --quiet
!pip install huggingface_hub --quiet
!pip install -U sentence-transformers --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.1/105.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Import necessary libraries and modules
import wandb
import torch
import sys
import datasets
import os
import redditcleaner
import re
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from sentence_transformers import SentenceTransformer
from textstat import flesch_reading_ease as fre
from textstat import flesch_kincaid_grade as fkg
from datasets import (
    load_dataset,
    load,
    load_from_disk,
    Dataset,
    concatenate_datasets,
    DatasetDict
)
from itertools import compress
from tqdm import tqdm
from collections import defaultdict
from itertools import combinations
import random
from datetime import datetime

# Check for GPU availability and set the device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"

# Enable inline plotting for Jupyter Notebooks
%matplotlib inline


#Definitions

In [None]:
def replace_url_i(text):
    """
    Replace all occurrences of the pattern "_url_i_", where i is an arbitrary integer, with an empty string in the input text.

    Parameters:
        text (str): The input text containing occurrences of the pattern to be replaced.

    Returns:
        str: The modified text with all occurrences of the pattern removed.

    Example:
        >>> replace_url_i("Check out my website: _url_123_ and _URL_456_")
        'Check out my website:  and '
    """
    # Define the regular expression patterns to match "_url_i_" where i is an arbitrary integer
    patterns = [r"_url_\d+_", r"_Url_\d+_", r"_URL_\d+_"]

    # Use re.sub() to replace all occurrences of the pattern with an empty string
    for pattern in patterns:
        text = re.sub(pattern, "", text)

    return text

def preprocess_example(example):
    """
    Preprocess an example dictionary containing 'answers', 'title', and 'selftext' keys.

    The function applies the following preprocessing steps to each element in the example:
    1. Cleans all answers, titles, and selftext using redditcleaner.
    2. Remove any quoted text starting with '>' and ending with a newline character in 'answers'.
    3. Removes extra whitespaces.
    4. Remove any occurrences of "_url_i_" in 'answers', 'title', and 'selftext'.
    5. Filter out answers with less than 20 words.

    Parameters:
        example (dict): A dictionary containing 'answers', 'title', and 'selftext' keys.

    Returns:
        dict: The preprocessed example dictionary with the above transformations applied.

    Example:
        >>> example = {
                'answers': {'text': ['Visit this website: _url_123_', 'Sure, here is the link: _URL_456_']},
                'title': 'How to use Python?',
                'selftext': 'Check out this tutorial: _Url_789_ to learn Python.'
            }
        >>> preprocess_example(example)
        {
            'answers': {'text': ['visit this website:', 'sure, here is the link:']},
            'title': 'how to use python?',
            'selftext': 'check out this tutorial: to learn python.'
        }
    """
    # Preprocess 'answers'
    answers = example['answers']['text']
    answers = [redditcleaner.clean(answer) for answer in answers]
    answers = [re.sub(r'>.*?\n', ' ', answer) for answer in answers]
    answers = [' '.join(answer.split()) for answer in answers]
    answers = [replace_url_i(answer) for answer in answers]
    answers = [answer for answer in answers if len(answer.split()) >= 20]
    example['answers']['text'] = answers

    # Preprocess 'title'
    title = example['title']
    title = redditcleaner.clean(title)
    title = ' '.join(title.split())
    title = replace_url_i(title)
    example['title'] = title

    # Preprocess 'selftext'
    selftext = example['selftext']
    selftext = redditcleaner.clean(selftext)
    selftext = ' '.join(selftext.split())
    selftext = replace_url_i(selftext)
    example['selftext'] = selftext

    return example

class score_cutoff_wrapper:
    """
    A wrapper class to filter answers based on a cutoff score from an example dictionary.

    This class provides a method to filter the answers in an example based on their corresponding scores.
    Answers with a score greater than or equal to the specified cutoff will be retained, and others will be removed.

    Parameters:
        cutoff (int or float): The cutoff score value to filter answers.
    """

    def __init__(self, cutoff):
        """
        Initialize the score_cutoff_wrapper with the specified cutoff score.

        Parameters:
            cutoff (int or float): The cutoff score value to filter answers.
        """
        self.cutoff = cutoff

    def score_cutoff_ex(self, example):
        """
        Filter the answers in the example based on the cutoff score.

        Parameters:
            example (dict): A dictionary containing 'answers' key with 'text' and 'score' lists.

        Returns:
            dict: The modified example dictionary with answers filtered based on the cutoff score.

        Example:
            >>> example = {
                    'answers': {
                        'text': ['Yes', 'No', 'Maybe'],
                        'score': [10, 5, 8]
                    }
                }
            >>> wrapper = score_cutoff_wrapper(cutoff=8)
            >>> filtered_example = wrapper.score_cutoff_ex(example)
            >>> filtered_example
            {
                'answers': {
                    'text': ['Yes', 'Maybe'],
                    'score': [10, 8]
                }
            }
        """
        scores = example['answers']['score']
        # Find idxs where scores >= cutoff.
        idxs = list(np.array(scores) >= self.cutoff)
        # For each (key,value) pair in dictionary example['answers'] only
        # keep text and metadata for answers with a high enough score.
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val, idxs))

        return example


def score_cutoff(dataset,cutoff=4):
    """
    Uses class score_cutoff_wrapper to filter a Huggingface dataset to only keep
    scores above a certain cutoff.

    Parameters:
        dataset (Dataset): The input Huggingface dataset to be filtered.
        cutoff (int or float, optional): The cutoff score value to filter answers. Default is 4.

    Returns:
        Dataset: The modified dataset with answers filtered based on the cutoff score.
    """
    cutoff = score_cutoff_wrapper(cutoff)
    ds = dataset.map(cutoff.score_cutoff_ex)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds


def flesch_scores(example):
    """
    Calculate Flesch Readability scores for each answer in the example.

    This function calculates Flesch Readability scores and Flesch-Kincaid Grade levels for each answer in the example.
    The calculated scores are then added to the example dictionary under the 'fre' (Flesch Readability) and 'fkg'
    (Flesch-Kincaid Grade) keys.

    Parameters:
        example (dict): A dictionary containing 'answers' key with 'text' lists for each answer.

    Returns:
        dict: The modified example dictionary with Flesch Readability and Flesch-Kincaid Grade scores.

    Example:
        >>> example = {
                'answers': {
                    'text': ['This is a sample answer.', 'Another answer with more words.']
                }
            }
        >>> flesch_scores(example)
        {
            'answers': {
                'text': ['This is a sample answer.', 'Another answer with more words.'],
                'fre': [89.1, 79.2],
                'fkg': [2.6, 5.5]
            }
        }
    """

    # Compute Flesch Readability score for each answer.
    fre_scores = [fre(text) for text in example['answers']['text']]
    # Compute Flesch Kincaid Grade level for each answer.
    fkg_scores = [fkg(text) for text in example['answers']['text']]
    # Add corresponding metrics to dictioanry example['answers'].
    example['answers']['fre'] = fre_scores
    example['answers']['fkg'] = fkg_scores

    return example


class flesch_scores_filter_wrapper:
    """
    This class provides a method to filter answers in an example based on Flesch Readability (fre) and
    Flesch-Kincaid Grade (fkg) scores. Answers with fre >= fre_cutoff and fkg < fkg_cutoff will be retained,
    and others will be removed.
    """

    def __init__(self, fre_cutoff, fkg_cutoff):
        """
        Initialize the flesch_scores_filter_wrapper with the specified cutoff scores.

        Parameters:
            fre_cutoff (float): The cutoff value for Flesch Readability score.
            fkg_cutoff (float): The cutoff value for Flesch-Kincaid Grade score.
        """
        self.fre_cutoff = fre_cutoff
        self.fkg_cutoff = fkg_cutoff

    def flesch_scores_filter(self, example):
        """
        Applies filter to specific example using self.fre_cutoff and self.fkg_cutoff.

        Parameters:
            example (dict): A dictionary containing 'answers' key with 'fre' and 'fkg' lists.

        Returns:
            dict: The modified example dictionary with answers filtered based on the cutoff scores.

        Example:
            >>> example = {
                    'answers': {
                        'text': ['This is a sample answer.', 'Another answer with more words.'],
                        'fre': [89.1, 79.2],
                        'fkg': [2.6, 5.5]
                    }
                }
            >>> filter = flesch_scores_filter_wrapper(fre_cutoff=80, fkg_cutoff=5)
            >>> filtered_example = filter.flesch_scores_filter(example)
            >>> filtered_example
            {
                'answers': {
                    'text': ['This is a sample answer.'],
                    'fre': [89.1],
                    'fkg': [2.6]
                }
            }
        """
        fre_scores = example['answers']['fre']
        fkg_scores = example['answers']['fkg']

        idxs = [True if (fre_scores[i] >= self.fre_cutoff
                         and fkg_scores[i] < self.fkg_cutoff) \
                else False for i in range(len(fre_scores))]

        # Use 'compress' to filter the values based on the boolean mask 'idxs'
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val, idxs))

        return example


def flesch_scores_cutoff(dataset, fre_cutoff=60, fkg_cutoff=9):
    """
    This function applies flesch_scores_filter_wrapper to a Huggingface dataset.
    Only answers with fre >= fre_cutoff and fkg < fkg_cutoff.
    Posts with no qualifying answers will be removed.

    Parameters:
        dataset (Dataset): Huggingface dataset to be filtered.
        fre_cutoff (float, optional): The cutoff value for Flesch Readability score. Default is 60.
        fkg_cutoff (float, optional): The cutoff value for Flesch-Kincaid Grade score. Default is 9.

    Returns:
        Dataset: The modified dataset with answers filtered based on the Flesch Readability and Flesch-Kincaid Grade scores.
    """

    # Define filter function.
    filter = flesch_scores_filter_wrapper(fre_cutoff, fkg_cutoff)
    # Apply function to entire dataset.
    ds = dataset.map(filter.flesch_scores_filter)
    # Remove any posts with no valid answers.
    ds = ds.filter(lambda post: len(post['answers']['fre']) > 0)

    return ds

def preprocess_data(dataset,
                    output_file='./data/filtered',
                    save_file=True,
                    log_to_wandb=True,
                    overwrite=False):
    """
    Preprocesses the input dataset by applying various filters and transformations.

    Parameters:
        dataset (Dataset): The input Huggingface dataset to be processed.
        output_file (str, optional): The path to the file where the processed dataset will be saved.
            Default is './data/filtered'.
        save_file (bool, optional): If True, saves the processed dataset to the output_file.
            Default is True.
        log_to_wandb (bool, optional): If True, logs the processed dataset as a WandB artifact.
            Default is True.
        overwrite (bool, optional): If True, overwrites the output_file if it already exists.
            Default is False.

    Returns:
        Dataset: The preprocessed dataset.

    """

    if os.path.exists(output_file) and not overwrite:
        # If the output_file exists and overwrite is False, load the dataset from disk and return it.
        return load_from_disk(output_file)

    # List of strings to filter out posts based on their titles
    not_qus = ['IAMA', 'AMA', 'ama:', 'megathread', 'Megathread',
               'Discussion Thread', 'Discussion thread',
               'discussion Thread', 'discussion thread',
               'Ask Anything Wednesday', 'Free-for-All',
               'Free-For-All', '[META]', 'Monday Methods',
               'Tuesday Trivia', 'Monday Mysteries',
               'Theory Thursday', 'Monday Mish-Mash',
               'Media Mondays', '[META]', 'Wednesday Week in History',
               'Saturday Popular Questions', 'Ask Anything Wednesday',
               'Thursday Focus Historical Fiction']

    # List of question words used to filter out posts without meaningful questions in their titles or selftext
    qu_reqs = ['who', 'what', 'where', 'why', 'when', 'how', '?']

    # Preprocess each example in the dataset using the preprocess_example function
    dataset = dataset.map(preprocess_example)

    # Filter out posts with 'nsfw' in their titles
    dataset = dataset.filter(lambda post: 'nsfw' not in post['title'].lower())

    # Filter out posts that do not contain meaningful questions in their titles or selftext
    dataset = dataset.filter(lambda post:
                             not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                                  and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

    # Filter out posts that do not correspond to questions.
    dataset = dataset.filter(lambda post: not (any(nq in post['title'] for nq in not_qus)))

    # Map the flesch_scores function to calculate Flesch readability scores for each post
    dataset = dataset.map(flesch_scores)

    # Apply score_cutoff function to remove posts with low Flesch scores
    dataset = score_cutoff(dataset)

    # Apply flesch_scores_cutoff function to remove posts with scores below a certain threshold
    dataset = flesch_scores_cutoff(dataset)

    if save_file:
        # Save the processed dataset to the output_file
        dataset.save_to_disk(output_file)

        if log_to_wandb:
            # Log the processed dataset as a WandB artifact if log_to_wandb is True
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='preprocess_data',
                            name=f'preprocess_data_{time_stamp}') as run:
                # Initialize a WandB run for logging
                processed_data_art = wandb.Artifact('ELI5_processed', 'dataset')
                processed_data_art.add_dir(output_file)
                run.log_artifact(processed_data_art)

    return dataset

def split_idxs(example):
    """
    Splits the indices of scores from the input example's answers into two sets,
    pref_scores_idxs and dupl_scores_idxs.

    pref_scores_idxs = Each index in pref_scores_idxs corresponds
                       to a unique score in example['answers']['score'].


    dupl_scores_idxs = List of indices of example['answers']['score']
                       not found in pref_scores_idxs.

    pref_scores_idx correspond to indices of answers we will use for preference modeling
    since there are no ties in this set.

    dupl_scores_idxs correponds to indices of answers we will use for supervised fine-tuning.


    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'score' as a list.

    Returns:
        dict: The modified input example with 'pref_idxs' and 'dupl_scores_idxs' added.

    """

    # Extract the 'score' list from the 'answers' dictionary in the example.
    scores = example['answers']['score']

    # Sort the unique scores in descending order.
    scores_unique = sorted(set(scores), reverse=True)

    # Get the indices of the preferred scores in the 'scores' list.
    pref_scores_idxs = [scores.index(sc) for sc in scores_unique]

    # Get the indices of duplicate scores in the 'scores' list.
    dupl_scores_idxs = [n for n in range(len(scores)) if n not in pref_scores_idxs]

    # Add the preferred and duplicate scores indices to the input example.
    example['pref_idxs'] = pref_scores_idxs
    example['dupl_scores_idxs'] = dupl_scores_idxs

    # Return the modified example with the added indices.
    return example

def mult_ans_RM_proc(example):
    """
    Processes posts containing multiple answers. Only retains answers that will be used for
    preference modelling.

    Parameters:
        example (dict): The input example containing 'pref_idxs' and 'answers' as dictionary keys.
            'pref_idxs' is a list of indices corresponding to answers we will use for preference modelling.
             Value associated to the key 'answers' is a dictionary containing the text and metadata of the answers.

    Returns:
        dict: The modified input example with 'answers' containing only text and metadata used for preference modeling.

    """


    pref_scores_idxs = example['pref_idxs']

    # Iterate through each key-value pair in the 'answers' dictionary.
    for key, val in example['answers'].items():
        # Update the 'answers' dictionary by keeping only answers to be used for preference modeling.
        example['answers'][key] = [example['answers'][key][i] for i in pref_scores_idxs]

    return example


def mult_ans_SFT_proc(example):
    """
    Processes posts with multiple answers where we only retain answers that will be used for supervised fine-tuning.

    Parameters:
        example (dict): The input example containing 'dupl_scores_idxs' and 'answers' as dictionary keys.
            'dupl_scores_idxs' is a list of indices of duplicate scores, and 'answers' is a dictionary
            with lists of the text of answers and their metadata.

    Returns:
        dict: The modified input example with 'answers' containing only duplicate scores' answers.

    """

    # Retrieve the list of indices of duplicate scores from the 'dupl_scores_idxs' key.
    dupl_scores_idxs = example['dupl_scores_idxs']

    # Iterate through each key-value pair in the 'answers' dictionary.
    for key, val in example['answers'].items():
        # Update the 'answers' dictionary by keeping only the text and meta-data
        # corresponding to duplicate scores' indices.
        example['answers'][key] = [example['answers'][key][i] for i in dupl_scores_idxs]

    return example

def split_ds(ds_original,
             ds_filtered,
             output_dir='ds_split',
             save_file=True,
             log_to_wandb=True,
             overwrite=False):
    """
    Splits the datasets into supervised fine-tuning (SFT), reward modeling (RM), and reinforcement learning (RL) subsets.

    Parameters:
        ds_original (Dataset): The original dataset containing all examples.
        ds_filtered (Dataset): The filtered dataset containing relevant examples for SFT and RM datasets.
        output_dir (str, optional): The directory where the split datasets will be saved.
            Default is 'ds_split'.
        save_file (bool, optional): If True, saves the split datasets to disk. Default is True.
        log_to_wandb (bool, optional): If True, logs the split datasets as a WandB artifact.
            Default is True.
        overwrite (bool, optional): If True, overwrites existing split datasets in the output directory.
            Default is False.

    Returns:
        dict: A dictionary containing the split datasets for SFT, RM, and RL.
    """

    # Check if the split datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'./data/{output_dir}/{split}') for split in ['ds_SFT', 'ds_RM', 'ds_RL'])
        and not overwrite):

        ds_split = {}

        # Load the split datasets from disk and return them.
        ds_split['SFT'] = load_from_disk(f'./data/{output_dir}/ds_SFT')
        ds_split['RM'] = load_from_disk(f'./data/{output_dir}/ds_RM')
        ds_split['RL'] = load_from_disk(f'./data/{output_dir}/ds_RL')

        return ds_split

    ds_split = {}

    # Filter examples with multiple answers and single answers separately.
    ds_mult = ds_filtered.filter(lambda post: len(post['answers']['score']) >= 2)
    ds_sing = ds_filtered.filter(lambda post: len(post['answers']['score']) == 1)

    # Process examples with multiple answers using the 'mult_ans_RM_proc' function to retain only answers that
    # will be used for preference modeling. We choose answers with unique scores to avoid ties during preference modeling.
    ds_mult_indexed = ds_mult.map(split_idxs)
    ds_split['RM'] = ds_mult_indexed.map(mult_ans_RM_proc)
    ds_split['RM'] = ds_split['RM'].filter(lambda x: len(x['answers']['score']) > 0)

    # Process examples with multiple answers using the 'mult_ans_SFT_proc' function to retain only duplicate scores' answers.
    # These will be added to SFT dataset.
    ds_SFT_mult = ds_mult_indexed.map(mult_ans_SFT_proc)
    ds_SFT_mult = ds_SFT_mult.filter(lambda x: len(x['answers']['score']) > 0)

    # Form SFT dataset by combining answers for posts with a unique answers and the
    # answers corresponding to the "duplicate indices" for posts with multiple answers.
    ds_split['SFT'] = datasets.DatasetDict()

    for key in ['train', 'validation', 'test']:
        ds_split['SFT'][key] = datasets.concatenate_datasets([ds_SFT_mult[key], ds_sing[key]])

    # Collect the question IDs of examples used in SFT and RM to exclude them from RL.
    q_ids_taken = []

    for ds_ in (ds_split['SFT'], ds_split['RM']):
        for split in ds_:
            q_ids_taken.extend(ds_[split]['q_id'])

    q_ids_taken = set(q_ids_taken)

    # Create the RL subset by excluding examples used in SFT and RM.
    ds_split['RL'] = ds_original.filter(lambda post: post['q_id'] not in q_ids_taken)
    ds_split['RL'] = datasets.concatenate_datasets([ds for ds in ds_split['RL'].values()])

    # Save the split datasets to disk.
    if save_file:

        for key, value in ds_split.items():
            value.save_to_disk(f'./data/{output_dir}/ds_{key}')

        # Log the split datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='split_data',
                            name=f'split_data_{time_stamp}') as run:

                split_data_art = wandb.Artifact('ELI5_split', 'dataset')
                split_data_art.add_dir(f'./data/{output_dir}')
                run.log_artifact(split_data_art)

    # Return the dictionary containing the split datasets for SFT, RM, and RL.
    return ds_split



def combine_title_body(example):
    """
    Combines the title and body (selftext) of the input example into a single string and updates the input example.

    Parameters:
        example (dict): The input example containing 'title' and 'selftext' as keys.

    Returns:
        dict: The modified input example with the combined string of the title and body
              under the key 'title_body'.

    """

    # Remove extra spaces and join the words in the 'title' string.
    title = ' '.join(example['title'].split())

    # Remove extra spaces and join the words in the 'selftext' string.
    selftext = ' '.join(example['selftext'].split())

    # Combine the 'title' and 'selftext' strings with a newline separator.
    combined = title + '\n' + selftext

    # Add the combined string under the key 'title_body' in the input example.
    example['title_body'] = combined

    # Return the modified input example.
    return example

def embed_datasets(dataset_split,
                   checkpoint='all-mpnet-base-v2',
                   output_dir='embedded',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Embeds the datasets using a pre-trained SentenceTransformer model and saves the embeddings to disk.

    Parameters:
        dataset_split (dict): A dictionary containing different dataset splits as values (e.g., train, validation).
        checkpoint (str, optional): The name of the SentenceTransformer model checkpoint to use.
            Default is 'all-mpnet-base-v2'.
        output_dir (str, optional): The directory where the embedded datasets will be saved.
            Default is 'embedded'.
        save_file (bool, optional): If True, saves the embedded datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites existing embedded datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the embedded datasets as a WandB artifact.
            Default is True.

    Returns:
        dict: A dictionary containing the embedded datasets.

    """

    # Check if the embedded datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
        and not overwrite):

        ds_embedded = {}

        # Load the embedded datasets from disk and return them.
        for subset in ['SFT', 'RM', 'RL']:
            ds_embedded[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
        return ds_embedded

    # Initialize a dictionary to store the embedded datasets.
    ds_embedded = {}

    # Initialize the SentenceTransformer model.
    model = SentenceTransformer(checkpoint)

    # Loop through each dataset split and embed the examples.
    for key in dataset_split:
        ds_embedded[key] = dataset_split[key].map(combine_title_body)
        ds_embedded[key] = ds_embedded[key].map(lambda x: {'qu_emb':
                                                           model.encode(x['title_body'],
                                                                        batch_size=64)})

    # Save the embedded datasets to disk.
    if save_file:
        for key, value in ds_embedded.items():
            value.save_to_disk(f'./data/{output_dir}/ds_{key}')

        # Log the embedded datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='embed_data',
                            name=f'embed_data_{time_stamp}') as run:

                embed_data_art = wandb.Artifact('ELI5_embedded', 'dataset')
                embed_data_art.add_dir(f'./data/{output_dir}')
                run.log_artifact(embed_data_art)

    # Return the dictionary containing the embedded datasets.
    return ds_embedded


def make_pairs(example):
    """
    Creates pairs of answers from the input example based on their scores and updates the example.

    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'text' and 'score' lists.

    Returns:
        dict: The modified input example with the 'pairs' key containing the created pairs of answers.

    """

    # Extract the 'text' and 'score' lists from the 'answers' dictionary in the example.
    answers = example['answers']['text']
    scores = example['answers']['score']

    # Create a list of tuples with each tuple containing the score and its corresponding answer.
    sc_ans = tuple(zip(scores, answers))

    # Generate pairs of tuples using combinations() from the 'sc_ans' list.
    sc_pairs = tuple(combinations(sc_ans, 2))

    # If the number of pairs is greater than 10, randomly select 10 pairs from the list.
    if len(sc_pairs) > 10:
        sc_pairs = random.sample(sc_pairs, 10)

    # Sort each pair of tuples based on their score in descending order.
    sc_pairs = list(map(lambda x: sorted(x, key=lambda y: y[0], reverse=True), sc_pairs))

    # Extract the answers from the sorted pairs and create a list of answer pairs.
    pairs_text = [(sc_pair[0][1], sc_pair[1][1]) for sc_pair in sc_pairs]

    # Add the 'pairs' key to the input example with the created answer pairs.
    example['pairs'] = pairs_text

    # Return the modified input example.
    return example

def clean_datasets(ds_embedded,
                   cutoff=0.6,
                   batch_size=5000,
                   output_dir='cleaned',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Cleans the datasets by removing redundant examples based on the similarity of embedded vectors.

    Parameters:
        ds_embedded (dict): A dictionary containing the embedded datasets for supervised fine-tuning (SFT),
                            reward modeling (RM), and reinforcement learning (RL).
        cutoff (float, optional): The similarity threshold to consider examples as redundant.
            Default is 0.6.
        batch_size (int, optional): The batch size used for processing RL dataset.
            Default is 5000.
        output_dir (str, optional): The directory where the cleaned datasets will be saved.
            Default is 'cleaned'.
        save_file (bool, optional): If True, saves the cleaned datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites existing cleaned datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the cleaned datasets as a WandB artifact.
            Default is True.

    Returns:
        dict: A dictionary containing the cleaned datasets for SFT, RM, and RL.

    """

    #ds_clean is a dictionary which contains DatasetDicts as values.
    ds_clean = {}

    # Check if the cleaned datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
        and not overwrite):

        # Load the cleaned datasets from disk and return them.
        for subset in ['SFT', 'RM', 'RL']:
            ds_clean[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
        return ds_clean

    # Initialize dictionaries to store normalized embedding vectors and overlaps between splits for SFT and RM datasets.
    embed_vecs = {}
    overlaps = {}
    idxs = {}

    # standard splitting of data in supervised learning.
    splits = ['train', 'validation', 'test']

    # Cleaning SFT and RM datasets.
    for subset in ['SFT', "RM"]:
        print(f'Cleaning {subset} dataset')

        # Set the format of dataset to 'torch' to enable torch operations on the embedded vectors.
        ds_embedded[subset].set_format('torch')
        embed_vecs[subset] = {}

        # Normalize the embedded vectors for each split.
        for split in splits:
            embed_vecs[subset][split] = ds_embedded[subset][split]['qu_emb']
            embed_vecs[subset][split] /= torch.sqrt(torch.sum(embed_vecs[subset][split] ** 2,
                                                             dim=1,
                                                             keepdim=True))

        overlaps[subset] = {}
        idxs[subset] = {}

        # Compute the overlaps between splits and store the indices of redundant examples.
        for j in range(1, 3):
            for i in range(j):
                overlaps[subset][(splits[i], splits[j])] = torch.matmul(
                    embed_vecs[subset][splits[i]],
                    embed_vecs[subset][splits[j]].T
                )

                idxs[subset][(splits[i], splits[j])] = torch.where((overlaps[subset][(splits[i], splits[j])]) >= cutoff)

        # Find indices of examples to remove from the training set due to overlap between train and validation splits.
        rm_tr_idxs_temp = idxs[subset]['train', 'validation'][0].numpy()
        rm_tr_idxs_temp = set(rm_tr_idxs_temp)

        # Find indices of examples to remove from the training set due to overlap between train and test splits.
        rm_tr_idxs = idxs[subset]['train', 'test'][0].numpy()
        rm_tr_idxs = set(rm_tr_idxs).union(rm_tr_idxs_temp)

        # Indices to keep in train set.
        keep_train = set(range(len(ds_embedded[subset]['train']))) - rm_tr_idxs

        # Find indices of examples to remove from the test set due to overlap between validation and test splits.
        # Remove examples from test set because it is larger than the validation set.
        rm_test_idxs = idxs[subset]['validation', 'test'][1].numpy()
        rm_test_idxs = set(rm_test_idxs)

        # Indices to keep in train set.
        keep_test = set(range(len(ds_embedded[subset]['test']))) - rm_test_idxs

        # Create a new DatasetDict containing the cleaned subsets for SFT and RM.
        ds_clean[subset] = datasets.DatasetDict()

        ds_clean[subset]['train'] = ds_embedded[subset]['train'].select(keep_train)
        ds_clean[subset]['validation'] = ds_embedded[subset]['validation']
        ds_clean[subset]['test'] = ds_embedded[subset]['test'].select(keep_test)

    # Cleaning RL dataset.
    print(f'Cleaning RL dataset')

    # Set the format of RL dataset to 'torch' to enable torch operations on the embedded vectors.
    ds_embedded['RL'].set_format('torch')

    # Extract the embedded vectors for the RL dataset.
    embed_vecs['RL'] = ds_embedded['RL']['qu_emb']

    # Normalize the embedded vectors by dividing them by their L2 norm.
    embed_vecs['RL'] /= torch.sqrt(torch.sum(embed_vecs['RL'] ** 2,
                                             dim=1,
                                             keepdim=True))
    # Get the size of the RL dataset (number of examples).
    RL_size = len(ds_embedded['RL'])

    # Create an empty set to store the indices of redundant examples in the RL dataset.
    rem_RL = set()

    # Initialize a variable to keep track of the start index of each batch.
    start = 0

    # Calculate the number of batches based on the batch size.
    num_batches = RL_size // batch_size

    # If the size of RL dataset is not perfectly divisible by batch_size, add one extra batch.
    if RL_size % batch_size != 0:
        num_batches += 1

    # Loop through each batch and compute overlaps with SFT and RM datasets to find redundant examples.
    for k in tqdm(range(num_batches)):

        # Calculate the start and end index of the current batch.
        start = k * batch_size
        end = (k + 1) * batch_size

        # Get the current batch of embedded vectors.
        batch = embed_vecs['RL'][start:start + batch_size, :]

        # Compute overlaps between the current batch and the SFT and RM datasets.
        for subset in ['SFT', 'RM']:
            for split in ['train', 'validation']:
                overlap = torch.matmul(embed_vecs[subset][split], batch.T)

                # Find the indices of redundant examples in the current batch.
                rem_RL_idxs_temp = torch.where(overlap >= cutoff)[1].numpy()

                # Update the set of indices of redundant examples in the entire RL dataset.
                rem_RL = rem_RL.union(set(rem_RL_idxs_temp))

    # Create a set containing all the indices of the RL dataset.
    keep_RL = set(range(RL_size))

    # Remove the indices of redundant examples from the set to get non-redundant examples.
    keep_RL -= set(rem_RL)

    # Select non-redundant examples for RL dataset.
    ds_clean['RL'] = ds_embedded['RL'].select(keep_RL)

    # Apply 'make_pairs' function to the RM dataset to create pairs of answers.
    ds_clean['RM'] = ds_clean['RM'].map(lambda x: make_pairs(x))

    # Save the cleaned datasets to disk.
    if save_file:
        for subset in ['SFT', 'RM', 'RL']:
            ds_clean[subset].save_to_disk(f'./data/{output_dir}/ds_{subset}')

        # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='clean_data',
                            name=f'clean_data_{time_stamp}') as run:

                clean_data_art = wandb.Artifact('ELI5_cleaned', 'dataset')
                clean_data_art.add_dir(f'./data/{output_dir}')
                run.log_artifact(clean_data_art)

    # Return the dictionary containing the cleaned datasets for SFT, RM, and RL.
    return ds_clean


# Code

In [None]:
ds_original = load_dataset("vblagoje/lfqa")

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/687M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
ds_filtered = preprocess_data(ds_original)

In [None]:
wandb.login()

In [None]:
ds_split = split_ds(ds_original,
                    ds_filtered)

In [None]:
ds_embedded = embed_datasets(ds_split)

Downloading (…)a8e1d/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)0bca8e1d/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)e1d/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)a8e1d/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)8e1d/train_script.py:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)bca8e1d/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

Map:   0%|          | 0/63390 [00:00<?, ? examples/s]

Map:   0%|          | 0/939 [00:00<?, ? examples/s]

Map:   0%|          | 0/3153 [00:00<?, ? examples/s]

Map:   0%|          | 0/63390 [00:00<?, ? examples/s]

Map:   0%|          | 0/939 [00:00<?, ? examples/s]

Map:   0%|          | 0/3153 [00:00<?, ? examples/s]

Map:   0%|          | 0/29956 [00:00<?, ? examples/s]

Map:   0%|          | 0/1347 [00:00<?, ? examples/s]

Map:   0%|          | 0/2379 [00:00<?, ? examples/s]

Map:   0%|          | 0/29956 [00:00<?, ? examples/s]

Map:   0%|          | 0/1347 [00:00<?, ? examples/s]

Map:   0%|          | 0/2379 [00:00<?, ? examples/s]

Map:   0%|          | 0/145507 [00:00<?, ? examples/s]

Map:   0%|          | 0/145507 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/63390 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/939 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3153 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/29956 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1347 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2379 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/145507 [00:00<?, ? examples/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668546633324392, max=1.0…

[34m[1mwandb[0m: Adding directory to artifact (./data/embedded)... Done. 8.5s


In [None]:
ds_clean = clean_datasets(ds_embedded)

Cleaning SFT dataset
Cleaning RM dataset
Cleaning RL dataset


100%|██████████| 30/30 [04:53<00:00,  9.78s/it]


Map:   0%|          | 0/25799 [00:00<?, ? examples/s]

Map:   0%|          | 0/1347 [00:00<?, ? examples/s]

Map:   0%|          | 0/2275 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/57802 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/939 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3067 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/25799 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1347 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2275 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/140507 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (./data/cleaned)... Done. 11.4s
