<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/data_analysis/ELI5_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [None]:
%cd drive/MyDrive/LLMs/ELI5_dataset

/content/drive/MyDrive/LLMs/ELI5_dataset


In [None]:
%cd drive/MyDrive/LLMs/ELI5_dataset

!pip install datasets --quiet
!pip install textstat --quiet
!pip install wandb --quiet
!pip install redditcleaner --quiet
!pip install huggingface_hub --quiet
!pip install -U sentence-transformers --quiet

In [None]:
import wandb, torch
import sys
import datasets
import os
import redditcleaner
import re
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from sentence_transformers import SentenceTransformer
from textstat import flesch_reading_ease as fre
from textstat import flesch_kincaid_grade as fkg
from datasets import (load_dataset,
                      load,
                      load_from_disk,
                      Dataset,
                      concatenate_datasets,
                      DatasetDict)
from itertools import compress
from tqdm import tqdm
from collections import defaultdict
from itertools import combinations
import random
from datetime import datetime
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

%matplotlib inline

#Definitions

In [None]:
def replace_url_i(text):
    # Define the regular expression pattern to match "_url_i_" where i is an arbitrary integer

    patterns = [r"_url_\d+_",r"_Url_\d+_",r"_URL_\d+_"]

    # Use re.sub() to replace all occurrences of the pattern with an empty string
    for pattern in patterns:
        text = re.sub(pattern, "", text)

    return text

def preprocess_example(example):

    answers = example['answers']['text']
    answers = [redditcleaner.clean(answer) for answer in answers]
    answers = [re.sub('>.*?\n',' ',answer) for answer in answers]
    answers = [' '.join(answer.lower().split()) for answer in answers]
    answers = [replace_url_i(answer) for answer in answers]
    answers = [answer for answer in answers if len(answer.split())>=20]
    example['answers']['text'] = answers

    title = example['title']
    title = redditcleaner.clean(title)
    title = ' '.join(title.split())
    title = replace_url_i(title)
    example['title'] = title

    selftext = example['selftext']
    selftext = redditcleaner.clean(selftext)
    selftext = ' '.join(selftext.lower().split())
    selftext = replace_url_i(selftext)
    example['selftext'] = selftext

    return example


class score_cutoff_wrapper:
    def __init__(self,cutoff):
        self.cutoff = cutoff

    def score_cutoff_ex(self,example):
        scores = example['answers']['score']
        idxs = list(np.array(scores) >= self.cutoff)
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val,idxs))

        return example


def score_cutoff(dataset,cutoff=4):
    cutoff = score_cutoff_wrapper(cutoff)
    ds = dataset.map(cutoff.score_cutoff_ex)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds

def flesch_scores(example):

    fre_scores = [fre(text) for text in example['answers']['text']]
    fkg_scores = [fkg(text) for text in example['answers']['text']]
    example['answers']['fre'] = fre_scores
    example['answers']['fkg'] = fkg_scores

    return example

class flesch_scores_filter_wrapper:
    def __init__(self,fre_cutoff, fkg_cutoff):
        self.fre_cutoff = fre_cutoff
        self.fkg_cutoff = fkg_cutoff

    def flesch_scores_filter(self,example):

        fre_scores = example['answers']['fre']
        fkg_scores = example['answers']['fkg']

        idxs = [True if (fre_scores[i]>=self.fre_cutoff
                         and fkg_scores[i]<self.fkg_cutoff) else False
                for i in range(len(fre_scores))]

        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val,idxs))

        return example

def flesch_scores_cutoff(dataset,fre_cutoff=60,fkg_cutoff=9):
    filter = flesch_scores_filter_wrapper(fre_cutoff, fkg_cutoff)
    ds = dataset.map(filter.flesch_scores_filter)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds

