# Dependencies

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd drive/MyDrive/LLMs/ELI5_dataset

/content/drive/MyDrive/LLMs/ELI5_dataset


In [3]:
!pip install datasets --quiet
!pip install textstat --quiet
!pip install wandb --quiet
!pip install redditcleaner --quiet
!pip install huggingface_hub --quiet
!pip install -U sentence-transformers --quiet
!pip install seaborn --quiet
!pip install -U transformers --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.1/105.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━

In [4]:
# Import necessary libraries and modules
import wandb
import torch
import sys
import datasets
import os
import redditcleaner
import re
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from sentence_transformers import SentenceTransformer
from textstat import flesch_reading_ease as fre
from textstat import flesch_kincaid_grade as fkg
from datasets import (
    load_dataset,
    load,
    load_from_disk,
    Dataset,
    concatenate_datasets,
    DatasetDict
)
from itertools import compress
from tqdm import tqdm
from collections import defaultdict
from itertools import combinations
import random
from datetime import datetime
import shutil

# Check for GPU availability and set the device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"

# Enable inline plotting for Jupyter Notebooks
%matplotlib inline


#Definitions

### Dataset loading helper functions

In [20]:
def load_ds(base_dir):
    """
    Load ds from disk.

    Dynamically detects the number of arrow files in the train, test, and validation directories.

    Parameters:
        base_dir (str): The base directory of the dataset.

    Returns:
        ds (DatasetDict): A DatasetDict containing the dataset splits.

    """

    train_files = [os.path.join(base_dir, "train", file) for file in get_filenames_for_dir(os.path.join(base_dir, "train"))]
    test_files = [os.path.join(base_dir, "test", file) for file in get_filenames_for_dir(os.path.join(base_dir, "test"))]
    validation_files = [os.path.join(base_dir, "validation", file) for file in get_filenames_for_dir(os.path.join(base_dir, "validation"))]

    ds = load_dataset("arrow", data_files={"train": train_files, "test": test_files, "validation": validation_files})

    return ds

def load_split_ds(base_dir):
    """
    Load SFT/RM/RL splits from disk.

    Dynamically detects the number of arrow files in each split and its subsplits.

    Parameters:
        base_dir (str): The base directory of the dataset splits.

    Returns:
        ds (dict[DatasetDict]): A dictionary containing the SFT/RM/RL splits of the dataset.

    """

    splits = ["SFT", "RM", "RL"]
    ds = {}

    for split in splits:
        split_dir = os.path.join(base_dir, f"ds_{split}")

        # Check if train/test/validation subdirectories exist
        if all(os.path.isdir(os.path.join(split_dir, subsplit)) for subsplit in ["train", "test", "validation"]):
            data_files = {
                subsplit: [os.path.join(split_dir, subsplit, file) for file in get_filenames_for_dir(os.path.join(split_dir, subsplit))]
                for subsplit in ["train", "test", "validation"]
            }
        else:  # No subdirectories, just load the files directly
            data_files = [os.path.join(split_dir, 'train', file) for file in get_filenames_for_dir(os.path.join(split_dir, 'train'))]

        ds_split = load_dataset("arrow", data_files=data_files)

        # Store in the master dictionary
        ds[split] = ds_split

    return ds

def get_filenames_for_dir(directory):
    """Helper function to get sorted arrow filenames for a given directory."""
    filenames = sorted([f for f in os.listdir(directory) if f.endswith(".arrow")])
    total_files = len(filenames)
    return [f"data-0000{i}-of-0000{total_files}.arrow" for i in range(total_files)]


### Preprocessing

In [14]:
def preprocess_data(dataset,
                    output_dir='./data/ELI5/ds_preprocessed',
                    save_file=True,
                    log_to_wandb=True,
                    overwrite=False):
    """
    Preprocesses the input dataset by applying various filters,
    then combining the title and body of each post.

    Parameters:
        dataset (Dataset): The input Huggingface dataset to be processed.
        output_dir (str, optional): The path to the directory where the processed dataset will be saved.
            Default is './data/filtered'.
        save_file (bool, optional): If True, saves the processed dataset to the output_file.
            Default is True.
        log_to_wandb (bool, optional): If True, logs the processed dataset as a WandB artifact.
            Default is True.
        overwrite (bool, optional): If True, overwrites the output_file if it already exists.
            Default is False.

    Returns:
        Dataset: The preprocessed dataset.

    """

    if os.path.exists(output_dir) and not overwrite:
        print('Loading filtered datasets.....')
        # If the output_file exists and overwrite is False, load the dataset from disk and return it.
        return load_from_disk(output_dir)

    # List of strings to filter out posts based on their titles
    not_qus = ['AMA', 'megathread', 'Discussion Thread',
               'Ask Anything Wednesday', 'Monday Methods',
               'Tuesday Trivia', 'Monday Mysteries',
               'Theory Thursday', 'Monday Mish-Mash',
               'Media Mondays', 'Wednesday Week in History',
               'Saturday Popular Questions', 'Ask Anything Wednesday',
               'Thursday Focus Historical Fiction', 'Askhistorians Podcast',
               'cross post', 'cross-post', 'crosspost', 'x post', 'x-post', 'x/post',
               'mod post', 'mods', 'moderator','meta',
               'ask me anything', 'meetup',' floating feature', 'twenty-year rule',
               'subreddit', 'Rules Roundtable',
              ]

    # List of question words used to filter out posts without meaningful questions in their titles or selftext
    qu_reqs = ['who', 'what', 'where', 'why', 'when', 'how', '?']

    # Preprocess each example in the dataset using the preprocess_example function
    print('Preprocessing datasets.....')
    dataset = dataset.map(preprocess_example,
                          batched=True, batch_size=64)

    # Filter out posts with 'nsfw' in their titles
    print('Filtering posts.....')
    dataset = dataset.filter(lambda post: 'nsfw' not in post['title'].lower())

    # Filter out posts with '__url_i__' in the title or selftext
    dataset = dataset.filter(lambda post: not contains_url(post['title']) \
                                          and not contains_url(post['selftext']))

    # Filter out posts that do not contain meaningful questions in their titles or selftext
    dataset = dataset.filter(lambda post:
                             not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                                  and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

    # Filter out posts that do not correspond to questions.
    dataset = dataset.filter(lambda post: not (any(nq.lower() in post['title'].lower() for nq in not_qus)))

    # Combine title and body of remaining posts
    print('Combining post title+body.....')
    dataset = dataset.map(combine_title_body)

    if save_file:
        # Save the processed dataset to the output_file
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        dataset.save_to_disk(output_dir)

        if log_to_wandb:
            # Log the processed dataset as a WandB artifact if log_to_wandb is True
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='preprocess_data',
                            name=f'preprocess_data_{time_stamp}') as run:
                # Initialize a WandB run for logging
                processed_data_art = wandb.Artifact('ELI5_preprocessed', 'dataset')
                processed_data_art.add_dir(output_dir)
                run.log_artifact(processed_data_art)

    return dataset

def preprocess_example(batch):
    """
    Batch preprocess an example dictionary containing 'answers', 'title', and 'selftext' keys.

    The function applies the following preprocessing steps to each element in the example:
    1. Cleans all answers, titles, and selftext using redditcleaner.
    2. Remove any answers that contain "_url_i_" (posts w/urls in title/selftext are filtered later in preprocess_data)
    3. Remove any answers that contain "reddit".
    4. Removes extra whitespaces from answers, titles and selftext.
    5. Capitalize the beginning of each sentence and all instances of the word 'I' in answers.
    5. Truncate answers, titles and selftext at 'edit:', "[update" etc. (refer to truncate_edit_update_thanks for details).
    6. Truncate selftext at 'PS', 'p.s.' etc.
    7. Remove any answers with less than 20 words.
    8. Remove 'eli5', 'ELI 10:' etc at the beginning of the title.

    Only elements with at least one answer after preprocessing are retained.

    Parameters:
        batch (dict): A dictionary containing 'answers', 'title', and 'selftext' keys.

    Returns:
        List[dict]: The preprocessed dictionaries with the above transformations applied.

    Example:
        >>> example = {
                'answers': {'text': [['Visit this website: _url_123_',
                                     'Ask this question on another subreddit',
                                     'Ask this question. edit: a a a a a a a a a a a a a a a a a a a a a a a a a',
                                     'this is an  answer containing at least 20 words and a a a a a a a a a a a'],
                                     ['this question has no answers with at least 20 words']]},
                'title': ['ELI 5: How to use Python?', 'what color is the sky'],
                'selftext': ['Check out this tutorial: _Url_789_ to learn Python. [updated to fix typos]',
                            ['it might be green']],
            }
        >>> preprocess_example(example)
        {
            'answers': {'text': [['this is an answer containing at least 20 words and a a a a a a a a a a a']]},
            'title': ['How to use Python?'],
            'selftext': ['Check out this tutorial: _Url_789_ to learn Python.'],
        }

        Note that the above post will be removed by preprocess_data since it contains a url in the selftext.
    """

    # store processed examples and track elements to retain
    # retained elements must have at least one answer
    retained_idxs = []
    processed_batch = {'answers' : [],
                       'title' : [],
                       'selftext' : []}
    # process examples
    for idx, answers in enumerate(batch['answers']):
        # deduplicate answers
        answers = deduplicate_answers(answers)
        # in each example, filter answers based on criteria in docstring
        valid_indices = [index for index, answer in enumerate(answers['text']) \
                         if len(answer.strip().split(' ')) >= 20 \
                         and not contains_url(answer) \
                         and not contains_reddit_ref(answer) \
                         and not 'reddit' in answer

        ]
        # apply cleaning transformations to the retained answers
        filtered_texts = [answers['text'][index] for index in valid_indices]
        filtered_texts = [' '.join(text.strip().split()) for text in filtered_texts]
        filtered_texts = [capitalize_sentences(text) for text in filtered_texts]
        filtered_texts = [redditcleaner.clean(text) for text in filtered_texts]
        filtered_texts = [truncate_edit_update_thanks(text) for text in filtered_texts]
        filtered_texts = [text.strip() for text in filtered_texts]
        # build filtered answers dict
        filtered_answers = {
            'a_id': [answers['a_id'][index] for index in valid_indices],
            'score': [answers['score'][index] for index in valid_indices],
            'text': filtered_texts,
        }
        processed_batch['answers'].append(filtered_answers)
        # keep track of elements in batch to retain
        if len(filtered_answers['text']) > 0:
            retained_idxs.append(idx)

    # process titles
    for title in batch['title']:
        title = title.strip()
        title = redditcleaner.clean(title)
        title = ' '.join(title.split())
        title = truncate_edit_update_thanks(title)
        title = re.sub(r'^eli\s?\d*[.:-]?', '', title, flags=re.IGNORECASE)
        title = truncate_leading_chars(title)
        title = title.strip()
        processed_batch['title'].append(title)

    # process selftext
    for selftext in batch['selftext']:
        selftext = selftext.strip()
        selftext = redditcleaner.clean(selftext)
        selftext = ' '.join(selftext.split())
        selftext = truncate_edit_update_thanks(selftext)
        selftext = truncate_ps(selftext)
        selftext = selftext.strip()
        processed_batch['selftext'].append(selftext)

    # update batch with processed data
    # only retain examples with at least one answer after processing
    retained_batch = {}
    for key in batch:
        if key in ['answers', 'title', 'selftext']:
            retained_batch[key] = [processed_batch[key][i] for i in retained_idxs]
        else:
            retained_batch[key] = [batch[key][i] for i in retained_idxs]

    return retained_batch

def combine_title_body(example):
    """
    Combines the title and body (selftext) of the input example into a single string and updates the input example.

    Parameters:
        example (dict): The input example containing 'title' and 'selftext' as keys.

    Returns:
        dict: The modified input example with the combined string of the title and body
              under the key 'title_body'.

    """

    # Remove extra spaces and join the words in the 'title' string.
    title = ' '.join(example['title'].split())

    # Remove extra spaces and join the words in the 'selftext' string.
    selftext = ' '.join(example['selftext'].split())

    # Combine the 'title' and 'selftext' strings with a newline separator.
    combined = title + ' ' + selftext

    # Add the combined string under the key 'title_body' in the input example.
    example['title_body'] = combined

    # Return the modified input example.
    return example

def deduplicate_answers(example):
    
    answers = example['answers']
    a_ids = answers['a_id']
    
    # Find the indices of the first occurrence of each id
    seen_ids = set()
    retain_indices = []
    for index, item in enumerate(a_ids):
        if item not in seen_ids:
            seen_ids.add(item)
            retain_indices.append(index)
    
    # Retain elements at the identified indices from each list in the dictionary
    for key, value in answers.items():
        answers[key] = [value[i] for i in retain_indices]
    
    example['answers'] = answers
    return example

def capitalize_sentences(text):
    """
    Capitalizes the first letter of each sentence in the given text and replaces standalone lowercase 'i' with uppercase 'I'.

    Parameters:
    - text (str): The input string containing one or more sentences.

    Returns:
    - str: The input text with each sentence capitalized and standalone lowercase 'i' replaced with uppercase 'I'.

    Example:
    >>> capitalize_sentences("hello. this is a test. like i am")
    'Hello. This is a test. Like I am'
    """
    # Split text by sentence delimiters (.!?), but keep the delimiters
    sentences = re.split(r'(?<=[.!?])\s+', text)

    # Capitalize each sentence
    capitalized_sentences = ' '.join([sentence.strip().capitalize() for sentence in sentences if sentence])

    # Replace lowercase 'i' surrounded by non-word characters with uppercase 'I'
    capitalized_sentences = re.sub(r'\bi\b', 'I', capitalized_sentences)

    return capitalized_sentences

