In [None]:
import warnings
warnings.filterwarnings("ignore")

from dotenv import load_dotenv
import os
load_dotenv()
HF_TOKEN_read = os.getenv("HF_TOKEN_read")

In [None]:
from huggingface_hub import login
login(token = HF_TOKEN_read)

In [None]:
from datasets import load_dataset

pretraining_dataset = load_dataset("BaekSeungJu/Ophthalmology-PubMed-Corpus")
print(pretraining_dataset)

In [4]:
dataset = pretraining_dataset.select_columns(
    ["text"]
)

In [None]:
print(dataset["train"][9]["text"])

## 2. Data cleaning

In the cells below, you'll carry out the following cleaning steps:
1. Filter out samples that are too short
2. Remove repetitions within a single text example
3. Remove duplicated documents
4. Quality filter to remove non-English texts 

In [None]:
dataset.num_rows

### Remove repeated text within training examples

Here you'll remove text repetitions within each example. 

In [7]:
def find_duplicates(paragraphs):
    """
    Use this function to find the number of repetitions 
    in the paragraphs.
    """
    unique_x = set()
    duplicate_chars = 0
    duplicate_elements = 0
    for element in paragraphs:
        if element in unique_x:
            duplicate_chars += len(element)
            duplicate_elements += 1
        else:
            unique_x.add(element)
    return duplicate_elements, duplicate_chars

In [8]:
import re

def paragraph_repetition_filter(x):
    """
    Returns False iff a page has too many repetitions.
    """
    text = x['text']
    paragraphs = re.compile(r"\n{2,}").split(text.strip())                # Split by paragraphs (2 or more newlines)
    paragraphs_duplicates, char_duplicates = find_duplicates(paragraphs)  # Find number of duplicates in paragraphs
    if paragraphs_duplicates / len(paragraphs) > 0.3:
        return False
    if char_duplicates / len(text) > 0.2:
        return False
    return True

In [None]:
dataset = dataset.filter(
    paragraph_repetition_filter,
    load_from_cache_file=False
)

dataset.num_rows

### Deduplication

In this section, you'll remove duplicate examples from the entire dataset (in contrast to the previous step where you were just looking for repeated text in each example.)

In [None]:
def deduplication(ds):
    def dedup_func(x):
        """Use this function to remove duplicate entries"""
        if x['text'] in unique_text:
            return False
        else:
            unique_text.add(x['text'])
            return True

    unique_text = set()

    ds = ds.filter(dedup_func, load_from_cache_file=False, num_proc=1)
    return ds

dataset = deduplication(dataset)
dataset.num_rows

### Quality filter - Language

Here you'll remove any text examples that are in a language other than English. The code here uses a language detection model called fastText. You can read about fastText [here](https://fasttext.cc/).

In [None]:
import urllib
from fasttext.FastText import _FastText

def english_language_filter(ds):
    # load language detection model
    model = _FastText('./models/L2_language_model.bin')
    
    def is_english(x):
        # Predict language of the text and probability
        language, score = model.predict(x['text'].replace("\n", ""))

        language = language[0].split("__")[2]
        return score > 0.4 and language == "en" # change code here if building a model in another language

    ds = ds.filter(is_english, load_from_cache_file=False, num_proc=1)
    return ds

dataset = english_language_filter(dataset)

dataset.num_rows

In [None]:
print(dataset)

## 3. Save the dataset to disk

Read more about the parquet data format [here](https://parquet.apache.org/).

In [None]:
directory = "./Pre-Training-Dataset"

if not os.path.exists(directory):
    os.makedirs(directory)

file_path = os.path.join(directory, "Preprocessed_pretrain_Dataset.parquet")
dataset["train"].to_parquet(file_path)