def preprocess_data(dataset,
                    output_file = './data/filtered',
                    save_file = True,
                    log_to_wandb = True,
                    overwrite = False):

    if os.path.exists(output_file) and not overwrite:
        return load_from_disk(output_file)

    not_qus = ['IAMA','AMA','ama:','megathread','Megathread',
           'Discussion Thread','Discussion thread',
           'discussion Thread','discussion thread',
           'Ask Anything Wednesday','Free-for-All',
           'Free-For-All','[META]','Monday Methods',
           'Tuesday Trivia','Monday Mysteries',
           'Theory Thursday','Monday Mish-Mash',
           'Media Mondays','[META]','Wednesday Week in History',
           'Saturday Popular Questions','Ask Anything Wednesday',
           'Thursday Focus Historical Fiction']

    qu_reqs = ['who','what','where','why','when','how','?']

    dataset = dataset.map(preprocess_example)
    dataset = dataset.filter(lambda post: 'nsfw' not in post['title'].lower())

    dataset = dataset.filter(lambda post:
                             not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                             and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

    dataset = dataset.filter(lambda post:
                             not (any(nq in post['title'] for nq in not_qus)))

    dataset = dataset.map(flesch_scores)

    dataset = score_cutoff(dataset)
    dataset = flesch_scores_cutoff(dataset)

    if save_file:
        dataset.save_to_disk(output_file)

        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='preprocess_data',
                            name=f'preprocess_data_{time_stamp}') as run:


                processed_data_art=wandb.Artifact('ELI5_processed','dataset')
                processed_data_art.add_dir(output_file)
                run.log_artifact(processed_data_art)

    return dataset

def split_idxs(example):
    scores = example['answers']['score']
    scores_unique = sorted(set(scores),reverse=True)
    pref_scores_idxs = [scores.index(sc)for sc in scores_unique]
    dupl_scores_idxs = [n for n in range(len(scores)) if n not in pref_scores_idxs]

    example['pref_idxs'] = pref_scores_idxs
    example['dupl_scores_idxs'] = dupl_scores_idxs

    return example

def mult_ans_RM_proc(example):
    pref_scores_idxs = example['pref_idxs']
    for key, val in example['answers'].items():
        example['answers'][key] = [example['answers'][key][i] for i in pref_scores_idxs]
    return example

def mult_ans_SFT_proc(example):
    dupl_scores_idxs = example['dupl_scores_idxs']
    for key, val in example['answers'].items():
        example['answers'][key] = [example['answers'][key][i] for i in dupl_scores_idxs]
    return example

def split_ds(ds_original,
             ds_filtered,
             output_dir='ds_split',
             save_file=True,
             log_to_wandb = True,
             overwrite = False):

    if (all(os.path.exists(f'./data/{output_dir}/{split}') for split in ['ds_SFT','ds_RM','ds_RL'])
        and not overwrite):

        ds_split = {}

        ds_split['SFT'] = load_from_disk(f'./data/{output_dir}/ds_SFT')
        ds_split['RM'] = load_from_disk(f'./data/{output_dir}/ds_RM')
        ds_split['RL'] = load_from_disk(f'./data/{output_dir}/ds_RL')

        return ds_split

    ds_split = {}

    ds_mult = ds_filtered.filter(lambda post : len(post['answers']['score'])>=2)
    ds_sing = ds_filtered.filter(lambda post : len(post['answers']['score'])==1)

    ds_mult_indexed = ds_mult.map(split_idxs)

    ds_split['RM'] = ds_mult_indexed.map(mult_ans_RM_proc)
    ds_split['RM'] = ds_split['RM'].filter(lambda x: len(x['answers']['score'])>0)

    ds_SFT_mult = ds_mult_indexed.map(mult_ans_SFT_proc)
    ds_SFT_mult = ds_SFT_mult.filter(lambda x: len(x['answers']['score'])>0)

    ds_split['SFT'] = datasets.DatasetDict()

    for key in ['train','validation','test']:
        ds_split['SFT'][key] = datasets.concatenate_datasets([ds_SFT_mult[key],
                                                     ds_sing[key]])

    q_ids_taken = []

    for ds_ in (ds_split['SFT'],ds_split['RM']):
        for split in ds_:
            q_ids_taken.extend(ds_[split]['q_id'])

    q_ids_taken = set(q_ids_taken)

    ds_split['RL'] = ds_original.filter(lambda post: post['q_id'] not in q_ids_taken)
    ds_split['RL'] = concatenate_datasets([ds for ds in ds_split['RL'].values()])

    if save_file:

        for key,value in ds_split.items():
            value.save_to_disk(f'./data/{output_dir}/ds_{key}')

        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='split_data',
                            name=f'split_data_{time_stamp}') as run:


                split_data_art=wandb.Artifact('ELI5_split','dataset')

                split_data_art.add_dir(f'./data/{output_dir}')
                run.log_artifact(split_data_art)

    return ds_split