def contains_url(text):
    """
    Helper function for filtering posts with links.

    Parameters:
        text (str): The input text.

    Returns:
        bool: True if the input text contains reddit urls.

    Example:
        >>> replace_url_i("Check out my website: _url_123_ and _URL_456_")
        'Check out my website:  and '
    """
    # Define the regular expression patterns to match "_url_i_" where i is an arbitrary integer
    url_pattern = r"_url_\d+_"

    return re.search(url_pattern, text, flags=re.IGNORECASE)

def contains_reddit_ref(text):
    """
    Helper function for filtering posts with reddit references (r/whatever).

    Parameters:
        text (str): The input text.

    Returns:
        bool: True if the input text contains reddit references.

    Example:
        >>> replace_url_i("Check out my website: _url_123_ and _URL_456_")
        'Check out my website:  and '
    """
    # Define the regular expression patterns to match "_url_i_" where i is an arbitrary integer
    reddit_pattern = r"r\/"

    return re.search(reddit_pattern, text, flags=re.IGNORECASE)


def truncate_edit_update_thanks(text):
    """
    Helper function for truncating posts at any of the following:
    - "edit/update/thank" followed by optional spaces and ":" or "-" anywhere in the text
    - "edit/update/thank" preceded by "<", "[" or "(" anywhere in the text
    - "edit/update/thank" at the beginning of a sentence

    Parameters:
        text (str): The input text.

    Returns:
        truncated_text (str): The input text truncated at the beginning of the first pattern match.

    Example:
        >>> truncate_edit_update_thanks("Hello world! Thanks for listening")
        'Hello world!'
    """
    # edit, update, thanks patterns
    # If the term is at the start of a sentence, return True.
    pattern_sos = r"(?<=[.!?)])\s*\b(edit|edited|update|updated|thanks|thank\syou)\b"
    # If the term is anywhere and is preceded by <, ( or [, return True.
    pattern_anywhere_1 = r"[\[\(<]\s*\b(edit|edited|update|updated|thanks|thank\syou)\b"
    # If the term is anywhere and is followed by optional spaces plus a mandatory colon, period or dash, return True.
    pattern_anywhere_2 = r"\b(edit|edited|update|updated|thanks|thank\syou)\b\s*([:.-])"

    patterns = [pattern_sos, pattern_anywhere_1, pattern_anywhere_2]
    patterns = [re.compile(pattern, flags=re.IGNORECASE) for pattern in patterns]

    # Search for the pattern in the text and truncate if found
    truncated_text = text
    for pattern in patterns:
        match = pattern.search(truncated_text)
        if match:
            # If the pattern is found, truncate the text at the start of the match
            truncated_text = truncated_text[:match.start()]

    return truncated_text


def truncate_ps(text):
    """
    Helper function for truncating posts at the first occurence of 'PS'.

    Parameters:
        text (str): The input text.

    Returns:
        truncated_text (str): The input text truncated at the beginning of the pattern match.

    Example:
        >>> truncate_ps("Hello world! PS I am a computer")
        'Hello world!'
    """
    # search for PS (possibly with periods) followed by any nonword character
    # at the start of a sentence (i.e. after a sentence-ending punct like ., !, ?)
    pattern_ps = re.compile(r"\b[\[\(<]?p\.?s\.?(?=\W|\b)", re.IGNORECASE)

    # Search for the pattern in the question and truncate if found
    truncated_text = text
    match = pattern_ps.search(text)
    if match:
        # If the pattern is found, truncate the question at the start of the match
        truncated_text = text[:match.start()]

    return truncated_text

def truncate_leading_chars(text):
    """
    Removes any leading characters that are not alphabets from the string using regex.
    
    Args:
    - text (str): The input string
    
    Returns:
    - str: String starting from the first letter
    """
    match = re.search(r'[a-zA-Z]', text)
    if match:
        return text[match.start():]
    return text

### Reddit score/Flesch filtering

In [7]:
def apply_score_filtering(dataset,
                    fre_cutoff=60,
                    fkg_cutoff=9,
                    reddit_cutoff=4):
    """
    This function applies flesch_scores_filter_wrapper to a Huggingface dataset.
    Only answers with fre >= fre_cutoff and fkg < fkg_cutoff.
    Posts with no qualifying answers will be removed.

    Parameters:
        dataset (Dataset): Huggingface dataset to be filtered.
        fre_cutoff (float, optional): The cutoff value for Flesch Readability score. Default is 60.
        fkg_cutoff (float, optional): The cutoff value for Flesch-Kincaid Grade score. Default is 9.
        reddit_cutoff (float, optional): The cutoff value for reddit score. Default is 4.

    Returns:
        Dataset: The modified dataset with answers filtered based on the Flesch Readability and Flesch-Kincaid Grade scores.
    """

    # Map the flesch_scores function to calculate Flesch readability scores for each post
    print('Computing flesch scores.....')
    dataset = dataset.map(compute_flesch_scores)

    # Define filter function.
    filter = flesch_scores_filter_wrapper(fre_cutoff, fkg_cutoff)
    # Remove posts with scores below certain thresholds
    print(f'Filtering by flesch score (FRE>={fre_cutoff}, FKG<{fkg_cutoff}).....')
    dataset = dataset.map(filter.flesch_scores_filter)
    # Remove any posts with no valid answers.
    dataset = dataset.filter(lambda post: len(post['answers']['fre']) > 0)

    # Apply score_cutoff function to remove posts with low reddit scores
    if reddit_cutoff:
        print(f'Filtering by reddit score (reddit score>={reddit_cutoff}).....')
        dataset = score_cutoff(dataset, cutoff=reddit_cutoff)

    return dataset


class score_cutoff_wrapper:
    """
    A wrapper class to filter answers based on a cutoff score from an example dictionary.

    This class provides a method to filter the answers in an example based on their corresponding scores.
    Answers with a score greater than or equal to the specified cutoff will be retained, and others will be removed.

    Parameters:
        cutoff (int or float): The cutoff score value to filter answers.
    """

    def __init__(self, cutoff):
        """
        Initialize the score_cutoff_wrapper with the specified cutoff score.

        Parameters:
            cutoff (int or float): The cutoff score value to filter answers.
        """
        self.cutoff = cutoff

    def score_cutoff_ex(self, example):
        """
        Filter the answers in the example based on the cutoff score.

        Parameters:
            example (dict): A dictionary containing 'answers' key with 'text' and 'score' lists.

        Returns:
            dict: The modified example dictionary with answers filtered based on the cutoff score.

        Example:
            >>> example = {
                    'answers': {
                        'text': ['Yes', 'No', 'Maybe'],
                        'score': [10, 5, 8]
                    }
                }
            >>> wrapper = score_cutoff_wrapper(cutoff=8)
            >>> filtered_example = wrapper.score_cutoff_ex(example)
            >>> filtered_example
            {
                'answers': {
                    'text': ['Yes', 'Maybe'],
                    'score': [10, 8]
                }
            }
        """
        scores = example['answers']['score']
        # Find idxs where scores >= cutoff.
        idxs = list(np.array(scores) >= self.cutoff)
        # For each (key,value) pair in dictionary example['answers'] only
        # keep text and metadata for answers with a high enough score.
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val, idxs))

        return example


def score_cutoff(dataset,cutoff=4):
    """
    Uses class score_cutoff_wrapper to filter a Huggingface dataset to only keep
    scores above a certain cutoff.

    Parameters:
        dataset (Dataset): The input Huggingface dataset to be filtered.
        cutoff (int or float, optional): The cutoff score value to filter answers. Default is 4.

    Returns:
        Dataset: The modified dataset with answers filtered based on the cutoff score.
    """
    cutoff = score_cutoff_wrapper(cutoff)
    ds = dataset.map(cutoff.score_cutoff_ex)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds


def compute_flesch_scores(example):
    """
    Calculate Flesch Readability scores for each answer in the example.

    This function calculates Flesch Readability scores and Flesch-Kincaid Grade levels for each answer in the example.
    The calculated scores are then added to the example dictionary under the 'fre' (Flesch Readability) and 'fkg'
    (Flesch-Kincaid Grade) keys.

    Parameters:
        example (dict): A dictionary containing 'answers' key with 'text' lists for each answer.

    Returns:
        dict: The modified example dictionary with Flesch Readability and Flesch-Kincaid Grade scores.

    Example:
        >>> example = {
                'answers': {
                    'text': ['This is a sample answer.', 'Another answer with more words.']
                }
            }
        >>> compute_flesch_scores(example)
        {
            'answers': {
                'text': ['This is a sample answer.', 'Another answer with more words.'],
                'fre': [89.1, 79.2],
                'fkg': [2.6, 5.5]
            }
        }
    """

    # Compute Flesch Readability score for each answer.
    fre_scores = [fre(text) for text in example['answers']['text']]
    # Compute Flesch Kincaid Grade level for each answer.
    fkg_scores = [fkg(text) for text in example['answers']['text']]
    # Add corresponding metrics to dictioanry example['answers'].
    example['answers']['fre'] = fre_scores
    example['answers']['fkg'] = fkg_scores

    return example


class flesch_scores_filter_wrapper:
    """
    This class provides a method to filter answers in an example based on Flesch Readability (fre) and
    Flesch-Kincaid Grade (fkg) scores. Answers with fre >= fre_cutoff and fkg < fkg_cutoff will be retained,
    and others will be removed.
    """

    def __init__(self, fre_cutoff, fkg_cutoff):
        """
        Initialize the flesch_scores_filter_wrapper with the specified cutoff scores.

        Parameters:
            fre_cutoff (float): The cutoff value for Flesch Readability score.
            fkg_cutoff (float): The cutoff value for Flesch-Kincaid Grade score.
        """
        self.fre_cutoff = fre_cutoff
        self.fkg_cutoff = fkg_cutoff

    def flesch_scores_filter(self, example):
        """
        Applies filter to specific example using self.fre_cutoff and self.fkg_cutoff.

        Parameters:
            example (dict): A dictionary containing 'answers' key with 'fre' and 'fkg' lists.

        Returns:
            dict: The modified example dictionary with answers filtered based on the cutoff scores.

        Example:
            >>> example = {
                    'answers': {
                        'text': ['This is a sample answer.', 'Another answer with more words.'],
                        'fre': [89.1, 79.2],
                        'fkg': [2.6, 5.5]
                    }
                }
            >>> filter = flesch_scores_filter_wrapper(fre_cutoff=80, fkg_cutoff=5)
            >>> filtered_example = filter.flesch_scores_filter(example)
            >>> filtered_example
            {
                'answers': {
                    'text': ['This is a sample answer.'],
                    'fre': [89.1],
                    'fkg': [2.6]
                }
            }
        """
        fre_scores = example['answers']['fre']
        fkg_scores = example['answers']['fkg']

        idxs = [(fre >= self.fre_cutoff and fkg < self.fkg_cutoff) \
                for fre, fkg in zip(fre_scores, fkg_scores)]

        # Use 'compress' to filter the values based on the boolean mask 'idxs'
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val, idxs))

        return example

### Splitting

In [8]:
def split_ds(ds_preprocessed,
             ds_filtered,
             output_dir='./data/ELI5/ds_split',
             save_file=True,
             log_to_wandb=True,
             overwrite=False):
    """
    Splits the datasets into supervised fine-tuning (SFT), reward modeling (RM), and reinforcement learning (RL) subsets.

    Parameters:
        ds_preprocessed (Dataset): The preprocessed dataset containing all examples.
        ds_filtered (Dataset): The score-filtered dataset containing relevant examples for SFT and RM datasets.
        output_dir (str, optional): The directory where the split datasets will be saved.
            Default is 'ds_split'.
        save_file (bool, optional): If True, saves the split datasets to disk. Default is True.
        log_to_wandb (bool, optional): If True, logs the split datasets as a WandB artifact.
            Default is True.
        overwrite (bool, optional): If True, overwrites existing split datasets in the output directory.
            Default is False.

    Returns:
        dict: A dictionary containing the split datasets for SFT, RM, and RL.
    """

    # Check if the split datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'./data/{output_dir}/{split}') for split in ['ds_SFT', 'ds_RM', 'ds_RL'])
        and not overwrite):

        ds_split = {}

        # Load the split datasets from disk and return them.
        print('Loading split datasets.....')
        ds_split['SFT'] = load_from_disk(f'./data/{output_dir}/ds_SFT')
        ds_split['RM'] = load_from_disk(f'./data/{output_dir}/ds_RM')
        ds_split['RL'] = load_from_disk(f'./data/{output_dir}/ds_RL')

        return ds_split

    ds_split = {}

    # Filter examples with multiple answers and single answers separately.
    ds_mult = ds_filtered.filter(lambda post: len(post['answers']['score']) >= 2)
    ds_sing = ds_filtered.filter(lambda post: len(post['answers']['score']) == 1)

    # Process examples with multiple answers using the 'mult_ans_RM_proc' function to retain only answers that
    # will be used for preference modeling. We choose answers with unique scores to avoid ties during preference modeling.
    print('Generating SFT and RM splits.....')
    ds_mult_indexed = ds_mult.map(split_idxs)
    ds_split['RM'] = ds_mult_indexed.map(mult_ans_RM_proc)
    ds_split['RM'] = ds_split['RM'].filter(lambda x: len(x['answers']['score']) > 0)

    # Process examples with multiple answers using the 'mult_ans_SFT_proc' function to retain only duplicate scores' answers.
    # These will be added to SFT dataset.
    ds_SFT_mult = ds_mult_indexed.map(mult_ans_SFT_proc)
    ds_SFT_mult = ds_SFT_mult.filter(lambda x: len(x['answers']['score']) > 0)

    # Form SFT dataset by combining answers for posts with a unique answers and the
    # answers corresponding to the "duplicate indices" for posts with multiple answers.
    ds_split['SFT'] = datasets.DatasetDict()

    for key in ['train', 'validation', 'test']:
        ds_split['SFT'][key] = datasets.concatenate_datasets([ds_SFT_mult[key], ds_sing[key]])

    #Remove reddit posts with a low score from the SFT dataset.
    ds_split['SFT'] = score_cutoff(ds_split['SFT'])

    # Collect the question IDs of examples used in SFT and RM to exclude them from RL.
    q_ids_taken = []

    for ds_ in (ds_split['SFT'], ds_split['RM']):
        for split in ds_:
            q_ids_taken.extend(ds_[split]['q_id'])

    q_ids_taken = set(q_ids_taken)

    # Create the RL subset by excluding examples used in SFT and RM.
    print('Generating RL split.....')
    ds_split['RL'] = ds_preprocessed.filter(lambda post: post['q_id'] not in q_ids_taken)
    ds_split['RL'] = datasets.concatenate_datasets([ds for ds in ds_split['RL'].values()])

    # Save the split datasets to disk.
    if save_file:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        for key, value in ds_split.items():
            value.save_to_disk(output_dir+f'/ds_{key}')

        # Log the split datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='split_data',
                            name=f'split_data_{time_stamp}') as run:

                split_data_art = wandb.Artifact('ELI5_split', 'dataset')
                split_data_art.add_dir(output_dir)
                run.log_artifact(split_data_art)

    # Return the dictionary containing the split datasets for SFT, RM, and RL.
    return ds_split


