# Dependencies



In [None]:
%cd drive/MyDrive/LLMs/wikipedia

/content/drive/MyDrive/LLMs/wikipedia


In [None]:
!pip install datasets
!pip install textstat
!pip install wandb
!pip install apache-beam
!pip install transformers[torch] evaluate
!pip install mwparserfromhell
!pip install openai
!pip install wandb
!pip install tiktoken

In [None]:
# Importing necessary libraries

from datasets import load_dataset, load_from_disk, DatasetDict  # Importing functions to load datasets and dataset dictionaries
from textstat import flesch_reading_ease as fre, flesch_kincaid_grade as fkg  # Importing functions for text readability metrics
from tqdm import tqdm  # Importing tqdm for progress bar display
import pandas as pd  # Importing pandas for data manipulation and analysis
import matplotlib.pyplot as plt  # Importing matplotlib for data visualization
import re  # Importing the regular expression module for string pattern matching and manipulation

import os  # Importing the os module to interact with the operating system
import openai  # Importing the OpenAI library for accessing AI models and APIs
import json  # Importing the json module for working with JSON data
from getpass import getpass  # Importing getpass to securely get password input from the user

from scipy.stats import kstest  # Importing kstest from scipy.stats for the Kolmogorov-Smirnov test
import numpy as np  # Importing numpy for numerical operations
import wandb  # Importing wandb for experiment tracking and visualization
import datetime  # Importing datetime for handling date and time data
from datetime import datetime  # Importing datetime for more date and time functionality
from collections import defaultdict  # Importing defaultdict for creating dictionaries with default values
from time import time, sleep  # Importing time and sleep for measuring time duration

from itertools import product  # Importing itertools for efficient looping and combination generation
import tiktoken  # Importing tiktoken for counting tokens in a text

from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)  # for exponential backoff


In [None]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Definitions

In [None]:
def flesch_scores(example):
    """
    Calculate Flesch Readability scores for each article.

    This function calculates the Flesch Reading Ease and Flesch-Kincaid Grade scores for the given 'example' text.
    The Flesch Reading Ease score measures how easy a text is to read (higher scores are easier to read),
    while the Flesch-Kincaid Grade score estimates the U.S. grade level required to understand the text.

    Parameters:
        example (dict): A dictionary object containing the 'text' key containing the article text.

    Returns:
        dict: The 'example' dictionary updated with two additional keys: 'fre' and 'fkg', representing the respective scores.
    """

    text = example['text']
    # Calculate Flesch Reading Ease and Flesch-Kincaid Grade scores using the 'fre()' and 'fkg()' functions from the 'textstat' library.
    # They are stored in the 'fre' and 'fkg' keys of the 'example' dictionary, respectively.
    example['fre'] = fre(text)
    example['fkg'] = fkg(text)

    # Return the updated 'example' dictionary with the Flesch scores added.
    return example


def remove_text_between_curly_braces(text):
    """
    Removes text enclosed in curly braces from the input text.

    Args:
        text (str): The input text containing curly braces.

    Returns:
        str: The modified text with the content inside curly braces removed.
    """
    # Define the pattern to match text between double curly braces
    pattern_double_curly = r"{{(.|\n)*?}}"
    # Use regular expression substitution to remove text between double curly braces
    text = re.sub(pattern_double_curly, "", text)

    # Define the pattern to match text between single curly braces
    pattern_single_curly = r"{(.|\n)*?}"
    # Use regular expression substitution to remove text between single curly braces
    text = re.sub(pattern_single_curly, "", text)

    return text


def extract_article_text(example):
    """
    Extracts the text from a Wikipedia article before the "Related pages" and "References" sections.

    Parameters:
        example (dict): Dictionary with key 'text' which contains entire Wikipedia article content as a string.

    Returns:
        str: The article text before the "Related pages" and "References" sections.
    """
    text = example['text']

    # Find the positions of the "Related pages" and "References" sections
    txt_lower = text.lower()

    related_pages_position = txt_lower.find('related pages')
    references_position = txt_lower.find('references')

    # Extract the text before the sections
    article_text = ""
    if related_pages_position != -1:
        # If the "Related pages" section is found, extract the text before it
        article_text = text[:related_pages_position]
    elif references_position != -1:
        # If the "References" section is found (but not the "Related pages" section),
        # extract the text before it
        article_text = text[:references_position]
    else:
        # If both sections are not present, return the entire text
        article_text = text

    article_text = article_text.replace('  ',' ')
    article_text = remove_text_between_curly_braces(article_text)

    example['text'] = article_text

    return example

