In [1]:
import pandas as pd
from transformers import GPT2TokenizerFast

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def check_and_add_fix(
    df:pd.DataFrame,
    column:str,
    fix:str,
    option:str
) -> pd.DataFrame:
    
    # get indices with no prefix/suffix
    if option == 'prefix':
        no_fix = df[column].str[:len(fix)] != fix
    elif option == 'suffix':
        no_fix = df[column].str[-len(fix):] != fix
    no_fix_index = df[no_fix].index.to_list()
    print(f'The indices of the samples that do not have the {option}: {no_fix_index}')

    if len(no_fix_index) > 0:
        # add prefix/suffix
        print(f'Adding the {option} to them')
        if option == 'prefix':
            df.loc[no_fix, column] = df.loc[no_fix, column].apply(lambda c: f'{fix}{c}')
        elif option == 'suffix':
            df.loc[no_fix, column] = df.loc[no_fix, column].apply(lambda c: f'{c}{fix}')

    return df

In [3]:
def get_first_n_tokens(
    text:str,
    n:int,
    tokenizer:GPT2TokenizerFast,
):
    
    return tokenizer.decode(tokenizer.encode(text)[:n])

In [4]:
def prepare_dataset(
    df:pd.DataFrame, 
    completion_prefix:str=None, 
    completion_suffix:str=None,
    prompt_prefix:str=None,
    prompt_suffix:str=None,
):

    # check if required columns are there
    assert 'completion' in df.columns, "We need a column named 'completion'"
    assert 'prompt' in df.columns, "We need a column named 'prompt'"

    # get first 1800 tokens of prompts
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    df['prompt'] = df['prompt'].apply(lambda c: get_first_n_tokens(c, 1800, tokenizer))

    print('Formatting completions')
    # prepare completion
    if completion_prefix != None:
        df = check_and_add_fix(df, 'completion', completion_prefix, 'prefix')
    if completion_suffix != None:
        df = check_and_add_fix(df, 'completion', completion_suffix, 'suffix')

    print('Formatting prompts')
    # prepare prompt
    if prompt_prefix != None:
        df = check_and_add_fix(df, 'prompt', prompt_prefix, 'prefix')
    if prompt_suffix != None:
        df = check_and_add_fix(df, 'prompt', prompt_suffix, 'suffix')

    return df[['prompt', 'completion']]