def split_idxs(example):
    """
    Splits the indices of scores from the input example's answers into two sets,
    pref_scores_idxs and dupl_scores_idxs.

    pref_scores_idxs = Each index in pref_scores_idxs corresponds
                       to a unique score in example['answers']['score'].


    dupl_scores_idxs = List of indices of example['answers']['score']
                       not found in pref_scores_idxs.

    pref_scores_idx correspond to indices of answers we will use for preference modeling
    since there are no ties in this set.

    dupl_scores_idxs correponds to indices of answers we will use for supervised fine-tuning.


    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'score' as a list.

    Returns:
        dict: The modified input example with 'pref_idxs' and 'dupl_scores_idxs' added.

    """

    # Extract the 'score' list from the 'answers' dictionary in the example.
    scores = example['answers']['score']

    # Sort the unique scores in descending order.
    scores_unique = sorted(set(scores), reverse=True)

    # Get the indices of the preferred scores in the 'scores' list.
    pref_scores_idxs = [scores.index(sc) for sc in scores_unique]

    # Get the indices of duplicate scores in the 'scores' list.
    dupl_scores_idxs = [n for n in range(len(scores)) if n not in pref_scores_idxs]

    # Add the preferred and duplicate scores indices to the input example.
    example['pref_idxs'] = pref_scores_idxs
    example['dupl_scores_idxs'] = dupl_scores_idxs

    # Return the modified example with the added indices.
    return example

def mult_ans_RM_proc(example):
    """
    Processes posts containing multiple answers. Only retains answers that will be used for
    preference modelling.

    Parameters:
        example (dict): The input example containing 'pref_idxs' and 'answers' as dictionary keys.
            'pref_idxs' is a list of indices corresponding to answers we will use for preference modelling.
             Value associated to the key 'answers' is a dictionary containing the text and metadata of the answers.

    Returns:
        dict: The modified input example with 'answers' containing only text and metadata used for preference modeling.

    """


    pref_scores_idxs = example['pref_idxs']

    # Iterate through each key-value pair in the 'answers' dictionary.
    for key, val in example['answers'].items():
        # Update the 'answers' dictionary by keeping only answers to be used for preference modeling.
        example['answers'][key] = [example['answers'][key][i] for i in pref_scores_idxs]

    return example


def mult_ans_SFT_proc(example):
    """
    Processes posts with multiple answers where we only retain answers that will be used for supervised fine-tuning.

    Parameters:
        example (dict): The input example containing 'dupl_scores_idxs' and 'answers' as dictionary keys.
            'dupl_scores_idxs' is a list of indices of duplicate scores, and 'answers' is a dictionary
            with lists of the text of answers and their metadata.

    Returns:
        dict: The modified input example with 'answers' containing only duplicate scores' answers.

    """

    # Retrieve the list of indices of duplicate scores from the 'dupl_scores_idxs' key.
    dupl_scores_idxs = example['dupl_scores_idxs']

    # Iterate through each key-value pair in the 'answers' dictionary.
    for key, val in example['answers'].items():
        # Update the 'answers' dictionary by keeping only the text and meta-data
        # corresponding to duplicate scores' indices.
        example['answers'][key] = [example['answers'][key][i] for i in dupl_scores_idxs]

    return example

### Embedding

In [25]:
def embed_datasets(ds_split,
                   checkpoint='all-mpnet-base-v2',
                   output_dir='./data/ELI5/ds_embedded',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Embeds the datasets using a pre-trained SentenceTransformer model and saves the embeddings to disk.

    Parameters:
        ds_split (dict): A dictionary containing different dataset splits as values (e.g., train, validation).
        checkpoint (str, optional): The name of the SentenceTransformer model checkpoint to use.
            Default is 'all-mpnet-base-v2'.
        output_dir (str, optional): The directory where the embedded datasets will be saved.
            Default is 'embedded'.
        save_file (bool, optional): If True, saves the embedded datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites existing embedded datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the embedded datasets as a WandB artifact.
            Default is True.

    Returns:
        dict: A dictionary containing the embedded datasets.

    """

    # Check if the embedded datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
        and not overwrite):

        ds_embedded = {}

        # Load the embedded datasets from disk and return them.
        print('Loading embedded datasets.....')
        for subset in ['SFT', 'RM', 'RL']:
            ds_embedded[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
        return ds_embedded

    # Initialize a dictionary to store the embedded datasets.
    ds_embedded = {}

    # Initialize the SentenceTransformer model.
    model = SentenceTransformer(checkpoint)
    model.to("cuda" if torch.cuda.is_available() else "cpu")

    # Loop through each dataset split and embed the examples.
    for key in ds_split:
        print(f'Embedding {key} dataset.....')
        ds_embedded[key] = ds_split[key].map(lambda x: {'qu_emb':
                                                           model.encode(x['title_body'],
                                                                        batch_size=64)})

    # Save the embedded datasets to disk.
    if save_file:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        for key, value in ds_embedded.items():
            value.save_to_disk(output_dir+f'/ds_{key}')

        # Log the embedded datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='embed_data',
                            name=f'embed_data_{time_stamp}') as run:

                embed_data_art = wandb.Artifact('ELI5_embedded', 'dataset')
                embed_data_art.add_dir(output_dir)
                run.log_artifact(embed_data_art)

    # Return the dictionary containing the embedded datasets.
    return ds_embedded




### Deduplicating and generating pairs

In [None]:
def dedup_and_make_pairs(ds_embedded,
                         dedup=True,
                         cos_cutoff=0.6,
                         reddit_cutoff=4,
                         batch_size=5000,
                         output_dir='./data/ELI5/ds_pairs',
                         save_file=True,
                         overwrite=False,
                         log_to_wandb=True):
    """
    Cleans the datasets by removing redundant examples based on the similarity
    of embedded vectors if dedup=True. Dataset must have a 'qu_emb' field as
    generated by the embed_datasets function.

    Generates pairs for RM using the deduplicated dataset.

    Parameters:
        ds_embedded (dict): A dictionary containing the embedded datasets for supervised fine-tuning (SFT),
                            reward modeling (RM), and reinforcement learning (RL).
        dedup (bool, optional): If True, removes redundant examples based on the similarity of embedded vectors.
            Default is True.
        cos_cutoff (float, optional): The cos similarity threshold to consider examples as redundant.
            Default is 0.6.
        reddit_cutoff (float, optional): The cutoff value for reddit score.
            Default is 4.
        batch_size (int, optional): The batch size used for processing RL dataset.
            Default is 5000.
        output_dir (str, optional): The directory where the possibly deduped paired datasets will be saved.
            Default is 'ds_pairs_deduped', or 'ds_pairs' if dedup is False.
        save_file (bool, optional): If True, saves the cleaned datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites existing cleaned datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the cleaned datasets as a WandB artifact.
            Default is True.

    Returns:
        ds_dedup (dict): A dictionary containing the cleaned datasets for SFT, RM, and RL.

    """

    #ds_dedup is a dictionary which contains DatasetDicts as values.
    ds_dedup = {}

    if dedup:
        output_dir += '_deduped'
        # Check if the cleaned datasets already exist in the output directory and overwrite is False.
        if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
            and not overwrite):

            # Load the cleaned datasets from disk and return them.
            print('Loading deduplicated datasets.....')
            for subset in ['SFT', 'RM', 'RL']:
                ds_dedup[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
            return ds_dedup

        # Initialize dictionaries to store normalized embedding vectors and overlaps between splits for SFT and RM datasets.
        embed_vecs = {}
        overlaps = {}
        idxs = {}

        # standard splitting of data in supervised learning.
        splits = ['train', 'validation', 'test']

        # Cleaning SFT and RM datasets.
        for subset in ['SFT', "RM"]:
            print(f'Deduplicating {subset} dataset.....')

            # Set the format of dataset to 'torch' to enable torch operations on the embedded vectors.
            ds_embedded[subset].set_format('torch')
            embed_vecs[subset] = {}

            # Normalize the embedded vectors for each split.
            for split in splits:
                embed_vecs[subset][split] = ds_embedded[subset][split]['qu_emb']
                embed_vecs[subset][split] /= torch.sqrt(torch.sum(embed_vecs[subset][split] ** 2,
                                                                dim=1,
                                                                keepdim=True))

            overlaps[subset] = {}
            idxs[subset] = {}

            # Compute the overlaps between splits and store the indices of redundant examples.
            for j in range(1, 3):
                for i in range(j):
                    overlaps[subset][(splits[i], splits[j])] = torch.matmul(
                        embed_vecs[subset][splits[i]],
                        embed_vecs[subset][splits[j]].T
                    )

                    idxs[subset][(splits[i], splits[j])] = torch.where((overlaps[subset][(splits[i], splits[j])]) >= cos_cutoff)

            # Find indices of examples to remove from the training set due to overlap between train and validation splits.
            rm_tr_idxs_temp = idxs[subset]['train', 'validation'][0].numpy()
            rm_tr_idxs_temp = set(rm_tr_idxs_temp)

            # Find indices of examples to remove from the training set due to overlap between train and test splits.
            rm_tr_idxs = idxs[subset]['train', 'test'][0].numpy()
            rm_tr_idxs = set(rm_tr_idxs).union(rm_tr_idxs_temp)

            # Indices to keep in train set.
            keep_train = set(range(len(ds_embedded[subset]['train']))) - rm_tr_idxs

            # Find indices of examples to remove from the test set due to overlap between validation and test splits.
            # Remove examples from test set because it is larger than the validation set.
            rm_test_idxs = idxs[subset]['validation', 'test'][1].numpy()
            rm_test_idxs = set(rm_test_idxs)

            # Indices to keep in test set.
            keep_test = set(range(len(ds_embedded[subset]['test']))) - rm_test_idxs

            # Create a new DatasetDict containing the cleaned subsets for SFT and RM.
            ds_dedup[subset] = datasets.DatasetDict()

            ds_dedup[subset]['train'] = ds_embedded[subset]['train'].select(keep_train)
            ds_dedup[subset]['validation'] = ds_embedded[subset]['validation']
            ds_dedup[subset]['test'] = ds_embedded[subset]['test'].select(keep_test)

        # Cleaning RL dataset.
        print(f'Deduplicating RL dataset.....')

        # Set the format of RL dataset to 'torch' to enable torch operations on the embedded vectors.
        ds_embedded['RL'].set_format('torch')

        # Extract the embedded vectors for the RL dataset.
        embed_vecs['RL'] = ds_embedded['RL']['qu_emb']

        # Normalize the embedded vectors by dividing them by their L2 norm.
        embed_vecs['RL'] /= torch.sqrt(torch.sum(embed_vecs['RL'] ** 2,
                                                dim=1,
                                                keepdim=True))
        # Get the size of the RL dataset (number of examples).
        RL_size = len(ds_embedded['RL'])

        # Create an empty set to store the indices of redundant examples in the RL dataset.
        rem_RL = set()

        # Initialize a variable to keep track of the start index of each batch.
        start = 0

        # Calculate the number of batches based on the batch size.
        num_batches = RL_size // batch_size

        # If the size of RL dataset is not perfectly divisible by batch_size, add one extra batch.
        if RL_size % batch_size != 0:
            num_batches += 1

        # Loop through each batch and compute overlaps with SFT and RM datasets to find redundant examples.
        for k in tqdm(range(num_batches)):

            # Calculate the start and end index of the current batch.
            start = k * batch_size
            end = (k + 1) * batch_size

            # Get the current batch of embedded vectors.
            batch = embed_vecs['RL'][start:start + batch_size, :]

            # Compute overlaps between the current batch and the SFT and RM datasets.
            for subset in ['SFT', 'RM']:
                for split in ['train', 'validation']:
                    overlap = torch.matmul(embed_vecs[subset][split], batch.T)

                    # Find the indices of redundant examples in the current batch.
                    rem_RL_idxs_temp = torch.where(overlap >= cos_cutoff)[1].numpy()

                    # Update the set of indices of redundant examples in the entire RL dataset.
                    rem_RL = rem_RL.union(set(rem_RL_idxs_temp))

        # Create a set containing all the indices of the RL dataset.
        keep_RL = set(range(RL_size))

        # Remove the indices of redundant examples from the set to get non-redundant examples.
        keep_RL -= set(rem_RL)

        # Select non-redundant examples for RL dataset.
        ds_dedup['RL'] = ds_embedded['RL'].select(keep_RL)

    # Apply 'make_pairs' function to the RM dataset to create pairs of answers.
    print('Making pairs for RM.....')
    ds_dedup['RM'] = ds_dedup['RM'].map(lambda x: make_pairs(x, reddit_cutoff))

    # Save the cleaned datasets to disk.
    if save_file:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        for subset in ['SFT', 'RM', 'RL']:
            ds_dedup[subset].save_to_disk(f'{output_dir}/ds_{subset}')

        # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='clean_data',
                            name=f'clean_data_{time_stamp}') as run:

                dedup_data_art = wandb.Artifact('ELI5_deduped', 'dataset')
                dedup_data_art.add_dir(output_dir)
                run.log_artifact(dedup_data_art)

    # Return the dictionary containing the cleaned datasets for SFT, RM, and RL.
    return ds_dedup



def make_pairs(example, score_cutoff=0):
    """
    Creates pairs of answers from the input example based on their scores and updates the example.

    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'text' and 'score' lists.
        score_cutoff (float, optional): The minimum score required for the preferred answer. Default is 0.

    Returns:
        dict: The modified input example with the 'pairs_text' key containing the created pairs of answers.

    """

    # Extract the 'text' and 'score' lists from the 'answers' dictionary in the example.
    answers = example['answers']['text']
    scores = example['answers']['score']

    # Create a list of tuples with each tuple containing the score, its corresponding answer and its index.
    sc_ans = tuple(zip(scores, answers))

    # Generate pairs of tuples using combinations() from the 'sc_ans' list.
    sc_pairs = tuple(combinations(sc_ans, 2))

    # If the number of pairs is greater than 10, randomly select 10 pairs from the list.
    if len(sc_pairs) > 10:
        sc_pairs = random.sample(sc_pairs, 10)

    # Sort each pair of tuples based on their score in descending order.
    sc_pairs = list(map(lambda x: sorted(x, key=lambda y: y[0], reverse=True), sc_pairs))

    # Extract the answers from the sorted pairs and create a list of answer pairs.
    pairs_text = [(sc_pair[0][1], sc_pair[1][1]) for sc_pair in sc_pairs if sc_pair[0][0] >= score_cutoff]

    # Add the 'pairs_text' key to the input example with the created answer pairs.
    example['pairs_text'] = pairs_text

    # Return the modified input example.
    return example

### Detoxifying

In [None]:
def detox_datasets(ds_split,
                   cutoff=0.1,
                   output_dir='./data/ELI5/ds_detox',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Cleans the datasets by removing toxic examples as judged by the detoxify library.

    Parameters:
        ds_split (dict): A dictionary containing the dataset split for supervised fine-tuning (SFT),
                            reward modeling (RM), and reinforcement learning (RL).
        cutoff (float, optional): The toxicity score threshold to consider examples as toxic.
            Default is 0.1.
        output_dir (str, optional): The directory where the detoxified datasets will be saved.
            Default is 'nontox'.
        save_file (bool, optional): If True, saves the detoxified datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites detoxified cleaned datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the detoxified datasets as a WandB artifact.
            Default is True.
    Returns:
        dict: A dictionary containing the detoxified datasets for SFT, RM, and RL.
    """
    # load the toxicity model
    detoxify_model = Detoxify('unbiased')
    detoxify_model.model.to("cuda" if torch.cuda.is_available() else "cpu")
    splits = ['SFT', 'RM', 'RL']

    # Check if the embedded datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'{output_dir}/ds_{split}') for split in splits)
        and not overwrite):

        ds_nontox = {}

        # Load the embedded datasets from disk and return them.
        for split in splits:
            ds_split[split] = load_from_disk(f'./data/{output_dir}/ds_{split}')
        return ds_nontox

    # ds_nontox is a dictionary which contains DatasetDicts as values.
    ds_nontox = {}
    for split in ds_split:
        print(f'Filtering toxic posts in {split} split.....')
        if split in ['SFT', 'RL']:
            # SFT/RL splits: predict toxicity scores for each example's answers
            ds_split[split] = ds_split[split].map(lambda x: \
                                              {'toxicity_scores': \
                                               [detoxify_model.predict(answer['text']) \
                                                for answer in x['answers']]}, \
                                              batched=True, batch_size=256)
            # only keep nontoxic answers
            ds_nontox[split] = ds_split[split].map(lambda x: detox_answers(x, cutoff))
            # drop examples with no nontoxic answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['answers']['text'])>=1)

        elif split == 'RM':
            # RM split: predict toxicity scores for each example's first answer
            ds_split[split] = ds_split[split].map(lambda x: \
                                                {'toxicity_scores': \
                                                [[detoxify_model.predict(pair[0]) \
                                                    for pair in pairs_text] \
                                                    for pairs_text in x['pairs_text']]
                                                },
                                                  batched=True,
                                                  batch_size=256
            )
            # only keep pairs with nontoxic preferred answers
            ds_nontox[split] = ds_split[split].map(lambda x: detox_answers_RM(x, cutoff))
            # drop examples with no pairs with nontoxic preferred answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['pairs_text'])>=1)

    # Save the detoxed datasets to disk.
    if save_file:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        for split in splits:
            ds_nontox[split].save_to_disk(f'{output_dir}/ds_{split}')

        # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='detox_data',
                            name=f'detox_data_{time_stamp}') as run:

                nontox_data_art = wandb.Artifact('ELI5_detox', 'dataset')
                nontox_data_art.add_dir(output_dir)
                run.log_artifact(nontox_data_art)

    return ds_nontox