class TruncArticleWrapper:
    """
    Wrapper class for truncating articles.

    Attributes:
        cutoff_length (int): Truncate article once length >= cutoff_length. Default is 300.
    """

    def __init__(self, cutoff_length=300):
        """
        Initializes the TruncArticleWrapper object.

        Args:
            cutoff_length (int, optional): Cutoff target for truncated text. Default is 300.
                                           Stop adding to article once cutoff target is reached.
        """
        self.cutoff_length = cutoff_length

    def trunc_article_func(self, example):
        """
        Truncates the text of an article.

        Args:
            example (dict): A dictionary containing the article text.

        Returns:
            dict: The dictionary with the truncated text stored under the key 'trunc_text'.
        """

        text = example['text']
        paragraphs = text.split('\n\n')

        trunc_text = paragraphs[0]
        i = 1

        while len(trunc_text) <= self.cutoff_length and i < len(paragraphs):
            if len(paragraphs[i].split()) >= 5:
                trunc_text += ' ' + paragraphs[i]
            i += 1

        if '.' in trunc_text:
            idx = trunc_text[::-1].index('.')
            trunc_text = trunc_text[:len(trunc_text)-idx]

        example['trunc_text'] = trunc_text
        return example


def trunc_article(dataset, cutoff_length=300):
    """
    Applies TruncArticleWrapper.trunc_article_func to a dataset.

    Args:
        dataset (Dataset): The dataset containing articles.
        cutoff_length (int, optional): Cutoff length truncated text. Default is 300.

    Returns:
        Dataset: The modified dataset with truncated articles.
    """
    # Create an instance of TruncArticleWrapper with the specified cutoff_length
    f = trunc_article_wrapper(cutoff_length).trunc_article_func

    # Apply the truncation function to each article in the dataset
    dataset = dataset.map(lambda article: f(article))

    return dataset


def create_prompt(prompt_start, text, prompt_end):
    """
    Creates a prompt by combining a start phrase, main text, and an end phrase.

    Args:
        prompt_start (str): The initial part of the prompt.
        text (str): The main text or content to be included in the prompt.
        prompt_end (str): The concluding part of the prompt.

    Returns:
        str: The combined prompt string.
    """
    # Combine the provided parts into a single prompt string
    prompt = f"{prompt_start} {text} {prompt_end}"
    return prompt


def GPT_custom_prompt(message, prompt_start, text, prompt_end, model_engine='gpt-3.5-turbo'):
    """
    Generate a custom prompt for a GPT-based language model and retrieve its response.

    Args:
        message (str): A system-level instruction or context message.
        prompt_start (str): The initial part of the prompt.
        text (str): The main text or content to be included in the prompt.
        prompt_end (str): The concluding part of the prompt.
        model_engine (str, optional): The name or ID of the GPT model to be used. Default is 'gpt-3.5-turbo'.

    Returns:
        str: The generated instruction from the model.
    """
    # Create a custom prompt using provided components
    prompt = create_prompt(prompt_start, text, prompt_end)

    # Make a request to OpenAI's ChatCompletion API
    response = openai.ChatCompletion.create(
        model=model_engine,
        messages=[
            {"role": "system", "content": message},
            {"role": "user", "content": prompt},
        ],
        temperature=0
    )

    # Extract the generated instruction from the response
    instruction = response['choices'][0]['message']['content']

    return instruction


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def form_question(row, text_key='trunc_text'):
    """
    Formulate a question using a GPT-based language model with retry functionality.

    Args:
        row (dict): A dictionary containing the article information.
        text_key (str, optional): The key to access the text in the row dictionary. Default is 'trunc_text'.

    Returns:
        str: The generated question.
    """
    # Generate a custom prompt for forming a question
    question = GPT_custom_prompt(message, prompt_template[0], row[text_key], prompt_template[1])

    return question