def combine_title_body(example):
    title = ' '.join(example['title'].split())
    selftext = ' '.join(example['selftext'].split())

    combined = title +'\n'+selftext

    return {'title_body':combined}

def embed_datasets(dataset_split,
                   checkpoint ='all-mpnet-base-v2',
                   output_dir = 'embedded',
                   save_file = True,
                   overwrite = False,
                   log_to_wandb = True):

    if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT','RM','RL'])
        and not overwrite):

        ds_embedded = {}

        for subset in ['SFT','RM','RL']:
            ds_embedded[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
        return ds_embedded

    ds_embedded = {}
    model = SentenceTransformer(checkpoint)

    for key in dataset_split:
        ds_embedded[key] = dataset_split[key].map(combine_title_body)
        ds_embedded[key] = ds_embedded[key].map(lambda x:{'qu_emb':
                                                          model.encode(x['title_body'],
                                                                       batch_size=64)})
    if save_file:

        for key,value in ds_embedded.items():
            value.save_to_disk(f'./data/{output_dir}/ds_{key}')

        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='embed_data',
                            name=f'embed_data_{time_stamp}') as run:


                embed_data_art=wandb.Artifact('ELI5_embedded','dataset')

                embed_data_art.add_dir(f'./data/{output_dir}')
                run.log_artifact(embed_data_art)

    return ds_embedded

def make_pairs(example):
    answers = example['answers']['text']
    scores = example['answers']['score']

    sc_ans = tuple(zip(scores,answers))
    sc_pairs = tuple(combinations(sc_ans,2))

    if len(sc_pairs)>10:
        sc_pairs = random.sample(sc_pairs,10)

    sc_pairs = list(map(lambda x: sorted(x,key=lambda y:y[0],
                                 reverse=True),sc_pairs))

    pairs_text = [(sc_pair[0][1],sc_pair[1][1]) for sc_pair in sc_pairs]

    example['pairs'] = pairs_text

    return example