def detox_answers(example, cutoff):
    """
    Removes toxic answers from the input example (for SFT and RL).

    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'text' and 'toxicity_score' lists.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
           'obscene','severe_toxicity',
           'sexual_explicit','threat',
           'toxicity']

    # get the indices of answers below tox cutoff
    answers = example['answers']
    tox_scores = example['toxicity_scores']
    valid_indices = [idx for idx, _ in enumerate(answers['text']) \
                        if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])

    ]
    
    # build nontox answers dict
    detoxed_answers = {
        'a_id': [answers['a_id'][index] for index in valid_indices],
        'score': [answers['score'][index] for index in valid_indices],
        'text': [answers['text'][index] for index in valid_indices],
    }
    example['answers'] = detoxed_answers
    return example


def detox_answers_RM(example, cutoff):
    """
    Removes answer pairs with toxic preferred answers from the input example.

    Parameters:
        example (dict): The input example with a list of answer pairs stored in the 'pairs_text' key
                        and a list of the toxicity scores of the preferred answer in each pair
                        stored in the 'toxicity_score' key. The two lists must share an index.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
       'obscene','severe_toxicity',
       'sexual_explicit','threat',
       'toxicity']

    # get the indices of answer pairs with preferred answers below tox cutoff
    pairs = example['pairs_text']
    tox_scores = example['toxicity_scores']
    valid_indices = [idx for idx, _ in enumerate(pairs) \
                        if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])

    ]
    
    # filter the pairs
    detoxed_pairs = [pair for idx, pair in enumerate(pairs) if idx in valid_indices]
    example['pairs_text'] = detoxed_pairs
    return example

### Dataset building wrapper

In [None]:
def build_dataset(score_cutoff_dict=None,
                  dedup=True,
                  detox=True,
                  detox_cutoff=0.1,
                  save_intermediates=False,
                  save_final=True,
                  log_to_wandb=True,
                  overwrite=False,
                  output_basename='./data/ELI5'):
    """
    Dataset processing pipeline that performs preprocessing, score filtering,
    splitting, embedding, deduplication and saving of datasets.

    Detoxifying the model must be handled separately due to conflicting
    transformers compatiblity constraints in embedding vs the detoxify library.

    Parameters:
        score_cutoff_dict (dict, optional): Dictionary containing cutoff scores for various metrics.
            Expected keys: 'cutoff_fkg', 'cutoff_fre', and 'cutoff_reddit'.
            If not provided, default values of 9, 60 and 4 are used.
        dedup (bool, optional): Whether to deduplicate the dataset. Defaults to True.
        detox (bool, optional): Whether to detoxify the dataset. Defaults to True.
        detox_cutoff (float, optional): Cutoff score for detoxification. Defaults to 0.1.
        save_intermediates (bool, optional): Whether to save intermediate stages of processed datasets to disk. Defaults to False.
        save_final (bool, optional): Whether to save the final processed datasets to disk. Defaults to True.
        log_to_wandb (bool, optional): Whether to log any saved datasets to Weights & Biases. Defaults to True.
        overwrite (bool, optional): Whether to overwrite existing processed datasets. Defaults to False.
        output_basename (str, optional): Base path for saving datasets. Defaults to './data/ELI5'.

    Returns:
        dict: Dictionary containing the processed datasets with keys 'SFT', 'RM', and 'RL'.
    """

    output_extension = ''

    if dedup:
        output_extension += '_deduped'
    if detox:
        output_extension += f'_detoxed'
    if score_cutoff_dict:
        output_extension += f"__cutoffs_FKG{score_cutoff_dict['cutoff_fkg']}_FRE{score_cutoff_dict['cutoff_fre']}"
        if score_cutoff_dict['cutoff_reddit']:
            output_extension += f'_REDD{score_cutoff_dict["cutoff_reddit"]}'
        else:
            output_extension += f'_REDD4'
    else:
        output_extension += f"__cutoffs_FKG9_FRE60_REDD4"
    if detox:
        output_extension += f'_TOX{detox_cutoff}'

    output_dir = output_basename + output_extension

    # Check if the processed datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'{output_dir}/processed/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
        and not overwrite):

        ds = {}

        # Load the embedded datasets from disk and return them.
        print('Loading processed datasets.....')
        for subset in ['SFT', 'RM', 'RL']:
            ds[subset] = load_from_disk(f'{output_dir}/ds_{subset}')
        return ds

    # Download original dataset
    print('Downloading dataset.....')
    ds_original = load_dataset("vblagoje/lfqa")

    # Preprocess the dataset
    ds_preprocessed = preprocess_data(ds_original,
                                      output_dir=output_dir+'/ds_preprocessed',
                                      save_file=save_intermediates,
                                      log_to_wandb=log_to_wandb,
                                      overwrite=overwrite)

    # Score filter the dataset
    if score_cutoff_dict:
        ds_filtered = apply_score_filtering(ds_preprocessed,
                                            fre_cutoff=score_cutoff_dict['cutoff_fre'],
                                            fkg_cutoff=score_cutoff_dict['cutoff_fkg'],
                                            reddit_cutoff=score_cutoff_dict.get('cutoff_reddit', None)
        )
    else:
        ds_filtered = apply_score_filtering(ds_preprocessed)

    # Split the dataset
    ds_split = split_ds(ds_preprocessed,
                        ds_filtered,
                        output_dir=output_dir+'/ds_split',
                        save_file=save_intermediates,
                        log_to_wandb=log_to_wandb,
                        overwrite=overwrite)

    # Embed the datasets
    if dedup:
        ds_split = embed_datasets(ds_split,
                                  checkpoint='all-mpnet-base-v2',
                                  output_dir=output_dir+'/ds_embedded',
                                  save_file=save_intermediates,
                                  log_to_wandb=log_to_wandb,
                                  overwrite=overwrite)

    # Deduplicate the datasets and generate pairs (or just generate pairs)
    ds_split = dedup_and_make_pairs(ds_split,
                                    dedup=dedup,
                                    output_dir=output_dir+f'/ds_pairs',
                                    save_file=save_intermediates,
                                    log_to_wandb=log_to_wandb,
                                    overwrite=overwrite)

    # Save the processed datasets to disk
    if save_final:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(f'{output_dir}/processed')
        for key, value in ds_split.items():
            value.save_to_disk(f'{output_dir}/processed/ds_{key}')

        # Log the processed datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='process_data',
                            name=f'process_ELI5_{time_stamp}') as run:

                processed_data_art = wandb.Artifact('ELI5' + output_extension, 'dataset')
                processed_data_art.add_dir(f'{output_dir}/processed')
                run.log_artifact(processed_data_art)

    # Return the dictionary containing the processed datasets.
    return ds_split

# Dataset building code

In [None]:
# to build and detox the dataset:
# - run this cell (to build the dataset)
# - run the next cell (to install the detoxify library)
# - restart the kernel
# - modify values in the following cell to match the call to build_dataset
# - run the last two code cells in this section (to detox the dataset)

ds_split = build_dataset()

In [None]:
# download the toxicity model and remind user to restart the kernel
!pip install git+https://github.com/unitaryai/detoxify.git

from IPython.display import display, HTML, Javascript

def display_restart_reminder():
    button_html = """
    <div style="background-color: #f5f5f5; padding: 10px; border: 1px solid #ccc;
                margin-top: 10px; text-align: center;">
        <button onclick="IPython.notebook.kernel.restart();">ALERT: restart the kernel from the notebook menu before proceeding</button>
    </div>
    """
    display(HTML(button_html))

