<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/data_cleaning/ELI5_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [2]:
%ls drive/MyDrive/LLMs/ELI5_dataset

compare_sents_SFT.pkl  [0m[01;34mdata[0m/  ELI5_analysis.ipynb  [01;34mresults[0m/


In [3]:
%cd drive/MyDrive/LLMs/ELI5_dataset

!pip install datasets --quiet
!pip install textstat --quiet
!pip install wandb --quiet
!pip install redditcleaner --quiet
!pip install huggingface_hub --quiet
!pip install -U sentence-transformers --quiet

/content/drive/MyDrive/LLMs/ELI5_dataset


In [4]:
import wandb, torch
import sys
import datasets
import os
import redditcleaner
import re
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from sentence_transformers import SentenceTransformer
from textstat import flesch_reading_ease as fre
from textstat import flesch_kincaid_grade as fkg
from datasets import (load_dataset,
                      load,
                      load_from_disk,
                      Dataset,
                      concatenate_datasets,
                      DatasetDict)
from itertools import compress
from tqdm import tqdm
from collections import defaultdict
from itertools import combinations
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

%matplotlib inline

In [5]:
notebook_login()

#creates the 'results' and 'data' directories, if they don't exist.
if not os.path.exists('results'):
    os.makedirs('results')

if not os.path.exists('data'):
    os.makedirs('data')

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Filtering Data

## Definitions

In [None]:
def replace_url_i(input_string):
    # Define the regular expression pattern to match "_url_i_" where i is an arbitrary integer

    pattern_1 = r"_url_\d+_"  # \d+ matches one or more digits
    pattern_2 = r"_Url_\d+_"
    pattern_3 = r"_URL_\d+_"

    # Use re.sub() to replace all occurrences of the pattern with an empty string
    output_string = re.sub(pattern_1, "", input_string)
    output_string = re.sub(pattern_2, "", output_string)
    output_string = re.sub(pattern_3, "", output_string)

    return output_string

def preprocess_example(example):

    answers = example['answers']['text']
    answers = [redditcleaner.clean(answer) for answer in answers]
    answers = [re.sub('>.*?\n',' ',answer) for answer in answers]
    answers = [' '.join(answer.lower().split()) for answer in answers]
    answers = [replace_url_i(answer) for answer in answers]
    answers = [answer for answer in answers if len(answer.split())>=20]
    example['answers']['text'] = answers

    title = example['title']
    title = redditcleaner.clean(title)
    title = ' '.join(title.split())
    title = replace_url_i(title)
    example['title'] = title

    selftext = example['selftext']
    selftext = redditcleaner.clean(selftext)
    selftext = ' '.join(selftext.lower().split())
    selftext = replace_url_i(selftext)
    example['selftext'] = selftext

    return example

def preprocess_data(dataset):
    dataset = dataset.map(preprocess_example)
    return dataset

class score_cutoff_wrapper:
    def __init__(self,cutoff):
        self.cutoff = cutoff

    def score_cutoff_ex(self,example):
        scores = example['answers']['score']
        idxs = list(np.array(scores) >= self.cutoff)
        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val,idxs))

        return example


def score_cutoff(dataset,cutoff):
    cutoff = score_cutoff_wrapper(cutoff)
    ds = dataset.map(cutoff.score_cutoff_ex)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds

def flesch_scores(example):

    fre_scores = [fre(text) for text in example['answers']['text']]
    fkg_scores = [fkg(text) for text in example['answers']['text']]
    example['answers']['fre'] = fre_scores
    example['answers']['fkg'] = fkg_scores

    return example

class flesch_scores_filter_wrapper:
    def __init__(self,fre_cutoff, fkg_cutoff):
        self.fre_cutoff = fre_cutoff
        self.fkg_cutoff = fkg_cutoff

    def flesch_scores_filter(self,example):

        fre_scores = example['answers']['fre']
        fkg_scores = example['answers']['fkg']

        idxs = [True if (fre_scores[i]>=self.fre_cutoff
                         and fkg_scores[i]<self.fkg_cutoff) else False
                for i in range(len(fre_scores))]

        for key, val in example['answers'].items():
            example['answers'][key] = list(compress(val,idxs))

        return example

def flesch_scores_cutoff(dataset,fre_cutoff=60,fkg_cutoff=9):
    filter = flesch_scores_filter_wrapper(fre_cutoff, fkg_cutoff)
    ds = dataset.map(filter.flesch_scores_filter)
    ds = ds.filter(lambda post: len(post['answers']['score'])>0)

    return ds

## Code

In [None]:
dataset = load_dataset("vblagoje/lfqa")

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/687M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
dataset_preprocessed = preprocess_data(dataset)
dataset_preprocessed.save_to_disk('./data/preprocessing')

Map:   0%|          | 0/226147 [00:00<?, ? examples/s]

Map:   0%|          | 0/3020 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/226147 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3020 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
not_qus = ['IAMA','AMA','ama:','megathread','Megathread',
           'Discussion Thread','Discussion thread',
           'discussion Thread','discussion thread',
           'Ask Anything Wednesday','Free-for-All',
           'Free-For-All','[META]','Monday Methods',
           'Tuesday Trivia','Monday Mysteries',
           'Theory Thursday','Monday Mish-Mash',
           'Media Mondays','[META]','Wednesday Week in History',
           'Saturday Popular Questions','Ask Anything Wednesday',
           'Thursday Focus Historical Fiction']

qu_reqs = ['who','what','where','why','when','how','?']

In [None]:
ds_reduced = dataset_preprocessed.filter(lambda post:
                                         not (all(qu_req not in post['title'].lower() for qu_req in qu_reqs)
                                         and all(qu_req not in post['selftext'].lower() for qu_req in qu_reqs)))

ds_reduced = ds_reduced.filter(lambda post:
                                       not (any(nq in post['title'] for nq in not_qus)))

Filter:   0%|          | 0/226147 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3020 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/221802 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2983 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9734 [00:00<?, ? examples/s]

In [None]:
ds_reduced = ds_reduced.map(flesch_scores)

ds_reduced.save_to_disk('./data/reduced_dataset')

Map:   0%|          | 0/221322 [00:00<?, ? examples/s]

Map:   0%|          | 0/2963 [00:00<?, ? examples/s]

Map:   0%|          | 0/9696 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/221322 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2963 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/9696 [00:00<?, ? examples/s]

In [None]:
ds_filtered = score_cutoff(ds_reduced,4)
ds_filtered = flesch_scores_cutoff(ds_filtered,fkg_cutoff=9)

ds_filtered_mult = ds_filtered.filter(lambda post : len(post['answers']['score'])>=2)
ds_filtered_sing = ds_filtered.filter(lambda post : len(post['answers']['score'])==1)

In [None]:
ds_filtered_mult.save_to_disk('./data/filtered/mult_ans')
ds_filtered_sing.save_to_disk('./data/filtered/sing_ans')

Saving the dataset (0/1 shards):   0%|          | 0/29978 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1349 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2381 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/56730 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/595 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2684 [00:00<?, ? examples/s]

# SFT, RM, RL Datasets

## Definitions

In [None]:
def split_idxs(example):
    scores = example['answers']['score']
    scores_unique = sorted(set(scores),reverse=True)
    pref_scores_idxs = [scores.index(sc)for sc in scores_unique]
    dupl_scores_idxs = [n for n in range(len(scores)) if n not in pref_scores_idxs]

    example['pref_idxs'] = pref_scores_idxs
    example['dupl_scores_idxs'] = dupl_scores_idxs

    return example

def mult_ans_RM_proc(example):
    pref_scores_idxs = example['pref_idxs']
    for key, val in example['answers'].items():
        example['answers'][key] = [example['answers'][key][i] for i in pref_scores_idxs]
    return example

def mult_ans_SFT_proc(example):
    dupl_scores_idxs = example['dupl_scores_idxs']
    for key, val in example['answers'].items():
        example['answers'][key] = [example['answers'][key][i] for i in dupl_scores_idxs]
    return example

## Code

In [None]:
ds_filtered_mult = load_from_disk('./data/filtered/mult_ans')
ds_filtered_sing = load_from_disk('./data/filtered/sing_ans')

In [None]:
ds_filtered_indexed = ds_filtered_mult.map(split_idxs)

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_RM = ds_filtered_indexed.map(mult_ans_RM_proc)
ds_RM = ds_RM.filter(lambda x: len(x['answers']['score'])>0)

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

Filter:   0%|          | 0/29978 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_SFT_mult = ds_filtered_indexed.map(mult_ans_SFT_proc)
ds_SFT_mult = ds_SFT_mult.filter(lambda x: len(x['answers']['score'])>0)

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

Filter:   0%|          | 0/29978 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_SFT = datasets.DatasetDict()

In [None]:
for key in ['train','validation','test']:
    ds_SFT[key] = datasets.concatenate_datasets([ds_SFT_mult[key], ds_filtered_sing[key]])

In [None]:
q_ids_taken = []

for ds_ in ds_SFT, ds_RM:
    for split in ds_:
        q_ids_taken.extend(ds_[split]['q_id'])

q_ids_taken = set(q_ids_taken)

In [None]:
ds_RL = dataset.filter(lambda post: post['q_id'] not in q_ids_taken)
ds_RL = concatenate_datasets([ds for ds in ds_RL.values()])

Filter:   0%|          | 0/226147 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3020 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
%ls

compare_sents_SFT.pkl  [0m[01;34mdata[0m/  ELI5_analysis.ipynb  [01;34mresults[0m/


In [None]:
ds_SFT.save_to_disk('./data/ds_SFT')
ds_RM.save_to_disk('./data/ds_RM')
ds_RL.save_to_disk('./data/ds_RL')

Saving the dataset (0/1 shards):   0%|          | 0/63424 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/940 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3157 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/29978 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1349 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2381 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/145450 [00:00<?, ? examples/s]

# Embedding Data

## Definitions

In [None]:
def combine_title_body(example):
    title = ' '.join(example['title'].split())
    selftext = ' '.join(example['selftext'].split())

    combined = title +'\n'+selftext

    return {'title_body':combined}

## Code

In [None]:
ds_SFT = load_from_disk('./data/ds_SFT')
ds_RM = load_from_disk('./data/ds_RM')
ds_RL = load_from_disk('./data/ds_RL')

In [None]:
ds_SFT_emb = ds_SFT.map(combine_title_body)
ds_RM_emb = ds_RM.map(combine_title_body)
ds_RL_emb = ds_RL.map(combine_title_body)

In [None]:
model = SentenceTransformer('all-mpnet-base-v2')

#for ds_ in [ds_SFT_emb, ds_RM_emb, ds_RL_emb]:
ds_SFT_emb = ds_SFT_emb.map(lambda x:{'qu_emb':
                        model.encode(x['title_body'],
                                     batch_size=32)},
              batched=True)

ds_RM_emb = ds_RM_emb.map(lambda x:{'qu_emb':
                        model.encode(x['title_body'],
                                     batch_size=32)},
              batched=True)

ds_RL_emb = ds_RL_emb.map(lambda x:{'qu_emb':
                        model.encode(x['title_body'],
                                     batch_size=32)},
              batched=True)

#ds_SFT_emb = ds_SFT_emb.map(lambda x:{'qu_emb':
#                        model.encode(x['title_body'],
#                                     batch_size=32)},
#              batched=True)
#
#ds_RM_emb = ds_RM_emb.map(lambda x:{'qu_emb':
#                        model.encode(x['title_body'],
#                                     batch_size=32)},
#              batched=True)

ds_RM_emb = ds_RM_emb.map(lambda x:{'ans_emb':
                        model.encode(x['answers']['text'],
                                     batch_size=32)})

Downloading (…)a8e1d/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)0bca8e1d/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)e1d/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)a8e1d/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)8e1d/train_script.py:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)bca8e1d/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

Map:   0%|          | 0/940 [00:00<?, ? examples/s]

Map:   0%|          | 0/29978 [00:00<?, ? examples/s]

Map:   0%|          | 0/1349 [00:00<?, ? examples/s]

Map:   0%|          | 0/2381 [00:00<?, ? examples/s]

In [None]:
ds_SFT_emb.save_to_disk('./data/embedded/ds_SFT_emb')
ds_RM_emb.save_to_disk('./data/embedded/ds_RM_emb')
ds_RL_emb.save_to_disk('./data/embedded/ds_RL_emb')

Saving the dataset (0/1 shards):   0%|          | 0/63424 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/940 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3157 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/29978 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1349 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2381 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/145450 [00:00<?, ? examples/s]

In [None]:
%ls data

compare_sents_SFT.pkl  [0m[01;34mds_RL[0m/  [01;34mds_SFT[0m/    [01;34mfiltered[0m/       [01;34mreduced_dataset[0m/
ds_compare_SFT.pkl     [01;34mds_RM[0m/  [01;34membedded[0m/  [01;34mpreprocessing[0m/


# Checking Data Leakage

In [73]:
from datetime import datetime

now = datetime.now()

current_time = now.strftime("%m/%d/%y_%H:%M:%S")
print("Current Time =", current_time)

Current Time = 08/01/23_04:03:14


## SFT

In [4]:
ds_SFT_emb = load_from_disk('./data/embedded/ds_SFT_emb')
ds_SFT_emb.set_format('torch')

In [13]:
vecs_SFT = {}
for split in ['train','validation','test']:
    vecs_SFT[split] = ds_SFT_emb[split]['qu_emb']
    vecs_SFT[split] /= torch.sqrt(torch.sum(vecs_SFT[split]**2,
                                            dim=1,
                                            keepdim=True))

In [7]:
overlap_SFT = {}
idxs_SFT = {}
splits = ['train','validation','test']

for j in range(1,3):
    for i in range(j):
        overlap_SFT[(splits[i],splits[j])] = torch.matmul(
            vecs_SFT[splits[i]],
            vecs_SFT[splits[j]].T)

        idxs_SFT[(splits[i],splits[j])] = torch.where((overlap_SFT[(splits[i],splits[j])])>=0.6)

In [8]:
ds_compare_SFT = {}

for j in range(1,3):
    for i in range(j):

        idxs_1,idxs_2 = idxs_SFT[(splits[i],splits[j])]

        idxs_1 = idxs_1.numpy()
        idxs_2 = idxs_2.numpy()

        Q_1 = (ds_SFT_emb[splits[i]].select(idxs_1))['title_body']

        Q_2 = (ds_SFT_emb[splits[j]].select(idxs_2))['title_body']

        overlaps = overlap_SFT[(splits[i],
                                splits[j])][idxs_SFT[(splits[i],
                                                      splits[j])]]

        ds_compare_SFT[splits[i],splits[j]] = Dataset.from_dict(
            {'overlaps':overlaps.numpy(),
             'idxs_1':idxs_1,
             'Q_1':Q_1,
             'idxs_2':idxs_2,
             'Q_2':Q_2}
        )

with open('./data/ds_compare_SFT.pkl', 'wb') as f:
    pickle.dump(ds_compare_SFT,f)

In [9]:
with open('./data/ds_compare_SFT.pkl', 'rb') as f:
    ds_compare_SFT = pickle.load(f)

In [18]:
df_overlaps_SFT = {}
splits = ['train','validation','test']

for j in range(1,3):
    for i in range(j):
        df_overlaps_SFT[splits[i],splits[j]]=pd.DataFrame(ds_compare_SFT[splits[i],splits[j]])

In [32]:
rem_train_pt_1 = df_overlaps_SFT['train','validation']['idxs_1']
rem_train_pt_1 = rem_train_pt_1.values
rem_train_pt_1 = set(rem_train_pt_1)

rem_train_pt_2 = df_overlaps_SFT['train','test']['idxs_1']
rem_train_pt_2 = rem_train_pt_2.values
rem_train_pt_2 = set(rem_train_pt_2)

rem_train = rem_train_pt_1.union(rem_train_pt_2)

keep_train = set(range(len(ds_SFT_emb['train'])))-rem_train

In [33]:
rem_test = df_overlaps_SFT['validation','test']['idxs_2']
rem_test = rem_test.values
rem_test = set(rem_test)

keep_test = set(range(len(ds_SFT_emb['test'])))-rem_test

In [35]:
ds_SFT_clean = DatasetDict()

ds_SFT_clean['train'] = ds_SFT_emb['train'].select(keep_train)

ds_SFT_clean['validation'] = ds_SFT_emb['validation']

ds_SFT_clean['test'] = ds_SFT_emb['test'].select(keep_test)

ds_SFT_clean.save_to_disk('./data/ds_SFT_clean')

## RM

### Set-up

In [None]:
def make_pairs(example):
    answers = example['answers']['text']
    scores = example['answers']['score']
    embds = example['ans_emb']

    sc_ans = tuple(zip(scores,answers,
                       embds
                       ))
    pairs = tuple(combinations(sc_ans,2))

    if len(pairs)>10:
        pairs = random.sample(pairs,10)

    pairs = list(map(lambda x: sorted(x,key=lambda y:y[0],
                                 reverse=True),pairs))

    pairs_text = [((sc_pair[0][1]),sc_pair[1][1]) for sc_pair in pairs]
    pairs_emb = [((sc_pair[0][2]),sc_pair[1][2]) for sc_pair in pairs]

    example['pairs'] = {}
    example['pairs']['text'] = pairs_text
    example['pairs']['embs'] = pairs_emb

    return example

    #[((num,pair,emb),(num,pair,emb)), ((num,pair),(num,pair))]

In [None]:
ds_RM_emb = load_from_disk('./data/embedded/ds_RM_emb')
ds_RM_emb
vecs_RM = {}

In [None]:
ds_RM_pairs = ds_RM_emb.map(lambda x:make_pairs(x),
                            remove_columns=ds_RM_emb['train'].column_names)

In [None]:
ds_RM_pairs.save_to_disk('./data/RM_pairs')

In [None]:
ds_RM_pairs.set_format('torch')
ds_RM_emb.set_format('torch')

### Questions

In [41]:
ds_RM_pairs = load_from_disk('./data/RM_pairs')
ds_RM_emb = load_from_disk('./data/embedded/ds_RM_emb')

ds_RM_emb.set_format('torch')
ds_RM_pairs.set_format('torch')

In [14]:
vecs_RM_Q = {}
for split in ['train','validation','test']:
    vecs_RM_Q[split] = ds_RM_emb[split]['qu_emb']
    vecs_RM_Q[split] /= torch.sqrt(torch.sum(vecs_RM_Q[split]**2,
                                            dim=1,
                                            keepdim=True))

In [46]:
overlap_RM_qus = {}
splits = ['train','validation','test']
idxs_RM_Q = {}

for j in range(1,3):
    for i in range(j):
        overlap_RM_qus[(splits[i],splits[j])] = torch.matmul(
            vecs_RM_Q[splits[i]],
            vecs_RM_Q[splits[j]].T)

        idxs_RM_Q[(splits[i],splits[j])] = torch.where(
            (overlap_RM_qus[(splits[i],splits[j])])>=0.6)

In [47]:
ds_compare_RM_Q = {}

for j in range(1,3):
    for i in range(j):

        idxs_1,idxs_2 = idxs_RM_Q[(splits[i],splits[j])]

        idxs_1 = idxs_1.numpy()
        idxs_2 = idxs_2.numpy()

        Q_1 = (ds_RM_emb[splits[i]].select(idxs_1))['title_body']

        Q_2 = (ds_RM_emb[splits[j]].select(idxs_2))['title_body']

        overlaps = overlap_RM_qus[(splits[i],splits[j])][
            idxs_RM_Q[(splits[i],splits[j])]]

        ds_compare_RM_Q[splits[i],splits[j]] = Dataset.from_dict(
            {'overlaps':overlaps.numpy(),
             'idxs_1':idxs_1,
             'Q_1':Q_1,
             'idxs_2':idxs_2,
             'Q_2':Q_2}
        )

with open('./data/ds_compare_RM_Q.pkl', 'wb') as f:
    pickle.dump(ds_compare_RM_Q,f)

In [51]:
df_overlaps_RM_Q = {}
splits = ['train','validation','test']

for j in range(1,3):
    for i in range(j):
        df_overlaps_RM_Q[splits[i],splits[j]]=pd.DataFrame(
                            ds_compare_RM_Q[splits[i],splits[j]])

In [52]:
rem_train_RM_pt_1 = df_overlaps_RM_Q['train','validation']['idxs_1']
rem_train_RM_pt_1 = rem_train_RM_pt_1.values
rem_train_RM_pt_1 = set(rem_train_RM_pt_1)

rem_train_RM_pt_2 = df_overlaps_RM_Q['train','test']['idxs_1']
rem_train_RM_pt_2 = rem_train_RM_pt_2.values
rem_train_RM_pt_2 = set(rem_train_RM_pt_2)

rem_train_RM = rem_train_RM_pt_1.union(rem_train_RM_pt_2)

keep_train_RM = set(range(len(ds_RM_emb['train'])))-rem_train_RM

In [53]:
rem_test_RM = df_overlaps_RM_Q['validation','test']['idxs_2']
rem_test_RM = rem_test_RM.values
rem_test_RM = set(rem_test_RM)

keep_test_RM = set(range(len(ds_RM_emb['test'])))-rem_test_RM

In [None]:
ds_RM_clean = DatasetDict()

ds_RM_clean['train'] = ds_RM_emb['train'].select(keep_train_RM)
ds_RM_clean['validation'] = ds_RM_emb['validation']
ds_RM_clean['test'] = ds_RM_emb['test'].select(keep_test_RM)

ds_RM_clean.save_to_disk('./data/ds_RM_clean')

## RL

In [6]:
ds_RL_emb = load_from_disk('./data/embedded/ds_RL_emb')
ds_RL_emb.set_format('torch')

In [8]:
ds_SFT_emb = load_from_disk('./data/embedded/ds_SFT_emb')
ds_SFT_emb.set_format('torch')

ds_RM_emb = load_from_disk('./data/embedded/ds_RM_emb')
ds_RM_emb.set_format('torch')

In [9]:
vecs_RL = ds_RL_emb['qu_emb']
vecs_RL /= torch.sqrt(torch.sum(vecs_RL**2,
                                        dim=1,
                                        keepdim=True))

In [25]:
from tqdm import tqdm

In [34]:
batch_size = 5000
RL_size = vecs_RL.shape[0]
rem_RL = []
start = 0
i=0

while start < RL_size:
    print(f'Working on batch {i+1}/{RL_size//batch_size+1}')

    batch = vecs_RL[start:start+batch_size,:]

    overlap = torch.matmul(vecs_SFT['train'],
                               batch.T)
    idxs = torch.where(overlap>=.6)
    rem_RL.extend(list(idxs[1].numpy()))

    overlap = torch.matmul(vecs_RM_Q['train'],
                               batch.T)
    idxs = torch.where(overlap>=.6)
    rem_RL.extend(list(idxs[1].numpy()))

    start += batch_size
    i += 1


Working on batch 1/30
Working on batch 2/30
Working on batch 3/30
Working on batch 4/30
Working on batch 5/30
Working on batch 6/30
Working on batch 7/30
Working on batch 8/30
Working on batch 9/30
Working on batch 10/30
Working on batch 11/30
Working on batch 12/30
Working on batch 13/30
Working on batch 14/30
Working on batch 15/30
Working on batch 16/30
Working on batch 17/30
Working on batch 18/30
Working on batch 19/30
Working on batch 20/30
Working on batch 21/30
Working on batch 22/30
Working on batch 23/30
Working on batch 24/30
Working on batch 25/30
Working on batch 26/30
Working on batch 27/30
Working on batch 28/30
Working on batch 29/30
Working on batch 30/30


In [35]:
keep_RL = set(range(len(ds_RL_emb)))
keep_RL -= set(rem_RL)

ds_RL_clean = ds_RL_emb.select(keep_RL)

ds_RL_clean.save_to_disk('./data/ds_RL_clean')

Saving the dataset (0/2 shards):   0%|          | 0/140450 [00:00<?, ? examples/s]

# Scratch

In [None]:
ds_reduced = load_from_disk('./data/reduced_dataset')

In [None]:
ds_flattened = ds_reduced.flatten()
df_train = pd.DataFrame(ds_flattened['train'])

In [None]:
df_train.columns

Index(['q_id', 'title', 'selftext', 'document', 'subreddit', 'url',
       'answers.a_id', 'answers.fkg', 'answers.fre', 'answers.score',
       'answers.text', 'title_urls', 'selftext_urls', 'answers_urls'],
      dtype='object')

In [None]:
fkg_scores = df_train['answers.fkg'].explode()
fre_scores = df_train['answers.fre'].explode()

In [None]:
scores_summ = pd.DataFrame([fkg_scores.values,fre_scores.values]).T

In [None]:
scores_summ.columns=['fkg','fre']

In [None]:
scores_summ.corr()

Unnamed: 0,fkg,fre
fkg,1.0,-0.89608
fre,-0.89608,1.0


In [None]:
scores_summ.describe()

Unnamed: 0,fkg,fre
count,579078.0,579078.0
mean,9.041773,65.047794
std,3.565566,14.804391
min,-15.7,-605.67
25%,6.8,56.08
50%,8.7,65.73
75%,10.9,74.9
max,172.1,206.84


In [None]:
ds_RL_emb = load_from_disk('./data/embedded/ds_RL_emb')

In [None]:
from collections import Counter
Counter(list(map(lambda x:len(x['score']),ds_RL_emb['train']['answers'])))


Counter({4: 2375,
         2: 15590,
         3: 5253,
         1: 2302,
         10: 214,
         5: 1377,
         19: 13,
         6: 854,
         9: 305,
         13: 88,
         16: 38,
         8: 446,
         7: 642,
         24: 3,
         12: 123,
         14: 63,
         11: 162,
         20: 7,
         22: 9,
         15: 40,
         17: 23,
         42: 1,
         23: 6,
         18: 25,
         25: 5,
         21: 5,
         26: 1,
         28: 2,
         27: 2,
         31: 1,
         35: 1,
         49: 1,
         30: 1})