def clean_datasets(ds_embedded,
                   cutoff = 0.6,
                   batch_size = 5000,
                   output_dir = 'cleaned',
                   save_file=True,
                   overwrite = False,
                   log_to_wandb = True
                   ):

    if (all(os.path.exists(f'./data/{output_dir}/ds_{subset}') for subset in ['SFT','RM','RL'])
        and not overwrite):

        ds_clean = {}

        for subset in ['SFT','RM','RL']:
            ds_clean[subset] = load_from_disk(f'./data/{output_dir}/ds_{subset}')
        return ds_clean

    embed_vecs = {}
    overlaps = {}
    idxs = {}
    splits = ['train','validation','test']
    keep_train = {}
    keep_test = {}
    ds_clean = {}



    for subset in ['SFT',"RM"]:
        print(f'Cleaning {subset} dataset')

        ds_embedded[subset].set_format('torch')
        embed_vecs[subset]={}

        for split in splits:
            embed_vecs[subset][split] = ds_embedded[subset][split]['qu_emb']
            embed_vecs[subset][split] /= torch.sqrt(torch.sum(embed_vecs[subset][split]**2,
                                                           dim=1,
                                                           keepdim=True))

        overlaps[subset] = {}
        idxs[subset] = {}
        for j in range(1,3):
            for i in range(j):

                overlaps[subset][(splits[i],splits[j])] = torch.matmul(
                    embed_vecs[subset][splits[i]],
                    embed_vecs[subset][splits[j]].T
                )

                idxs[subset][(splits[i],splits[j])] = torch.where((overlaps[subset][(splits[i],splits[j])])>=cutoff)

        rm_tr_idxs_temp = idxs[subset]['train','validation'][0].numpy()
        rm_tr_idxs_temp = set(rm_tr_idxs_temp)

        rm_tr_idxs = idxs[subset]['train','test'][0].numpy()
        rm_tr_idxs = set(rm_tr_idxs).union(rm_tr_idxs_temp)

        keep_train = set(range(len(ds_embedded[subset]['train'])))-rm_tr_idxs

        rm_test_idxs = idxs[subset]['validation','test'][1].numpy()
        rm_test_idxs = set(rm_test_idxs)

        keep_test = set(range(len(ds_embedded[subset]['test'])))-rm_test_idxs

        ds_clean[subset] = DatasetDict()

        ds_clean[subset]['train'] = ds_embedded[subset]['train'].select(keep_train)
        ds_clean[subset]['validation'] = ds_embedded[subset]['validation']
        ds_clean[subset]['test'] = ds_embedded[subset]['test'].select(keep_test)


    print(f'Cleaning RL dataset')
    ds_embedded['RL'].set_format('torch')
    embed_vecs['RL'] = ds_embedded['RL']['qu_emb']
    embed_vecs['RL'] /= torch.sqrt(torch.sum(embed_vecs['RL']**2,
                                        dim = 1,
                                        keepdim = True))

    RL_size = len(ds_embedded['RL'])
    rem_RL = set()
    start = 0
    i=0

    num_batches = RL_size//batch_size

    if RL_size%batch_size != 0:
        num_batches += 1

    for k in tqdm(range(num_batches)):

        start = k*batch_size
        end = (k+1)*batch_size

        batch = embed_vecs['RL'][start:start+batch_size,:]

        for subset in ['SFT','RM']:
            for split in ['train','validation']:
                overlap = torch.matmul(embed_vecs[subset][split],
                                       batch.T)
                rem_RL_idxs_temp = torch.where(overlap>=cutoff)[1].numpy()
                rem_RL = rem_RL.union(set(rem_RL_idxs_temp))

    keep_RL = set(range(RL_size))
    keep_RL -= set(rem_RL)

    ds_clean['RL'] = ds_embedded['RL'].select(keep_RL)

    ds_clean['RM'] = ds_clean['RM'].map(lambda x:make_pairs(x))

    if save_file:
        for subset in ['SFT','RM','RL']:
            ds_clean[subset].save_to_disk(f'./data/{output_dir}/ds_{subset}')

        if log_to_wandb:
            now = datetime.now()
            time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
            with wandb.init(project='ELI5_analysis',
                            entity='ft-llmmm',
                            job_type='clean_data',
                            name=f'clean_data_{time_stamp}') as run:


                clean_data_art=wandb.Artifact('ELI5_cleaned','dataset')
                clean_data_art.add_dir(f'./data/{output_dir}')
                run.log_artifact(clean_data_art)


    return ds_clean

# Code

In [None]:
ds_original = load_dataset("vblagoje/lfqa")

In [None]:
ds_filtered = preprocess_data(ds_original)

In [None]:
ds_split = split_ds(ds_original,
                    ds_filtered)

In [None]:
ds_embedded = embed_datasets(ds_split)

In [None]:
ds_clean = clean_datasets(ds_embedded)