display_restart_reminder()

In [None]:
# CRUCIAL: MODIFY THESE VALUES TO MATCH CALL TO DATASET BUILD WRAPPER
preprocess=True
score_cutoff_dict=None
dedup=True
detox=True
detox_cutoff=0.4
save_intermediates=False
save_final=True
log_to_wandb=True
overwrite=False
output_basename='./data/ELI5'

this cell just contains string and function defs from the definitions section and should not be modified except to reflect changes made there

In [None]:
# @title
# rebuild output dirname
output_extension = ''
if dedup:
    output_extension += '_deduped'
if detox:
    output_extension += f'_detoxed'
if score_cutoff_dict:
    output_extension += f"__cutoffs_FKG{score_cutoff_dict['cutoff_fkg']}_FRE{score_cutoff_dict['cutoff_fre']}"
    if score_cutoff_dict['cutoff_reddit']:
        output_extension += f'REDD{score_cutoff_dict["cutoff_reddit"]}'
    else:
        output_extension += f'REDD4'
else:
    output_extension += f"__cutoffs_FKG9_FRE60_REDD4"
if detox:
    output_extension += f'_TOX{detox_cutoff}'

output_dir = output_basename + output_extension

# env setup
import torch
import transformers
import wandb
import os
import datetime
from detoxify import Detoxify
from datasets import load_from_disk, load_dataset
from tqdm import tqdm

%cd drive/MyDrive/LLMs/ELI5_dataset

# functions from above to be reloaded
def detox_datasets(ds_split,
                   cutoff=0.1,
                   output_dir='./data/ELI5/ds_detox',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Cleans the datasets by removing toxic examples as judged by the detoxify library.

    Parameters:
        ds_split (dict): A dictionary containing the dataset split for supervised fine-tuning (SFT),
                            reward modeling (RM), and reinforcement learning (RL).
        cutoff (float, optional): The toxicity score threshold to consider examples as toxic.
            Default is 0.1.
        output_dir (str, optional): The directory where the detoxified datasets will be saved.
            Default is 'nontox'.
        save_file (bool, optional): If True, saves the detoxified datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites detoxified cleaned datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the detoxified datasets as a WandB artifact.
            Default is True.
    Returns:
        dict: A dictionary containing the detoxified datasets for SFT, RM, and RL.
    """
    # load the toxicity model
    detoxify_model = Detoxify('unbiased')
    detoxify_model.model.to("cuda" if torch.cuda.is_available() else "cpu")
    splits = ['SFT', 'RM', 'RL']

    # Check if the embedded datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'{output_dir}/ds_{split}') for split in splits)
        and not overwrite):

        ds_nontox = {}

        # Load the embedded datasets from disk and return them.
        for split in splits:
            ds_split[split] = load_from_disk(f'./data/{output_dir}/ds_{split}')
        return ds_nontox

    # ds_nontox is a dictionary which contains DatasetDicts as values.
    ds_nontox = {}
    for split in ds_split:
        print(f'Filtering toxic posts in {split} split.....')
        if split in ['SFT', 'RL']:
            # SFT/RL splits: predict toxicity scores for each example's answers
            ds_split[split] = ds_split[split].map(lambda x: \
                                              {'toxicity_scores': \
                                               [detoxify_model.predict(answer['text']) \
                                                for answer in x['answers']]}, \
                                              batched=True, batch_size=256)
            # only keep nontoxic answers
            ds_nontox[split] = ds_split[split].map(lambda x: detox_answers(x, cutoff))
            # drop examples with no nontoxic answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['answers']['text'])>=1)

        elif split == 'RM':
            # RM split: predict toxicity scores for each example's first answer
            ds_split[split] = ds_split[split].map(lambda x: \
                                                {'toxicity_scores': \
                                                [[detoxify_model.predict(pair[0]) \
                                                    for pair in pairs_text] \
                                                    for pairs_text in x['pairs_text']]
                                                },
                                                  batched=True,
                                                  batch_size=256
            )
            # only keep pairs with nontoxic preferred answers
            ds_nontox[split] = ds_split[split].map(lambda x: detox_answers_RM(x, cutoff))
            # drop examples with no pairs with nontoxic preferred answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['pairs_text'])>=1)

    # Save the detoxed datasets to disk.
    if save_file:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        for split in splits:
            ds_nontox[split].save_to_disk(f'{output_dir}/ds_{split}')

        # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='detox_data',
                            name=f'detox_data_{time_stamp}') as run:

                nontox_data_art = wandb.Artifact('ELI5_detox', 'dataset')
                nontox_data_art.add_dir(output_dir)
                run.log_artifact(nontox_data_art)

    return ds_nontox


def detox_answers(example, cutoff):
    """
    Removes toxic answers from the input example (for SFT and RL).

    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'text' and 'toxicity_score' lists.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
           'obscene','severe_toxicity',
           'sexual_explicit','threat',
           'toxicity']

    # get the indices of answers below tox cutoff
    answers = example['answers']
    tox_scores = example['toxicity_scores']
    valid_indices = [idx for idx, _ in enumerate(answers['text']) \
                        if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])

    ]
    
    # build nontox answers dict
    detoxed_answers = {
        'a_id': [answers['a_id'][index] for index in valid_indices],
        'score': [answers['score'][index] for index in valid_indices],
        'text': [answers['text'][index] for index in valid_indices],
    }
    example['answers'] = detoxed_answers
    return example