def make_instruct_dataset(dataset,
                          message,
                          prompt_template,
                          split='train',
                          text_key='trunc_text',
                          length=10**3,
                          file_name_prefix='df_wiki_questions'):
    """
    Create an instruction dataset by generating questions from a given dataset.

    Args:
        dataset (datasets.DatasetDict): The Huggingface dataset containing articles.
        message (str): The system-level instruction or context message.
        prompt_template (list): A list of two elements representing the prompt start and end.
        split (str, optional): The split of the dataset to use (e.g., 'train', 'validation'). Default is 'train'.
        text_key (str, optional): The key to access the text in the row dictionary. Default is 'trunc_text'.
        length (int, optional): The desired length of the instruction dataset. Default is 1000.
        file_name_prefix (str, optional): The output file name prefix for the instruction dataset. Default is 'df_wiki_questions'.
    """
    # Define the output file path
    output_file = f'./data/{file_name_prefix}_{split}.csv'

    # Check if the output file already exists
    if os.path.exists(output_file):
        print('reloading file')
        # If it does, load the existing DataFrame
        df = pd.read_csv(output_file,
                         index_col='Unnamed: 0')
        # Get the previous length of the DataFrame
        prev_length = len(df)

        print(f'length of dataset is {prev_length}')

    else:
        # If the output file doesn't exist, create a new DataFrame
        df = pd.DataFrame(columns=['id',
                                   'system_message',
                                   'prompt_template',
                                   'question'])
        prev_length = 0

    # Access the dataset split
    ds = dataset[split]

    # Iterate over a range of indices
    for i in tqdm(range(prev_length,length)):
        # Get the row from the dataset
        row = ds[i]

        # Check if the ID is already in the DataFrame, skip if true
        if int(row['id']) in set(df['id']):
            continue

        # Create a temporary dictionary to store information
        temp_dict = {}
        temp_dict['id'] = row['id']
        temp_dict['system_message'] = message
        temp_dict['prompt_template'] = prompt_template
        temp_dict['question'] = 'pass'

        # Generate a question using the 'form_question' function
        temp_dict['question'] = form_question(row, text_key='trunc_text')

        # Add the temporary dictionary as a new row in the DataFrame
        df.loc[len(df)] = temp_dict

        # Save the DataFrame to the output file
        df.to_csv(output_file)


# Filtering Dataset

In [None]:
def filter_simple_wiki_ds(min_length=50,
                          max_length=500,
                          fre_cutoff=60,
                          fkg_cutoff=9,
                          val_size=5000,
                          test_size=5000,
                          output_file='./data/simple_wiki_split_long',
                          artifact_name='simple_wiki_split_long',
                          seed=42):
    """
    Filter and process the Simple English Wikipedia dataset.

    Args:
        min_length (int, optional): Minimum length of articles (in words). Default is 50.
        max_length (int, optional): Maximum length of articles (in words). Default is 500.
        fre_cutoff (int, optional): Flesch readability score cutoff. Default is 60.
        fkg_cutoff (int, optional): Flesch-Kincaid Grade Level score cutoff. Default is 9.
        val_size (int, optional): Size of the validation set. Default is 5000.
        test_size (int, optional): Size of the test set. Default is 5000.
        output_file (str, optional): Output file path for saving the processed dataset. Default is './data/simple_wiki_split_long'.
        artifact_name (str, optional): Name for the saved dataset artifact. Default is 'simple_wiki_split_long'.
        seed (int, optional): Random seed for reproducibility. Default is 42.
    """

    # Load the Simple English Wikipedia dataset
    simple_wiki = load_dataset("wikipedia", "20220301.simple")

    # Extract article text
    simple_wiki_articles = simple_wiki.map(extract_article_text)

    # Truncate articles
    simple_wiki_trunc = trunc_article(simple_wiki_articles)

    # Filter by length
    simple_wiki_trunc = simple_wiki_trunc.filter(lambda x: min_length <= len(x['trunc_text'].split()) <= max_length)

    # Calculate Flesch scores
    simple_wiki_scored = simple_wiki_trunc.map(flesch_scores)

    # Filter by readability scores
    simple_wiki_filt = simple_wiki_scored.filter(lambda article: article['fre'] >= fre_cutoff and article['fkg'] < fkg_cutoff)

    # Filter out articles by specific characters to remove wiki metadata.
    simple_wiki_filt = simple_wiki_filt.filter(lambda article: not any(spec_chr in article['trunc_text']
                                                                    for spec_chr in {'{{', 'class=', 'infobox'}))

    # Initialize a DatasetDict to store split data
    simple_wiki_split_long = DatasetDict()

    # Split dataset into train, validation, and test sets
    split_dataset_tmp = simple_wiki_filt['train'].train_test_split(val_size, seed=seed)

    simple_wiki_split_long = split_dataset_tmp['train'].train_test_split(test_size, seed=seed)
    simple_wiki_split_long['validation'] = split_dataset_tmp['test']

    # Save the processed dataset to disk
    simple_wiki_split_long.save_to_disk(output_file)

    # Log the dataset artifact using WandB
    now = datetime.now()
    time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
    with wandb.init(project='ELI5_analysis',
                    entity='ft-llmmm',
                    job_type='log_data',
                    name=f'wiki_long_articles_{time_stamp}') as run:
        data_art = wandb.Artifact(artifact_name, 'dataset')
        data_art.add_dir(output_file)
        run.log_artifact(data_art)


