<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/data_cleaning/ELI5_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
%cd drive/MyDrive/LLMs/new_dataset

!pip install datasets --quiet
!pip install textstat --quiet
!pip install wandb --quiet
!pip install redditcleaner --quiet
!pip install huggingface_hub --quiet
!pip install -U sentence-transformers --quiet


/content/drive/MyDrive/LLMs/new_dataset
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m492.4/492.4 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.1/105.1 kB[0m [31m779.4 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[

In [2]:
import wandb, torch
import sys
import datasets
import os
import redditcleaner
import re
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from sentence_transformers import SentenceTransformer
from textstat import flesch_reading_ease as fre
from textstat import flesch_kincaid_grade as fkg
from datasets import load_dataset, load, load_from_disk, Dataset
from itertools import compress
from tqdm import tqdm
from collections import defaultdict
from itertools import combinations
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

%matplotlib inline

In [3]:
notebook_login()

#creates the 'results' and 'data' directories, if they don't exist.
if not os.path.exists('results'):
    os.makedirs('results')

if not os.path.exists('data'):
    os.makedirs('data')

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Filtering Data

## Definitions

In [4]:
def replace_url_i(input_string):
    # Define the regular expression pattern to match "_url_i_" where i is an arbitrary integer

    pattern_1 = r"_url_\d+_"  # \d+ matches one or more digits
    pattern_2 = r"_Url_\d+_"
    pattern_3 = r"_URL_\d+_"

    # Use re.sub() to replace all occurrences of the pattern with an empty string
    output_string = re.sub(pattern_1, "", input_string)
    output_string = re.sub(pattern_2, "", output_string)
    output_string = re.sub(pattern_3, "", output_string)

    return output_string

def preprocess_example(example):

    answers = example['answers']['text']
    answers = [redditcleaner.clean(answer) for answer in answers]
    answers = [re.sub('>.*?\n',' ',answer) for answer in answers]
    answers = [' '.join(answer.lower().split()) for answer in answers]
    answers = [replace_url_i(answer) for answer in answers]
    answers = [answer for answer in answers if len(answer.split())>=20]
    example['answers']['text'] = answers

    title = example['title']
    title = redditcleaner.clean(title)
    title = ' '.join(title.split())
    title = replace_url_i(title)
    example['title'] = title

    selftext = example['selftext']
    selftext = redditcleaner.clean(selftext)
    selftext = ' '.join(selftext.lower().split())
    selftext = replace_url_i(selftext)
    example['selftext'] = selftext

    return example

def preprocess_data(dataset):
    dataset = dataset.map(preprocess_example)
    return dataset

class score_cutoff_wrapper:
    def __init__(self,cutoff):
        self.cutoff = cutoff

    def score_cutoff_ex(self,example):
        scores = example['answers']['score']
        idxs = list(np.array(scores) >= self.cutoff)
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val,idxs))

        return example


def score_cutoff(dataset,cutoff):
    cutoff = score_cutoff_wrapper(cutoff)
    ds = dataset.map(cutoff.score_cutoff_ex)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds

def flesch_scores(example):

    fre_scores = [fre(text) for text in example['answers']['text']]
    fkg_scores = [fkg(text) for text in example['answers']['text']]
    example['answers']['fre'] = fre_scores
    example['answers']['fkg'] = fkg_scores

    return example

class flesch_scores_filter_wrapper:
    def __init__(self,fre_cutoff, fkg_cutoff):
        self.fre_cutoff = fre_cutoff
        self.fkg_cutoff = fkg_cutoff

    def flesch_scores_filter(self,example):

        fre_scores = example['answers']['fre']
        fkg_scores = example['answers']['fkg']

        idxs = [True if (fre_scores[i]>=self.fre_cutoff
                         and fkg_scores[i]<self.fkg_cutoff) else False
                for i in range(len(fre_scores))]

        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val,idxs))

        return example

def flesch_scores_cutoff(dataset,fre_cutoff=60,fkg_cutoff=9):
    filter = flesch_scores_filter_wrapper(fre_cutoff, fkg_cutoff)
    ds = dataset.map(filter.flesch_scores_filter)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds

## Code

In [None]:
dataset = load_dataset("vblagoje/lfqa")

In [None]:
dataset_preprocessed = preprocess_data(dataset)
dataset_preprocessed.save_to_disk('./data/preprocessing')