def detox_answers_RM(example, cutoff):
    """
    Removes answer pairs with toxic preferred answers from the input example.

    Parameters:
        example (dict): The input example with a list of answer pairs stored in the 'pairs_text' key
                        and a list of the toxicity scores of the preferred answer in each pair
                        stored in the 'toxicity_score' key. The two lists must share an index.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
       'obscene','severe_toxicity',
       'sexual_explicit','threat',
       'toxicity']

    # get the indices of answer pairs with preferred answers below tox cutoff
    pairs = example['pairs_text']
    tox_scores = example['toxicity_scores']
    valid_indices = [idx for idx, _ in enumerate(pairs) \
                        if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])

    ]
    
    # filter the pairs
    detoxed_pairs = [pair for idx, pair in enumerate(pairs) if idx in valid_indices]
    example['pairs_text'] = detoxed_pairs
    return example

def load_ds(base_dir):
    """
    Load ds from disk.

    Dynamically detects the number of arrow files in the train, test, and validation directories.

    Parameters:
        base_dir (str): The base directory of the dataset.

    Returns:
        ds (DatasetDict): A DatasetDict containing the dataset splits.

    """

    train_files = [os.path.join(base_dir, "train", file) for file in get_filenames_for_dir(os.path.join(base_dir, "train"))]
    test_files = [os.path.join(base_dir, "test", file) for file in get_filenames_for_dir(os.path.join(base_dir, "test"))]
    validation_files = [os.path.join(base_dir, "validation", file) for file in get_filenames_for_dir(os.path.join(base_dir, "validation"))]

    ds = load_dataset("arrow", data_files={"train": train_files, "test": test_files, "validation": validation_files})

    return ds

def load_split_ds(base_dir):
    """
    Load SFT/RM/RL splits from disk.

    Dynamically detects the number of arrow files in each split and its subsplits.

    Parameters:
        base_dir (str): The base directory of the dataset splits.

    Returns:
        ds (dict[DatasetDict]): A dictionary containing the SFT/RM/RL splits of the dataset.

    """

    splits = ["SFT", "RM", "RL"]
    ds = {}

    for split in splits:
        split_dir = os.path.join(base_dir, f"ds_{split}")

        # Check if train/test/validation subdirectories exist
        if all(os.path.isdir(os.path.join(split_dir, subsplit)) for subsplit in ["train", "test", "validation"]):
            data_files = {
                subsplit: [os.path.join(split_dir, subsplit, file) for file in get_filenames_for_dir(os.path.join(split_dir, subsplit))]
                for subsplit in ["train", "test", "validation"]
            }
        else:  # No subdirectories, just load the files directly
            data_files = [os.path.join(split_dir, 'train', file) for file in get_filenames_for_dir(os.path.join(split_dir, 'train'))]

        ds_split = load_dataset("arrow", data_files=data_files)

        # Store in the master dictionary
        ds[split] = ds_split

    return ds

def get_filenames_for_dir(directory):
    """Helper function to get sorted arrow filenames for a given directory."""
    filenames = sorted([f for f in os.listdir(directory) if f.endswith(".arrow")])
    total_files = len(filenames)
    return [f"data-0000{i}-of-0000{total_files}.arrow" for i in range(total_files)]

In [None]:
# load ds to detox
ds_to_load = 'ds_pairs' if not deduped else 'ds_pairs_deduped'

stripped_output_dir = '/'.join(output_dir.split('/')[-2:])
ds_deduped = load_split_ds(stripped_output_dir+'/'+ds_to_load)

# detoxify the datasets
ds_detox = detox_datasets(ds_deduped,
                            cutoff=detox_cutoff,
                            output_dir=output_dir+'/ds_detox',
                            save_file=save_intermediates,
                            log_to_wandb=log_to_wandb,
                            overwrite=overwrite)

# save final dataset to disk
if save_final:
    if overwrite and os.path.exists(output_dir):
            shutil.rmtree(f'{output_dir}/processed')
    for split in splits:
        ds_detox[split].save_to_disk(f'{output_dir}/processed/ds_{split}')

    # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
    if log_to_wandb:
        now = datetime.now()
        time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
        with wandb.init(project='ELI5_analysis',
                        entity='ft-llmmm',
                        job_type='detox_data',
                        name=f'detox_data_{time_stamp}') as run:

            nontox_data_art = wandb.Artifact('ELI5' + output_extension, 'dataset')
            nontox_data_art.add_dir(os.path.join(output_dir, 'processed'))
            run.log_artifact(nontox_data_art)

# Walkthrough of ds building code

### walkthrough setup

In [None]:
score_cutoff_dict=None
dedup=True
detox=True
detox_cutoff=0.4
save_intermediates=False
save_final=True
log_to_wandb=True
overwrite=True
output_basename='./data/ELI5'

In [None]:
output_extension = ''
if dedup:
    output_extension += '_deduped'
if detox:
    output_extension += f'_detoxed'
if score_cutoff_dict:
    output_extension += f"__cutoffs_FKG{score_cutoff_dict['cutoff_fkg']}_FRE{score_cutoff_dict['cutoff_fre']}"
    if score_cutoff_dict['cutoff_reddit']:
        output_extension += f'REDD{score_cutoff_dict["cutoff_reddit"]}'
    else:
        output_extension += f'REDD4'
else:
    output_extension += f"__cutoffs_FKG9_FRE60_REDD4"
if detox:
    output_extension += f'_TOX{detox_cutoff}'



output_dir = output_basename + output_extension

# Check if the processed datasets already exist in the output directory and overwrite is False.
if (all(os.path.exists(f'{output_dir}/processed/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
    and not overwrite):

    ds = {}

    # Load the embedded datasets from disk and return them.
    print('Loading processed datasets.....')
    for subset in ['SFT', 'RM', 'RL']:
        ds[subset] = load_from_disk(f'{output_dir}/ds_{subset}')

print(output_dir)

./data/ELI5_deduped_detoxed__cutoffs_FKG9_FRE60_REDD4_TOX0.4


In [None]:
# Download original dataset
print('Downloading dataset.....')
ds_original = load_dataset("vblagoje/lfqa")

Downloading dataset.....


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/687M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

### walkthrough of preprocess, score filter, split, embed, dedup

In [None]:
# Preprocess the dataset
ds_preprocessed = preprocess_data(ds_original,
                            output_dir=output_dir+'/ds_preprocessed',
                            save_file=save_intermediates,
                            log_to_wandb=False,
                            overwrite=overwrite)

Preprocessing datasets.....


Map:   0%|          | 0/226147 [00:00<?, ? examples/s]

Map:   0%|          | 0/3020 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Filtering posts.....


Filter:   0%|          | 0/185400 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2593 [00:00<?, ? examples/s]

Filter:   0%|          | 0/8371 [00:00<?, ? examples/s]

Filter:   0%|          | 0/185316 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2591 [00:00<?, ? examples/s]

Filter:   0%|          | 0/8364 [00:00<?, ? examples/s]

Filter:   0%|          | 0/181269 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2528 [00:00<?, ? examples/s]

Filter:   0%|          | 0/8092 [00:00<?, ? examples/s]

Filter:   0%|          | 0/177664 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2490 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7857 [00:00<?, ? examples/s]

Combining post title+body.....


Map:   0%|          | 0/176369 [00:00<?, ? examples/s]

Map:   0%|          | 0/2467 [00:00<?, ? examples/s]

Map:   0%|          | 0/7788 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/176369 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2467 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7788 [00:00<?, ? examples/s]

In [None]:
# Score filter the dataset
if score_cutoff_dict:
    ds_filtered = apply_score_filtering(ds_preprocessed,
                                fre_cutoff=score_cutoff_dict['cutoff_fre'],
                                fkg_cutoff=score_cutoff_dict['cutoff_fkg'],
                                reddit_cutoff=score_cutoff_dict.get('cutoff_reddit', None)
                                )
else:
    ds_filtered = apply_score_filtering(ds_preprocessed)

Computing flesch scores.....


Map:   0%|          | 0/176369 [00:00<?, ? examples/s]

Map:   0%|          | 0/2467 [00:00<?, ? examples/s]

Map:   0%|          | 0/7788 [00:00<?, ? examples/s]

Filtering by flesch score (FRE>=60, FKG<9).....


Map:   0%|          | 0/176369 [00:00<?, ? examples/s]

Map:   0%|          | 0/2467 [00:00<?, ? examples/s]

Map:   0%|          | 0/7788 [00:00<?, ? examples/s]

Filter:   0%|          | 0/176369 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2467 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7788 [00:00<?, ? examples/s]

In [None]:
# Split the dataset
ds_split = split_ds(ds_preprocessed,
                    ds_filtered,
                    output_dir=output_dir+'/ds_split',
                    save_file=save_intermediates,
                    log_to_wandb=log_to_wandb,
                    overwrite=overwrite)

Filter:   0%|          | 0/112234 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1872 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5373 [00:00<?, ? examples/s]

Filter:   0%|          | 0/112234 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1872 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5373 [00:00<?, ? examples/s]

Generating SFT and RM splits.....


Map:   0%|          | 0/47296 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/2862 [00:00<?, ? examples/s]

Map:   0%|          | 0/47296 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/2862 [00:00<?, ? examples/s]

Filter:   0%|          | 0/47296 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2862 [00:00<?, ? examples/s]

Map:   0%|          | 0/47296 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/2862 [00:00<?, ? examples/s]

Filter:   0%|          | 0/47296 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2862 [00:00<?, ? examples/s]

Map:   0%|          | 0/83084 [00:00<?, ? examples/s]

Map:   0%|          | 0/1161 [00:00<?, ? examples/s]

Map:   0%|          | 0/3583 [00:00<?, ? examples/s]

Filter:   0%|          | 0/83084 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1161 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3583 [00:00<?, ? examples/s]

Generating RL split.....


Filter:   0%|          | 0/176369 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2467 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7788 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/47296 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1358 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2862 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/37225 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1896 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/101457 [00:00<?, ? examples/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Adding directory to artifact (./data/ELI5_deduped_detoxed__cutoffs_FKG9_FRE60_REDD4_TOX0.4/ds_split)... Done. 0.9s


In [None]:
# Embed the datasets
ds_embed = embed_datasets(ds_embed,
                            checkpoint='all-mpnet-base-v2',
                            output_dir=output_dir+'/ds_embedded',
                            save_file=save_intermediates,
                            log_to_wandb=log_to_wandb,
                            overwrite=overwrite)

Downloading (…)a8e1d/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)0bca8e1d/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)e1d/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)a8e1d/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)8e1d/train_script.py:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)bca8e1d/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

Embedding RM dataset.....


Map:   0%|          | 0/47296 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/2862 [00:00<?, ? examples/s]

Embedding SFT dataset.....


Map:   0%|          | 0/37225 [00:00<?, ? examples/s]

Map:   0%|          | 0/684 [00:00<?, ? examples/s]

Map:   0%|          | 0/1896 [00:00<?, ? examples/s]

Embedding RL dataset.....


Map:   0%|          | 0/101457 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/47296 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1358 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2862 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/37225 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1896 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/101457 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Currently logged in as: [33mblm3000[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (./data/ELI5_deduped_detoxed__cutoffs_FKG9_FRE60_REDD4_TOX0.4/ds_embedded)... Done. 3.6s


In [None]:
# Deduplicate the datasets and generate pairs (or just generate pairs)
ds_dedup = dedup_and_make_pairs(ds_embed,
                                dedup=dedup,
                                output_dir=output_dir+'/ds_pairs',
                                save_file=save_intermediates,
                                log_to_wandb=log_to_wandb,
                                overwrite=overwrite)

Deduplicating SFT dataset.....
Deduplicating RM dataset.....
Deduplicating RL dataset.....


100%|██████████| 21/21 [01:06<00:00,  3.14s/it]


Making pairs for RM.....


Map:   0%|          | 0/40126 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/2731 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/34866 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1858 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/40126 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1358 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2731 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/96457 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/ELI5_deduped_detoxed__cutoffs_FKG9_FRE60_REDD4_TOX0.4/ds_pairs_deduped)... Done. 3.3s


### walkthrough of detox

In [None]:
# load the toxicity model and remind user to restart the kernel
!pip install git+https://github.com/unitaryai/detoxify.git

from IPython.display import display, HTML, Javascript

def display_restart_reminder():
    button_html = """
    <div style="background-color: #f5f5f5; padding: 10px; border: 1px solid #ccc;
                margin-top: 10px; text-align: center;">
        <button onclick="IPython.notebook.kernel.restart();">ALERT: restart the kernel from the notebook menu before proceeding</button>
    </div>
    """
    display(HTML(button_html))

display_restart_reminder()

Collecting git+https://github.com/unitaryai/detoxify.git
  Cloning https://github.com/unitaryai/detoxify.git to /tmp/pip-req-build-2v41hw1m
  Running command git clone --filter=blob:none --quiet https://github.com/unitaryai/detoxify.git /tmp/pip-req-build-2v41hw1m
  Resolved https://github.com/unitaryai/detoxify.git to commit c9ffbac22d97ed63fb4f7862a9bdb33006c94aeb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting transformers==4.30.0 (from detoxify==0.5.1)
  Downloading transformers-4.30.0-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.30.0->detoxify==0.5.1)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

start from here once kernel has been restarted

In [None]:
import torch
import transformers
import wandb
import os
import datetime
from detoxify import Detoxify
from datasets import load_from_disk, load_dataset
from tqdm import tqdm

%cd drive/MyDrive/LLMs/ELI5_dataset

In [None]:
# CRUCIAL: values must match call to dataset build wrapper
preprocess=True
score_cutoff_dict=None
dedup=True
detox=True
detox_cutoff=0.4
save_intermediates=False
save_final=True
log_to_wandb=True
overwrite=False
output_basename='./data/ELI5'

# rebuild output dirname
output_extension = ''
if dedup:
    output_extension += '_deduped'
if detox:
    output_extension += f'_detoxed'
if score_cutoff_dict:
    output_extension += f"__cutoffs_FKG{score_cutoff_dict['cutoff_fkg']}_FRE{score_cutoff_dict['cutoff_fre']}"
    if score_cutoff_dict['cutoff_reddit']:
        output_extension += f'REDD{score_cutoff_dict["cutoff_reddit"]}'
    else:
        output_extension += f'REDD4'
else:
    output_extension += f"__cutoffs_FKG9_FRE60_REDD4"
if detox:
    output_extension += f'_TOX{detox_cutoff}'

output_dir = output_basename + output_extension

this cell just contains functions from above that need to be reloaded

In [None]:
# @title
def detox_datasets(ds_split,
                   cutoff=0.1,
                   output_dir='./data/ELI5/ds_detox',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Cleans the datasets by removing toxic examples as judged by the detoxify library.

    Parameters:
        ds_split (dict): A dictionary containing the dataset split for supervised fine-tuning (SFT),
                            reward modeling (RM), and reinforcement learning (RL).
        cutoff (float, optional): The toxicity score threshold to consider examples as toxic.
            Default is 0.1.
        output_dir (str, optional): The directory where the detoxified datasets will be saved.
            Default is 'nontox'.
        save_file (bool, optional): If True, saves the detoxified datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites detoxified cleaned datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the detoxified datasets as a WandB artifact.
            Default is True.
    Returns:
        dict: A dictionary containing the detoxified datasets for SFT, RM, and RL.
    """
    # load the toxicity model
    detoxify_model = Detoxify('unbiased')
    detoxify_model.model.to("cuda" if torch.cuda.is_available() else "cpu")
    splits = ['SFT', 'RM', 'RL']

    # Check if the embedded datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'{output_dir}/ds_{split}') for split in splits)
        and not overwrite):

        ds_nontox = {}

        # Load the embedded datasets from disk and return them.
        for split in splits:
            ds_split[split] = load_from_disk(f'./data/{output_dir}/ds_{split}')
        return ds_nontox

    # ds_nontox is a dictionary which contains DatasetDicts as values.
    ds_nontox = {}
    for split in ds_split:
        print(f'Filtering toxic posts in {split} split.....')
        if split in ['SFT', 'RL']:
            # SFT/RL splits: predict toxicity scores for each example's answers
            ds_split[split] = ds_split[split].map(lambda x: \
                                              {'toxicity_scores': \
                                               [detoxify_model.predict(answer['text']) \
                                                for answer in x['answers']]}, \
                                              batched=True, batch_size=256)
            # only keep nontoxic answers
            ds_nontox[split] = ds_split[split].map(lambda x: detox_answers(x, cutoff))
            # drop examples with no nontoxic answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['answers']['text'])>=1)

        elif split == 'RM':
            # RM split: predict toxicity scores for each example's first answer
            ds_split[split] = ds_split[split].map(lambda x: \
                                                {'toxicity_scores': \
                                                [[detoxify_model.predict(pair[0]) \
                                                    for pair in pairs_text] \
                                                    for pairs_text in x['pairs_text']]
                                                },
                                                  batched=True,
                                                  batch_size=256
            )
            # only keep pairs with nontoxic preferred answers
            ds_nontox[split] = ds_split[split].map(lambda x: detox_answers_RM(x, cutoff))
            # drop examples with no pairs with nontoxic preferred answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['pairs_text'])>=1)

    # Save the detoxed datasets to disk.
    if save_file:
        if overwrite and os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        for split in splits:
            ds_nontox[split].save_to_disk(f'{output_dir}/ds_{split}')

        # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='detox_data',
                            name=f'detox_data_{time_stamp}') as run:

                nontox_data_art = wandb.Artifact('ELI5_detox', 'dataset')
                nontox_data_art.add_dir(output_dir)
                run.log_artifact(nontox_data_art)

    return ds_nontox


def detox_answers(example, cutoff):
    """
    Removes toxic answers from the input example (for SFT and RL).

    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'text' and 'toxicity_score' lists.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
           'obscene','severe_toxicity',
           'sexual_explicit','threat',
           'toxicity']

    # get the indices of answers below tox cutoff
    answers = example['answers']
    tox_scores = example['toxicity_scores']
    valid_indices = [idx for idx, _ in enumerate(answers['text']) \
                        if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])

    ]
    
    # build nontox answers dict
    detoxed_answers = {
        'a_id': [answers['a_id'][index] for index in valid_indices],
        'score': [answers['score'][index] for index in valid_indices],
        'text': [answers['text'][index] for index in valid_indices],
    }
    example['answers'] = detoxed_answers
    return example


def detox_answers_RM(example, cutoff):
    """
    Removes answer pairs with toxic preferred answers from the input example.

    Parameters:
        example (dict): The input example with a list of answer pairs stored in the 'pairs_text' key
                        and a list of the toxicity scores of the preferred answer in each pair
                        stored in the 'toxicity_score' key. The two lists must share an index.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
       'obscene','severe_toxicity',
       'sexual_explicit','threat',
       'toxicity']

    # get the indices of answer pairs with preferred answers below tox cutoff
    pairs = example['pairs_text']
    tox_scores = example['toxicity_scores']
    valid_indices = [idx for idx, _ in enumerate(pairs) \
                        if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])

    ]
    
    # filter the pairs
    detoxed_pairs = [pair for idx, pair in enumerate(pairs) if idx in valid_indices]
    example['pairs_text'] = detoxed_pairs
    return example

def load_ds(base_dir):
    """
    Load ds from disk.

    Dynamically detects the number of arrow files in the train, test, and validation directories.

    Parameters:
        base_dir (str): The base directory of the dataset.

    Returns:
        ds (DatasetDict): A DatasetDict containing the dataset splits.

    """

    train_files = [os.path.join(base_dir, "train", file) for file in get_filenames_for_dir(os.path.join(base_dir, "train"))]
    test_files = [os.path.join(base_dir, "test", file) for file in get_filenames_for_dir(os.path.join(base_dir, "test"))]
    validation_files = [os.path.join(base_dir, "validation", file) for file in get_filenames_for_dir(os.path.join(base_dir, "validation"))]

    ds = load_dataset("arrow", data_files={"train": train_files, "test": test_files, "validation": validation_files})

    return ds

def load_split_ds(base_dir):
    """
    Load SFT/RM/RL splits from disk.

    Dynamically detects the number of arrow files in each split and its subsplits.

    Parameters:
        base_dir (str): The base directory of the dataset splits.

    Returns:
        ds (dict[DatasetDict]): A dictionary containing the SFT/RM/RL splits of the dataset.

    """

    splits = ["SFT", "RM", "RL"]
    ds = {}

    for split in splits:
        split_dir = os.path.join(base_dir, f"ds_{split}")

        # Check if train/test/validation subdirectories exist
        if all(os.path.isdir(os.path.join(split_dir, subsplit)) for subsplit in ["train", "test", "validation"]):
            data_files = {
                subsplit: [os.path.join(split_dir, subsplit, file) for file in get_filenames_for_dir(os.path.join(split_dir, subsplit))]
                for subsplit in ["train", "test", "validation"]
            }
        else:  # No subdirectories, just load the files directly
            data_files = [os.path.join(split_dir, 'train', file) for file in get_filenames_for_dir(os.path.join(split_dir, 'train'))]

        ds_split = load_dataset("arrow", data_files=data_files)

        # Store in the master dictionary
        ds[split] = ds_split

    return ds

def get_filenames_for_dir(directory):
    """Helper function to get sorted arrow filenames for a given directory."""
    filenames = sorted([f for f in os.listdir(directory) if f.endswith(".arrow")])
    total_files = len(filenames)
    return [f"data-0000{i}-of-0000{total_files}.arrow" for i in range(total_files)]