In [None]:
filter_simple_wiki_ds()

[34m[1mwandb[0m: Adding directory to artifact (./data/simple_wiki_split_long)... Done. 0.9s


# Make instruction dataset

In [None]:
# Load existing filtered dataset.
simple_wiki_split_long = load_from_disk('./data/simple_wiki_split_long')

In [None]:
# Enter openai API key to use GPT-3.5.
openai.api_key = getpass('Enter your OpenAI key: ')

Enter your OpenAI key: ··········


In [None]:
def make_instruction_qa_all_splits(prompt_template,
                                   system_message,
                                   model_engine='gpt-3.5-turbo',
                                   file_name_prefix='df_simple_wiki_long_qa'):
    """
    Generate instruction-based QA datasets for all splits.

    Args:
        prompt_template (list): A list of two elements representing the prompt start and end.
        system_message (str): The system-level instruction or context message.
        model_engine (str, optional): The name or ID of the GPT model to be used. Default is 'gpt-3.5-turbo'.
        file_name_prefix (str, optional): Prefix for the output file names. Default is 'df_simple_wiki_long_qa'.
    """
    # Loop through all splits: train, validation, and test
    for split in ['train', 'validation', 'test']:

        # Get the total length of articles in the split
        tot_length = len(simple_wiki_split_long[split])

        # Generate instruction-based QA dataset for the current split
        make_instruct_dataset(simple_wiki_split_long,
                              message=system_message,
                              prompt_template=prompt_template,
                              split=split,
                              text_key='trunc_text',
                              length=tot_length,
                              file_name_prefix=file_name_prefix)


In [None]:
# Create splits for instruction dataset using below system message and prompt.

message = 'You are a helpful assistant that generates questions from text.'
prompt_template = ("Question: X\nAnswer:", "\nWhat kind of question, X, could this be an answer to?\nX:")

make_instruction_qa_all_split(prompt_template,
                              message)

In [None]:
def combine_wiki_questions_answers(output_file_prefix='simple_wiki_QA_combined',
                                   artifact_name='simple_wiki_QA',
                                   qa_file_name_prefix='df_simple_wiki_long_qa'):
    """
    Combine questions and answers datasets and log the resulting dataset as a WandB artifact.

    Args:
        output_file_prefix (str, optional): Prefix for the output file names. Default is 'simple_wiki_QA_combined'.
        artifact_name (str, optional): Name for the saved dataset artifact. Default is 'simple_wiki_QA'.
        qa_file_name_prefix (str, optional): Prefix for the question-answer file names. Default is 'df_simple_wiki_long_qa'.
    """
    # Initialize dictionaries to store DataFrames for answers, questions, and combined QA
    df_answers = {}
    df_questions = {}
    df_qa = {}

    # Loop through all splits: train, validation, and test
    for split in ['train', 'validation', 'test']:
        # Load answers DataFrame for the split
        df_answers[split] = pd.DataFrame(simple_wiki_split_long[split])

        # Convert 'id' column to integer type and set it as index
        df_answers[split]['id'] = df_answers[split]['id'].astype(int)
        df_answers[split].set_index('id', inplace=True)

        # Load questions DataFrame for the split
        df_questions[split] = pd.read_csv(f'./data/{qa_file_name_prefix}_{split}.csv',
                                          index_col='Unnamed: 0')

        # Convert 'id' column to integer type and set it as index
        df_questions[split]['id'] = df_questions[split]['id'].astype(int)
        df_questions[split].set_index('id', inplace=True)

        # Join questions and answers DataFrames on 'id' column
        df_qa[split] = df_questions[split].join(df_answers[split][['trunc_text']],
                                               how='left', on='id', lsuffix='l', rsuffix='r')

        # Save the combined QA DataFrame to a CSV file
        df_qa[split].to_csv(f'./data/{output_file_prefix}/{output_file_prefix}_{split}.csv')

    # Get current date and time for the artifact name
    now = datetime.now()
    file_name = f'./data/{output_file_prefix}'
    time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

    # Initialize a WandB run for logging
    with wandb.init(project='ELI5_analysis',
                    entity='ft-llmmm',
                    job_type='log_data',
                    name=f'simple_wiki_QA_{time_stamp}') as run:
        # Create a WandB artifact for the combined dataset
        data_art = wandb.Artifact(artifact_name, 'dataset')
        data_art.add_dir(file_name)
        # Log the artifact to the run
        run.log_artifact(data_art)


In [None]:
# Call function defined above to combine question and answer dataframes for simple wikipedia.
combine_wiki_questions_answers()

# Additional Cleaning

In [None]:
def remove_wiki_tables(output_file_prefix='simple_wiki_QA_combined', artifact_name='simple_wiki_QA'):
    """
    Remove rows from QA dataset containing tables based on the presence of 'colspan' in the 'trunc_text' column.

    Args:
        output_file_prefix (str, optional): Prefix for the output file names. Default is 'simple_wiki_QA_combined'.
        artifact_name (str, optional): Name for the saved dataset artifact. Default is 'simple_wiki_QA'.
    """
    # Initialize dictionary to store DataFrames for QA
    df_qa = {}

    # Load QA DataFrames for all splits: train, validation, and test
    for split in ['train', 'validation', 'test']:
        df_qa[split] = pd.read_csv(f'./data/{output_file_prefix}/{output_file_prefix}_{split}.csv')

    # Iterate through all splits again
    for split in ['train', 'validation', 'test']:
        # Check for 'colspan' in 'trunc_text' and get the corresponding row indices
        colspan_ids = df_qa[split]['trunc_text'].str.contains('colspan')
        remove_ids = df_qa[split][colspan_ids].index

        # Remove rows with 'colspan' from the DataFrame
        df_qa[split].drop(index=remove_ids, inplace=True)

        # Save the modified DataFrame back to the file
        df_qa[split].to_csv(f'./data/{output_file_prefix}/{output_file_prefix}_{split}.csv')

    # Get current date and time for the artifact name
    now = datetime.now()
    file_name = './data/{output_file_prefix}'
    time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

    # Initialize a WandB run for logging
    with wandb.init(project='ELI5_analysis',
                    entity='ft-llmmm',
                    job_type='log_data',
                    name=f'simple_wiki_QA_{time_stamp}') as run:
        # Create a WandB artifact for the modified dataset
        data_art = wandb.Artifact(artifact_name, 'dataset')
        data_art.add_dir(file_name)
        # Log the artifact to the run
        run.log_artifact(data_art)

In [None]:
# Call above function to remove wiki articles containing tables.

remove_wiki_tables()

[34m[1mwandb[0m: Adding directory to artifact (./data/simple_wiki_QA_combined)... Done. 0.7s


# Scratch

## Compare with English Wiki

In [None]:
english_wiki = load_dataset("wikipedia", "20220301.en")

Downloading builder script:   0%|          | 0.00/35.9k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/30.4k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/15.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/20.3G [00:00<?, ?B/s]

In [None]:
english_wiki.save_to_disk('./data/english_wiki')

Saving the dataset (0/41 shards):   0%|          | 0/6458670 [00:00<?, ? examples/s]

In [None]:
simple_wiki_split_long = load_from_disk('./data/simple_wiki_split_long')

In [None]:
shared_titles = {}
for split in ['train','validation','test']:
    shared_titles[split] = set(simple_wiki_split_long[split]['title']).intersection(english_wiki['train']['title'])

In [None]:
type(shared_titles['train'])

set

In [None]:
english_wiki_shared = DatasetDict()
simple_wiki_shared = DatasetDict()

for split in ['train','validation','test']:
    print(f'working on split {split}')

    shared = shared_titles[split]

    english_wiki_shared[split] = english_wiki['train'].filter(lambda x:x['title'] in shared)

    simple_wiki_shared[split] = simple_wiki_split_long[split].filter(lambda x:x['title'] in shared)

working on split train
working on split validation


Filter:   0%|          | 0/6458670 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5000 [00:00<?, ? examples/s]

working on split test


Filter:   0%|          | 0/6458670 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [None]:
import os

os.mkdir('./data/simple_en_overlap')

In [None]:
english_wiki_shared.save_to_disk('./data/simple_en_overlap/english_shared')
simple_wiki_shared.save_to_disk('./data/simple_en_overlap/simple_shared')

Saving the dataset (0/1 shards):   0%|          | 0/57518 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4403 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4377 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/57518 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4403 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4377 [00:00<?, ? examples/s]

In [None]:
now = datetime.now()
file_name = './data/simple_en_overlap'
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='log_data',
                            name=f'shared_articles_en_simple_{time_stamp}') as run:
                # Initialize a WandB run for logging
                data_art = wandb.Artifact('shared_articles_en_simple', 'dataset')
                data_art.add_dir(file_name)
                run.log_artifact(data_art)

[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (./data/simple_en_overlap)... Done. 6.4s