In [None]:
not_qus = ['IAMA','AMA','ama:','megathread','Megathread',
           'Discussion Thread','Discussion thread',
           'discussion Thread','discussion thread',
           'Ask Anything Wednesday','Free-for-All',
           'Free-For-All','[META]','Monday Methods',
           'Tuesday Trivia','Monday Mysteries',
           'Theory Thursday','Monday Mish-Mash',
           'Media Mondays','[META]','Wednesday Week in History',
           'Saturday Popular Questions','Ask Anything Wednesday',
           'Thursday Focus Historical Fiction']

qu_reqs = ['who','what','where','why','when','how','?']

In [None]:
ds_reduced = dataset_preprocessed.filter(lambda post:
                                         not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                                         and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

ds_reduced = ds_reduced.filter(lambda post:
                                       not (any(nq in post['title'] for nq in not_qus)))



In [None]:
ds_reduced = ds_reduced.map(flesch_scores)

ds_reduced.save_to_disk('./data/reduced_dataset')

In [None]:
ds_filtered = score_cutoff(ds_reduced,4)
ds_filtered = flesch_scores_cutoff(ds_filtered,fkg_cutoff=9)

ds_filtered_mult = ds_filtered.filter(lambda post : len(post['answers']['score'])>=2)
ds_filtered_sing = ds_filtered.filter(lambda post : len(post['answers']['score'])==1)

In [None]:
ds_filtered_mult.save_to_disk('./data/filtered/mult_ans')
ds_filtered_sing.save_to_disk('./data/filtered/sing_ans')

In [None]:
ds_ls = dataset.filter(lambda x:np.max(x['answers']['score'])<=3)

In [None]:
ds_ls_preprocessed = preprocess_data(ds_ls)

ds_ls_reduced = ds_ls_preprocessed.filter(lambda post:
                                         not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                                         and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

ds_ls_reduced = ds_ls_reduced.filter(lambda post:
                                       not (any(nq in post['title'] for nq in not_qus)))

ds_ls_reduced = ds_ls_reduced.map(flesch_scores)

ds_ls_filtered = flesch_scores_cutoff(ds_ls_reduced,fkg_cutoff=9)


In [None]:
ds_ls_filtered.save_to_disk('./data/ds_low_scores_filtered')

Saving the dataset (0/1 shards):   0%|          | 0/34603 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/64 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/874 [00:00<?, ? examples/s]

In [None]:
ds_ls_filtered['train'].filter(lambda x:x['subreddit']=='explainlikeimfive')

Filter:   0%|          | 0/34603 [00:00<?, ? examples/s]

Dataset({
    features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls'],
    num_rows: 22751
})

In [None]:
ds_ls_filtered['train'][10]

{'q_id': 'e7ptdq',
 'title': 'why don’t all phone services send texts over the internet (like apple’s imessage does)?',
 'selftext': '',
 'document': '',
 'subreddit': 'explainlikeimfive',
 'url': 'https://www.reddit.com/r/explainlikeimfive/comments/e7ptdq/eli5_why_dont_all_phone_services_send_texts_over/',
 'answers': {'a_id': ['fa48tne'],
  'fkg': [6.2],
  'fre': [74.79],
  'score': [3],
  'text': ["the short answer is that they don't because they didn't originally. short message service-messages, as they are called in the gsm standard among other more recent standards, is a technical feature offered in the communication protocol used to communicate to and from the phones. internet traffic relies on one or several other technical substandards that are also offered in the communication protocols. from the phone operators point of view, a sms is awesome. because they have full control. the problem, if you wish, is that since they have full control, they also have pretty pricey business

# SFT and RM Dataset

## Definitions

In [None]:
def split_idxs(example):
    scores = example['answers']['score']
    scores_unique = sorted(set(scores),reverse=True)
    pref_scores_idxs = [scores.index(sc)for sc in scores_unique]
    dupl_scores_idxs = [n for n in range(len(scores)) if n not in pref_scores_idxs]

    example['pref_idxs'] = pref_scores_idxs
    example['dupl_scores_idxs'] = dupl_scores_idxs

    return example

def mult_ans_RL_proc(example):
    pref_scores_idxs = example['pref_idxs']
    for key, val in example['answers'].items():
        example['answers'][key] = [txt for i,txt in enumerate(example['answers'][key]) if i in pref_scores_idxs]
    return example

def mult_ans_SFT_proc(example):
    dupl_scores_idxs = example['dupl_scores_idxs']
    for key, val in example['answers'].items():
        example['answers'][key] = [txt for i,txt in enumerate(example['answers'][key]) if i in dupl_scores_idxs]
    return example

## Code

In [None]:
ds_filtered_indexed = ds_filtered_mult.map(split_idxs)

In [None]:
ds_RL = ds_filtered_indexed.map(mult_ans_RL_proc)
ds_RL = ds_RL.filter(lambda x: len(x['answers']['score'])>0)

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

Filter:   0%|          | 0/29978 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_SFT_mult = ds_filtered_indexed.map(mult_ans_SFT_proc)
ds_SFT_mult = ds_SFT_mult.filter(lambda x: len(x['answers']['score'])>0)

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

Filter:   0%|          | 0/29978 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_SFT = datasets.DatasetDict()

In [None]:
for key in ['train','validation','test']:
    ds_SFT[key] = datasets.concatenate_datasets([ds_SFT_mult[key], ds_filtered_sing[key]])

In [None]:
ds_SFT.save_to_disk('./data/ds_SFT')
ds_RL.save_to_disk('./data/ds_RL')

# Embedding Data

## Definitions

In [None]:
def combine_title_body(example):
    title = ' '.join(example['title'].split())
    selftext = ' '.join(example['selftext'].split())

    combined = title +'\n'+selftext

    return {'title_body':combined}

## Code

In [None]:
ds_SFT = load_from_disk('./data/ds_SFT')
ds_RL = load_from_disk('./data/ds_RL')

In [None]:
ds_SFT_emb = ds_SFT.map(combine_title_body)
ds_RL_emb = ds_RL.map(combine_title_body)

Map:   0%|          | 0/63424 [00:00<?, ? examples/s]

Map:   0%|          | 0/940 [00:00<?, ? examples/s]

Map:   0%|          | 0/3157 [00:00<?, ? examples/s]

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
model = SentenceTransformer('all-mpnet-base-v2')

ds_SFT_emb = ds_SFT_emb.map(lambda x:{'qu_emb':
                        model.encode(x['title_body'],
                                     batch_size=32)},
              batched=True)

ds_RL_emb = ds_RL_emb.map(lambda x:{'qu_emb':
                        model.encode(x['title_body'],
                                     batch_size=32)},
              batched=True)

ds_RL_emb = ds_RL_emb.map(lambda x:{'ans_emb':
                        model.encode(x['answers']['text'],
                                     batch_size=32)})

In [None]:
ds_SFT_emb.save_to_disk('./data/embedded/ds_SFT_emb')
ds_RL_emb.save_to_disk('./data/embedded/ds_RL_emb')

# Checking Data Leakage

## SFT

In [None]:
ds_SFT_emb = load_from_disk('./data/embedded/ds_SFT_emb')
ds_SFT_emb.set_format('torch')

In [None]:
vecs_SFT = {}
for split in ['train','validation','test']:
    vecs_SFT[split] = ds_SFT_emb[split]['qu_emb']
    vecs_SFT[split] /= torch.sqrt(torch.sum(vecs_SFT[split]**2,
                                            dim=1,
                                            keepdim=True))

In [None]:
overlap_SFT = {}
splits = ['train','validation','test']

for j in range(1,3):
    for i in range(j):
        overlap_SFT[(splits[i],splits[j])] = torch.matmul(
            vecs_SFT[splits[i]],
            vecs_SFT[splits[j]].T)

In [None]:
idxs_SFT = {}

for j in range(1,3):
    for i in range(j):
        idxs_SFT[(splits[i],splits[j])] = torch.where((overlap_SFT[(splits[i],splits[j])])>=0.55)

In [None]:
ds_compare_SFT = {}

for j in range(1,3):
    for i in range(j):

        idxs_1,idxs_2 = idxs_SFT[(splits[i],splits[j])]

        idxs_1 = idxs_1.numpy()
        idxs_2 = idxs_2.numpy()

        Q_1 = (ds_SFT_emb[splits[i]].select(idxs_1))['title_body']

        Q_2 = (ds_SFT_emb[splits[j]].select(idxs_2))['title_body']

        overlaps = overlap_SFT[(splits[i],splits[j])][idxs_SFT[(splits[i],splits[j])]]

        ds_compare_SFT[splits[i],splits[j]] = Dataset.from_dict(
            {'overlaps':overlaps.numpy(),
             'idxs_1':idxs_1,
             'Q_1':Q_1,
             'idxs_2':idxs_2,
             'Q_2':Q_2}
        )

with open('./data/ds_compare_SFT.pkl', 'wb') as f:
    pickle.dump(ds_compare_SFT,f)

In [None]:
pd.DataFrame(ds_compare_SFT['train','validation'])[['overlaps','Q_1','Q_2']].sort_values(by='overlaps')

## RL

In [None]:
def get_embeddings(model,text):
    embeddings = model.encode(text,show_progress_bar=True)
    return embeddings

class embed_qu_wrapper:
    def __init__(self,checkpoint='all-mpnet-base-v2'):
        self.checkpoint = checkpoint
        self.model = SentenceTransformer(self.checkpoint)

    def embed_qu(self,example):

        text = example['title_body']
        example['embedding'] = get_embeddings(self.model,text)
        return example

def make_pairs(example):
    answers = example['answers']['text']
    scores = example['answers']['score']
    embds = example['ans_emb']

    sc_ans = tuple(zip(scores,answers,
                       embds
                       ))
    pairs = tuple(combinations(sc_ans,2))

    if len(pairs)>10:
        pairs = random.sample(pairs,10)

    pairs = list(map(lambda x: sorted(x,key=lambda y:y[0],
                                 reverse=True),pairs))

    pairs_text = [((sc_pair[0][1]),sc_pair[1][1]) for sc_pair in pairs]
    pairs_emb = [((sc_pair[0][2]),sc_pair[1][2]) for sc_pair in pairs]

    example['pairs'] = {}
    example['pairs']['text'] = pairs_text
    example['pairs']['embs'] = pairs_emb

    return example

    #[((num,pair,emb),(num,pair,emb)), ((num,pair),(num,pair))]

In [None]:
ds_RL_emb = load_from_disk('./data/embedded/ds_RL_emb')
ds_RL_emb
vecs_RL = {}

In [None]:
vecs_RL = defaultdict

for split in ['train','validation','test']:

    vecs_SFT[split] = ds_SFT_emb[split]['qu_emb']
    vecs_SFT[split] /= torch.sqrt(torch.sum(vecs_SFT[split]**2,
                                            dim=1,
                                            keepdim=True))

In [None]:
ds_RL_pairs = ds_RL_emb.map(lambda x:make_pairs(x),
                            remove_columns=ds_RL_emb['train'].column_names)

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_RL_pairs['train'][0]['pairs']['text'][0]

["white is an ever shifting definition. it wasn't that long ago that irish, italians, eastern europeans and jews weren't considered white.",
 "what if you thought of everyone from english speaking countries as a group...let's say hugh jackman, bob marley, and narendra modi (prime minister of india)...you could say that they're all anglos...but they're not the same race."]

In [None]:
ds_RL_pairs.

DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu_emb'],
        num_rows: 63424
    })
    validation: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu_emb'],
        num_rows: 940
    })
    test: Dataset({
        features: ['q_id', 'title', 'selftext', 'document', 'subreddit', 'url', 'answers', 'title_urls', 'selftext_urls', 'answers_urls', 'pref_idxs', 'dupl_scores_idxs', 'title_body', 'qu_emb'],
        num_rows: 3157
    })
})

# Scratch

In [5]:
ds_reduced = load_from_disk('./data/reduced_dataset')

In [20]:
ds_flattened = ds_reduced.flatten()
df_train = pd.DataFrame(ds_flattened['train'])

In [21]:
df_train.columns

Index(['q_id', 'title', 'selftext', 'document', 'subreddit', 'url',
       'answers.a_id', 'answers.fkg', 'answers.fre', 'answers.score',
       'answers.text', 'title_urls', 'selftext_urls', 'answers_urls'],
      dtype='object')

In [37]:
fkg_scores = df_train['answers.fkg'].explode()
fre_scores = df_train['answers.fre'].explode()

In [41]:
scores_summ = pd.DataFrame([fkg_scores.values,fre_scores.values]).T

In [42]:
scores_summ.columns=['fkg','fre']

In [44]:
scores_summ.corr()

Unnamed: 0,fkg,fre
fkg,1.0,-0.89608
fre,-0.89608,1.0


In [45]:
scores_summ.describe()

Unnamed: 0,fkg,fre
count,579078.0,579078.0
mean,9.041773,65.047794
std,3.565566,14.804391
min,-15.7,-605.67
25%,6.8,56.08
50%,8.7,65.73
75%,10.9,74.9
max,172.1,206.84