In [None]:
# load ds to detox
ds_to_load = 'ds_pairs_deduped'

stripped_output_dir = '/'.join(output_dir.split('/')[-2:])
ds_deduped = load_split_ds(stripped_output_dir+'/'+ds_to_load)

In [None]:
# Detoxify the datasets
if detox:
    ds_detox = detox_datasets(ds_deduped,
                              cutoff=detox_cutoff,
                                output_dir=output_dir+'/ds_detox',
                                save_file=save_intermediates,
                                log_to_wandb=log_to_wandb,
                                overwrite=overwrite)

Map:   0%|          | 0/34866 [00:00<?, ? examples/s]

Map:   0%|          | 0/1858 [00:00<?, ? examples/s]

Map:   0%|          | 0/684 [00:00<?, ? examples/s]

Map:   0%|          | 0/34866 [00:00<?, ? examples/s]

Map:   0%|          | 0/1858 [00:00<?, ? examples/s]

Map:   0%|          | 0/684 [00:00<?, ? examples/s]

Filter:   0%|          | 0/34866 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1858 [00:00<?, ? examples/s]

Filter:   0%|          | 0/684 [00:00<?, ? examples/s]

Map:   0%|          | 0/40126 [00:00<?, ? examples/s]

Map:   0%|          | 0/2731 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/40126 [00:00<?, ? examples/s]

Map:   0%|          | 0/2731 [00:00<?, ? examples/s]

Map:   0%|          | 0/1358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40126 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2731 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1358 [00:00<?, ? examples/s]

Map:   0%|          | 0/96457 [00:00<?, ? examples/s]

Map:   0%|          | 0/96457 [00:00<?, ? examples/s]

Filter:   0%|          | 0/96457 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/34866 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1858 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/684 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/33671 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2464 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1290 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/96457 [00:00<?, ? examples/s]

AttributeError: ignored

In [None]:
# Save the detoxed datasets to disk.
if save_final:
    if overwrite and os.path.exists(output_dir):
        shutil.rmtree(f'{output_dir}/processed')
    for key, value in ds_detox.items():
        value.save_to_disk(f'{output_dir}/processed/ds_{key}')

    # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
    if log_to_wandb:
        now = datetime.now()
        time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
        with wandb.init(project='ELI5_analysis',
                        entity='ft-llmmm',
                        job_type='detox_data',
                        name=f'detox_data_{time_stamp}') as run:

            nontox_data_art = wandb.Artifact('ELI5' + output_extension, 'dataset')
            nontox_data_art.add_dir(os.path.join(output_dir, 'processed'))
            run.log_artifact(nontox_data_art)

Saving the dataset (0/1 shards):   0%|          | 0/34866 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1858 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/684 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/96457 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/33671 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2464 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1290 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/ELI5_deduped_detoxed__cutoffs_FKG9_FRE60_REDD4_TOX0.4/processed)... Done. 2.9s


VBox(children=(Label(value='989.172 MB of 989.172 MB uploaded (989.170 MB deduped)\r'), FloatProgress(value=1.…

# Code (scratch)

In [None]:
# Download original dataset.
print('Downloading dataset.....')
ds_original = load_dataset("vblagoje/lfqa")

# Preprocess dataset to clean up text, remove posts with short answers, etc.
print('Preprocessing dataset.....')
ds_filtered = preprocess_data(ds_original)

# Split dataset into SFT, RM, and RL subsets.
print('Splitting dataset.....')
ds_split = split_ds(ds_original,
                    ds_filtered)

Downloading dataset.....
Preprocessing dataset.....
Splitting dataset.....


In [None]:
# Embed answers using sentence transformers.
print('Embedding dataset.....')
ds_embedded = embed_datasets(ds_split)

# Use embedded answers to detect and remove dataleakage.
print('Doing semantic deduplication.....')
ds_clean = clean_datasets(ds_embedded)

Embedding dataset.....


Map:   0%|          | 0/42594 [00:00<?, ? examples/s]

Map:   0%|          | 0/756 [00:00<?, ? examples/s]

Map:   0%|          | 0/2124 [00:00<?, ? examples/s]

Map:   0%|          | 0/42594 [00:00<?, ? examples/s]

Map:   0%|          | 0/756 [00:00<?, ? examples/s]

Map:   0%|          | 0/2124 [00:00<?, ? examples/s]

Map:   0%|          | 0/53538 [00:00<?, ? examples/s]

Map:   0%|          | 0/1441 [00:00<?, ? examples/s]

Map:   0%|          | 0/3236 [00:00<?, ? examples/s]

Map:   0%|          | 0/53538 [00:00<?, ? examples/s]

Map:   0%|          | 0/1441 [00:00<?, ? examples/s]

Map:   0%|          | 0/3236 [00:00<?, ? examples/s]

Map:   0%|          | 0/142093 [00:00<?, ? examples/s]

Map:   0%|          | 0/142093 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/42594 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/756 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2124 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/53538 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1441 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3236 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/142093 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Currently logged in as: [33mblm3000[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112085744439861, max=1.0…

[34m[1mwandb[0m: Adding directory to artifact (./data/embedded)... Done. 9.5s


Doing semantic deduplication.....
Cleaning SFT dataset
Cleaning RM dataset
Cleaning RL dataset


100%|██████████| 29/29 [04:41<00:00,  9.71s/it]


Map:   0%|          | 0/45204 [00:00<?, ? examples/s]

Map:   0%|          | 0/1441 [00:00<?, ? examples/s]

Map:   0%|          | 0/3086 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/39855 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/756 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2077 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/45204 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1441 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3086 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/137093 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/cleaned)... Done. 12.2s


In [None]:
print('Original:')
print(ds_original)
print('Filtered:')
print(ds_filtered)
print('Split:')
print(ds_split)
print('Embedded:')
print(ds_embedded)
print('Cleaned:')
print(ds_clean)


Original:
DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls'],
        num_rows: 226147
    })
    validation: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls'],
        num_rows: 3020
    })
    test: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls'],
        num_rows: 10000
    })
})
Filtered:
DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls'],
        num_rows: 124642
    })
    validation: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls'],
        num_row

In [None]:
ds_nonzero_answers = ds.filter(lambda x: len(x['answers']['text']) > 0)
print(ds_nonzero_answers)
print(ds)

print('\n'.join([ds['train'][i]['title_body'] for i in range(50)]))


DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'title_body'],
        num_rows: 176369
    })
    validation: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'title_body'],
        num_rows: 2467
    })
    test: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'title_body'],
        num_rows: 7788
    })
})
DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'title_body'],
        num_rows: 176369
    })
    validation: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls

In [None]:
print('\n\n=======================\n\n'.join(['\n\n'.join(ds['train'][i]['answers']['text']) for i in range(50)]))

They're used interchangeably a lot. You'll get different answers from different resources, but the general consensus seems to be that woods are smaller than forests. > a wood is an area covered in trees, larger than a grove or a copse. A forest is also an area covered in trees, but it is larger than a wood > the u.s. National vegetation classification system differentiates them according to their densities: 25 to 60 percent of a a wood is covered by tree canopies, while 60 to 100 percent of a forest is canopied.


A) instinct. To protect it from further damage (if the damaging agent is ongoing) or to prevent bleeding and such. B) pain. Our brain knows that pressure sensation blocks pain sensation from experience. So we reflexively grab the injury site because it alleviates the pain.

So you have 2 different types of pressure sensors in your skin, superficial or closer to the surface and deep. Pressure sensors report back to the brain faster than pain sensors do so you can "jam the sign

In [None]:
# old unbatched preproc code

def preprocess_data(dataset,
                    output_dir='./data/ELI5/ds_preprocessed',
                    save_file=True,
                    log_to_wandb=True,
                    overwrite=False):
    """
    Preprocesses the input dataset by applying various filters,
    then combining the title and body of each post.

    Parameters:
        dataset (Dataset): The input Huggingface dataset to be processed.
        output_dir (str, optional): The path to the directory where the processed dataset will be saved.
            Default is './data/filtered'.
        save_file (bool, optional): If True, saves the processed dataset to the output_file.
            Default is True.
        log_to_wandb (bool, optional): If True, logs the processed dataset as a WandB artifact.
            Default is True.
        overwrite (bool, optional): If True, overwrites the output_file if it already exists.
            Default is False.

    Returns:
        Dataset: The preprocessed dataset.

    """

    if os.path.exists(output_dir) and not overwrite:
        print('Loading filtered datasets.....')
        # If the output_file exists and overwrite is False, load the dataset from disk and return it.
        return load_from_disk(output_dir)

    # List of strings to filter out posts based on their titles
    not_qus = ['AMA', 'megathread', 'Discussion Thread',
               'Ask Anything Wednesday', 'Monday Methods',
               'Tuesday Trivia', 'Monday Mysteries',
               'Theory Thursday', 'Monday Mish-Mash',
               'Media Mondays', 'Wednesday Week in History',
               'Saturday Popular Questions', 'Ask Anything Wednesday',
               'Thursday Focus Historical Fiction', 'Askhistorians Podcast',
               'cross post', 'cross-post', 'crosspost', 'x post', 'x-post', 'x/post',
               'mod post', 'mods', 'moderator','meta',
               'ask me anything', 'meetup',' floating feature', 'twenty-year rule',
               'subreddit', 'Rules Roundtable',
              ]

    # List of question words used to filter out posts without meaningful questions in their titles or selftext
    qu_reqs = ['who', 'what', 'where', 'why', 'when', 'how', '?']

    # Preprocess each example in the dataset using the preprocess_example function
    print('Preprocessing datasets.....')
    dataset = dataset.map(preprocess_example)

    # Filter out posts with 'nsfw' in their titles
    dataset = dataset.filter(lambda post: 'nsfw' not in post['title'].lower())

    # Filter out posts with '__url_i__' in the title or selftext
    dataset = dataset.filter(lambda post: not contains_url(post['title']) \
                                          and not contains_url(post['selftext']))

    # Filter out posts that do not contain meaningful questions in their titles or selftext
    dataset = dataset.filter(lambda post:
                             not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                                  and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

    # Filter out posts that do not correspond to questions.
    dataset = dataset.filter(lambda post: not (any(nq in post['title'].lower() for nq in not_qus)))

    # Combine title and body of remaining posts
    dataset = dataset.map(combine_title_body)

    if save_file:
        # Save the processed dataset to the output_file
        dataset.save_to_disk(output_dir)

        if log_to_wandb:
            # Log the processed dataset as a WandB artifact if log_to_wandb is True
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='preprocess_data',
                            name=f'preprocess_data_{time_stamp}') as run:
                # Initialize a WandB run for logging
                processed_data_art = wandb.Artifact('ELI5_preprocessed', 'dataset')
                processed_data_art.add_dir(output_dir)
                run.log_artifact(processed_data_art)

    return dataset

def preprocess_example(example):
    """
    Preprocess an example dictionary containing 'answers', 'title', and 'selftext' keys.

    The function applies the following preprocessing steps to each element in the example:
    1. Cleans all answers, titles, and selftext using redditcleaner.
    2. Removes extra whitespaces.
    3. Remove any answers that contain "_url_i_" (posts w/urls in title/selftext are filtered later in preprocess_data)
    4. Remove any answers that contain "reddit".
    5. Truncate answers, titles and selftext at 'edit:', "[update" etc. (refer to truncate_edit_update_thanks for details).
    6. Truncate selftext at 'PS', 'p.s.' etc.
    7. Remove any answers with less than 20 words.
    8. Remove 'eli5', 'ELI 10:' etc at the beginning of the title.


    Parameters:
        example (dict): A dictionary containing 'answers', 'title', and 'selftext' keys.

    Returns:
        dict: The preprocessed example dictionary with the above transformations applied.

    Example:
        >>> example = {
                'answers': {'text': ['Visit this website: _url_123_',
                                     'Ask this question on another subreddit',
                                     'Ask this question. edit: a a a a a a a a a a a a a a a a a a a a a a a a a',
                                     'this is an  answer containing at least 20 words and a a a a a a a a a a a']},
                'title': 'ELI 5: How to use Python?',
                'selftext': 'Check out this tutorial: _Url_789_ to learn Python. [updated to fix typos]'
            }
        >>> preprocess_example(example)
        {
            'answers': {'text': ['this is an answer containing at least 20 words and a a a a a a a a a a a']},
            'title': 'How to use Python?',
            'selftext': 'Check out this tutorial: _Url_789_ to learn Python.'
        }

        Note that the above post will be removed by preprocess_data since it contains a url in the selftext.
    """
    # Preprocess 'answers'
    answers = [answer.strip() for answer in example['answers']['text']]
    answers = [redditcleaner.clean(answer) for answer in answers]
    answers = [' '.join(answer.split()) for answer in answers]
    # TODO: need to keep scores aligned
    answers = [answer for answer in answers if not contains_url(answer)]
    answers = [answer for answer in answers if not 'reddit' in answer]
    answers = [truncate_edit_update_thanks(answer) for answer in answers]
    answers = [answer for answer in answers if len(answer.split()) >= 20]
    example['answers']['text'] = answers

    # Preprocess 'title'
    title = example['title'].strip()
    title = redditcleaner.clean(title)
    title = ' '.join(title.split())
    title = truncate_edit_update_thanks(title)
    title = re.sub(r'^eli\s?\d*[.:-]?', '', title, flags=re.IGNORECASE)
    example['title'] = title

    # Preprocess 'selftext'
    selftext = example['selftext'].strip()
    selftext = redditcleaner.clean(selftext)
    selftext = ' '.join(selftext.split())
    selftext = truncate_edit_update_thanks(selftext)
    selftext = truncate_ps(selftext)
    example['selftext'] = selftext

    # only return if at least one answer
    return example

# Filtering toxic posts (scratch)

In [None]:
# remove toxic posts from SFT, RL splits
# in RM, remove toxic answers if preferred

# TODO: how to deal with toxic questions?

!pip install git+https://github.com/unitaryai/detoxify.git

from detoxify import Detoxify


Collecting git+https://github.com/unitaryai/detoxify.git
  Cloning https://github.com/unitaryai/detoxify.git to /tmp/pip-req-build-7vbddasv
  Running command git clone --filter=blob:none --quiet https://github.com/unitaryai/detoxify.git /tmp/pip-req-build-7vbddasv
  Resolved https://github.com/unitaryai/detoxify.git to commit c9ffbac22d97ed63fb4f7862a9bdb33006c94aeb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting transformers==4.30.0 (from detoxify==0.5.1)
  Downloading transformers-4.30.0-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.22.1
    Uninstalling transformers-4.22.1:
      Successfully uninstalled transformers-4.22.1
[3

In [None]:
def detox_answers(example, cutoff):
    """
    Removes toxic answers from the input example.

    Parameters:
        example (dict): The input example containing 'answers' as a dictionary with 'text' and 'toxicity_score' lists.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
           'obscene','severe_toxicity',
           'sexual_explicit','threat',
           'toxicity']

    # get the answers and their tox scores
    answers = example['answers']['text']
    tox_scores = example['toxicity_scores']

    # only keep answers whose tox scores are below the cutoff
    answers = [answer for idx, answer in enumerate(answers) \
              if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])]
    example['answers']['text'] = answers

    return example

def detox_answers_RM(example, cutoff):
    """
    Removes answer pairs with toxic preferred answers from the input example.

    Parameters:
        example (dict): The input example with a list of answer pairs stored in the 'pairs_text' key
                        and a list of the toxicity scores of the preferred answer in each pair
                        stored in the 'toxicity_score' key. The two lists must share an index.

    Returns:
        dict: The modified input example with the 'answers' key containing only the nontoxic answers.

    """
    # metrics from the detoxify library
    metrics = ['identity_attack','insult',
       'obscene','severe_toxicity',
       'sexual_explicit','threat',
       'toxicity']

    # get the answer pairs and the tox scores of the preferred answers
    pairs_text = example['pairs_text']
    tox_scores = example['toxicity_scores']

    # only keep answer pairs whose preferred answer tox scores are below the cutoff
    pairs_text = [answer_pair for idx, answer_pair in enumerate(pairs_text) \
                 if all([tox_scores[metric][idx] <= cutoff \
                      for metric in metrics])]
    example['pairs_text'] = pairs_text

    return example

def detox_datasets(ds_split,
                   cutoff=0.1,
                   output_dir='detox',
                   save_file=True,
                   overwrite=False,
                   log_to_wandb=True):
    """
    Cleans the datasets by removing toxic examples as judged by the detoxify library.

    Parameters:
        ds_split (dict): A dictionary containing the dataset split for supervised fine-tuning (SFT),
                            reward modeling (RM), and reinforcement learning (RL).
        cutoff (float, optional): The toxicity score threshold to consider examples as toxic.
            Default is 0.1.
        output_dir (str, optional): The directory where the detoxified datasets will be saved.
            Default is 'nontox'.
        save_file (bool, optional): If True, saves the detoxified datasets to disk. Default is True.
        overwrite (bool, optional): If True, overwrites detoxified cleaned datasets in the output directory.
            Default is False.
        log_to_wandb (bool, optional): If True, logs the detoxified datasets as a WandB artifact.
            Default is True.
    Returns:
        dict: A dictionary containing the detoxified datasets for SFT, RM, and RL.
    """
    # load the toxicity model
    from detoxify import Detoxify
    detoxify_model = Detoxify('unbiased')
    detoxify_model.model.to("cuda" if torch.cuda.is_available() else "cpu")

    # Check if the embedded datasets already exist in the output directory and overwrite is False.
    if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT', 'RM', 'RL'])
        and not overwrite):

        ds_nontox = {}

        # Load the embedded datasets from disk and return them.
        for subset in ['SFT', 'RM', 'RL']:
            ds_embedded[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
        return ds_nontox

    # ds_nontox is a dictionary which contains DatasetDicts as values.
    ds_nontox = {}
    for split in ds_split:
        if split in ['SFT', 'RL']:
            # SFT/RL splits: predict toxicity scores for each example's answers
            ds_split[split] = ds_split[split].map(lambda x: \
                                              {'toxicity_scores': \
                                               [detoxify_model.predict(answer) \
                                                for answer in x['answers']['text']]}, \
                                              batched=True, batch_size=64)
            # only keep nontoxic answers
            ds_nontox[split] = ds_split.map(lambda x: detox_answers(x, cutoff))
            # drop examples with no nontoxic answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['answers'])>=1)

        elif split == 'RM':
            # RM split: predict toxicity scores for each example's first answer
            ds_split[split] = ds_split[split].map(lambda x: \
                                              {'toxicity_scores': \
                                               [detoxify_model.predict(pair[0]) \
                                                for pair in x['pairs_text']]}, \
                                              batched=True, batch_size=64)
            # only keep pairs with nontoxic preferred answers
            ds_nontox[split] = ds_split.map(lambda x: detox_answers_RM(x, cutoff))
            # drop examples with no pairs with nontoxic preferred answers
            ds_nontox[split] = ds_nontox[split].filter(lambda x: len(x['pairs_text'])>=1)

        # Save the detoxed datasets to disk.
        if save_file:
            for subset in ['SFT', 'RM', 'RL']:
                ds_nontox[subset].save_to_disk(f'./data/{output_dir}/ds_{subset}')

            # Log the cleaned datasets as a WandB artifact if log_to_wandb is True.
            if log_to_wandb:
                now = datetime.now()
                time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
                with wandb.init(project='ELI5_analysis',
                                entity='ft-llmmm',
                                job_type='detox_data',
                                name=f'detox_data_{time_stamp}') as run:

                    nontox_data_art = wandb.Artifact('ELI5_detox', 'dataset')
                    nontox_data_art.add_dir(f'./data/{output_dir}')
                    run.log_artifact(nontox_data_art)

    return ds_nontox

In [None]:
!pip install -U transformers --quiet

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
detoxify 0.5.1 requires transformers==4.22.1, but you have transformers 4.34.1 which is incompatible.[0m[31m
[0m

In [None]:
from detoxify import Detoxify

In [None]:
detoxify_model = Detoxify('unbiased')

Downloading (…)lve/main/config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

In [None]:
from datasets import load_dataset, DatasetDict

base_directory = "data/cleaned"  # Replace with your dataset directory path

splits = ["SFT", "RM", "RL"]
ds_cleaned = {}

for split in ["SFT", "RM", "RL"]:
    if split in ["SFT", "RM"]:
        filename = 'data-00000-of-00001.arrow'
        split_ds = load_dataset("arrow",
                                            data_files={"train": base_directory + f"/ds_{split}/train/{filename}",
                                                        "test": base_directory + f"/ds_{split}/test/{filename}",
                                                        "validation": base_directory + f"/ds_{split}/validation/{filename}",
                                            }
        )
    elif split == "RL":
        filenames = ["data-00000-of-00002.arrow", "data-00001-of-00002.arrow"]
        split_ds = load_dataset("arrow",
                                data_files=[base_directory + f"/ds_{split}/{filenames[i]}" for i in range(len(filenames))]
        )

    # Store in the master dictionary
    ds_cleaned[split] = split_ds

print(ds_cleaned)


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

{'SFT': DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu_emb'],
        num_rows: 39855
    })
    test: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu_emb'],
        num_rows: 2077
    })
    validation: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu_emb'],
        num_rows: 756
    })
}), 'RM': DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu

In [None]:
ds_clean = ds_cleaned

In [None]:
ds_clean['SFT']['train'][4]['answers']['text']

["The Tiger tank was the main tank used by Germany during WWII German tanks ran on diesel. It took 5 Sherman's to destroy a single German tank. Basically all those stupid WWII tank myths",
 'My biggest pet peeve is about how historians are thought of. I have a BA in history but work in IT. I\'ve had multiple people ask me why I went to school for it, implying its not a "useful" degree. Forget that it teaches us to think logically and examine issues from multiple angles. Or that it teaches us to express ourselves clearly and concisely. It clearly isn\'t as "good" a degree as business management or some other nonsense that exists solely to move money from the bank accounts of others to yourself. Just as annoying though is when I meet people and they find out I have a history degree and they assume that makes me an expert on whatever historical topic they want to ask questions about. It doesn\'t give actual historians any credit. I would never call myself a historian, I\'m a dabbler. You 

In [None]:
detoxify_model.predict([answer for answer in ds_clean['SFT']['train'][4]['answers']['text']])

{'toxicity': [0.9293162822723389, 0.0044317953288555145, 0.2879162132740021],
 'severe_toxicity': [3.061071765841916e-05,
  2.07321363632218e-06,
  0.0001749552320688963],
 'obscene': [0.003470703726634383, 0.0002042113192146644, 0.13944761455059052],
 'identity_attack': [0.0009943068725988269,
  0.00010600124369375408,
  0.0023303141351789236],
 'insult': [0.8731713891029358, 0.0020254249684512615, 0.12763100862503052],
 'threat': [9.128066449193284e-05,
  2.2633856133325025e-05,
  0.0001569474843563512],
 'sexual_explicit': [9.2100708570797e-05,
  3.734079291461967e-05,
  0.001785959117114544]}

In [None]:
# Use detoxify to filter toxic posts
print('Filtering toxic posts.....')
ds_nontox = detox_datasets(ds_clean)

Filtering toxic posts.....


Map:   0%|          | 0/39855 [00:00<?, ? examples/s]

TypeError: ignored

# Further Cleaning (scratch)

In [None]:
def remove_more_posts(cleaned_folder = 'cleaned_V3',
                      artifact_name = 'ELI5_cleaned'):

    run = wandb.init()
    artifact = run.use_artifact('ft-llmmm/ELI5_analysis/ELI5_cleaned:v2', type='dataset')
    artifact_dir = artifact.download()

    ds={}

    for key in ['SFT','RL','RM']:
        ds[key] = load_from_disk(f'{artifact_dir}/ds_{key}')

    filter_words = {'mod post','mods','moderator','meta ',
                '[meta]','ask me anything','meetup','floating feature','twenty-year rule',
                'askHistorians podcast episode','default subreddit',
                'state of the subreddit','Rules Roundtable'}

    for key in ['SFT','RL','RM']:
        ds[key] = ds[key].filter(lambda x: not any(word in x['title_body'].lower()
                                            for word in filter_words))

    for key in ['SFT','RL','RM']:
        ds[key].save_to_disk(f'./data/{cleaned_folder}/ds_{key}')

    now = datetime.now()
    file_name = './data/{cleaned_folder}'
    time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

    with wandb.init(project='ELI5_analysis',
                                entity='ft-llmmm',
                                job_type='log_data',
                                name=f'ELI5_cleaning_{time_stamp}') as run:
                    # Initialize a WandB run for logging
                    data_art = wandb.Artifact(artifact_name, 'dataset')
                    data_art.add_dir(file_name)
                    run.log_artifact(data_art)

# Scratch (WIP)

## Detoxify RM

In [None]:
!pip install detoxify

from detoxify import Detoxify
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
detoxify_model = Detoxify('unbiased')
detoxify_model.model.to(device)

In [None]:
ELI5_RM_ds = datasets.load_from_disk(f'./data/cleaned_V3/ds_RM')

ELI5_RM_ds.set_format('pandas')
ELI5_RM_ds=ELI5_RM_ds.flatten()
ELI5_RM_ds.set_format(None)

In [None]:
ELI5_RM_ds = ELI5_RM_ds.map(
    lambda x: {'toxicity_scores':
     [detoxify_model.predict(answer) for answer in x['answers.text']]},
    batched=True,batch_size=64)

In [None]:
metrics = ['identity_attack','insult',
           'obscene','severe_toxicity',
           'sexual_explicit','threat',
           'toxicity']

In [None]:
answer_feats = [col for col in ELI5_RM_ds['train'][0].keys() if
                'answer' in col]

ELI5_RM_ds_tox_scores = ELI5_RM_ds.map(lambda x:
                        {'non_tox_bool': [True if all(x['toxicity_scores'][metric][i]<=0.1 for metric in metrics) else False for
                                        i in range(len(x['answers.score']))]})

ELI5_RM_non_tox = ELI5_RM_ds_tox_scores.map(
    lambda x: {answer_feat:list(compress(x[answer_feat],
                                    x['non_tox_bool']))
    for answer_feat in answer_feats}
)

ELI5_RM_filt = ELI5_RM_non_tox.filter(lambda x:len(x['answers.score'])>=2)

In [None]:
ELI5_RM_filt.save_to_disk('./data/RM_non_toxic')

Saving the dataset (0/1 shards):   0%|          | 0/39811 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1392 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2914 [00:00<?, ? examples/s]

In [None]:
now = datetime.now()
file_name = './data/RM_non_toxic'
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='log_data',
                name=f'RM_non_toxic_{time_stamp}') as run:

                data_art = wandb.Artifact('ELI5_RM_non_toxic', 'dataset')
                data_art.add_dir(file_name)
                run.log_artifact(data_art)