In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np
import en_core_web_sm
from IPython import display
from tqdm import tqdm
from collections import namedtuple
from typing import List, Tuple, NamedTuple


In [None]:
def rename_datasets(dataset):
    dataset = dataset.rename_column(dataset.column_names[0], "source")
    dataset = dataset.rename_column(dataset.column_names[1], "target")
    return dataset


def spacy_token(samples: List[str]) -> NamedTuple:
    """
    Compute number of tokens in each row.
    Input: rows of tokens, number of rows.
    Output: mean, standard deviation and median of tokens.
    """

    stats = namedtuple("stats", "mean median std")

    # lens = [len(nlp(token)) for token in tqdm(tokens)]
    tokens = list(tqdm(nlp.pipe(samples, n_process=8), total=len(samples)))
    lens = [len(token) for token in (iter(tokens))]

    lens = np.array(lens)
    stats.lens = lens
    stats.mean = np.mean(lens)
    stats.median = np.median(lens)
    stats.std = np.std(lens)

    return stats


def whitespace_token(samples: List[str]) -> NamedTuple:

    stats = namedtuple("stats", "mean median std lens")

    lens = samples.str.split().str.len()

    lens = np.array(lens)
    stats.lens = lens
    stats.mean = np.mean(lens)
    stats.median = np.median(lens)
    stats.std = np.std(lens)

    return stats


def format_tuning(dataset):
    try:
        if (
            dataset.features["source"].feature._type == "Value"
        ):  # One row of source is a line in article.
            dataset = dataset.to_pandas()

    except AttributeError:
        if len(dataset.features["source"].feature) > 1:
            dataset = pd.DataFrame(dataset["source"])
            dataset = dataset.rename(
                columns={"document": "source", "summary": "target"}
            )

    dataset["source"] = dataset["source"].str.join("")
    dataset["target"] = dataset["target"].str.join("")
    return dataset

def remove_empty(df):
    df.replace(to_replace=r'^\s*$',value=np.nan,regex=True,inplace=True)
    df = df.dropna()
    return df


def stats_cal(
    dataset,
    dataset_name: str,
    tokenization_method: str = "whitespace",
    stats_to_compute: List[str] = [
        "SampleNum",
        "mean",
        "median",
        "std",
        "compression_ratio",
    ],
) -> NamedTuple:

    stats_attr_src = namedtuple("stats_attr", stats_to_compute)
    stats_attr_tg = namedtuple("stats_attr", stats_to_compute)
    stats = namedtuple("stats", "src tg")
    stats.src = stats_attr_src
    stats.tg = stats_attr_tg
    stats.src.SampleNum = dataset.num_rows
    stats.tg.SampleNum = dataset.num_rows

    # Use pandas dataframe to process data and remove samples that contain empty strings. 

    if dataset.features["source"]._type == "Value":  # One row of source is one article.
            dataset = dataset.to_pandas()
            
    elif (
        dataset.features["source"]._type == "Sequence"
    ):  # One row of source is a line in article or combined with article, summary and id.
        dataset = format_tuning(dataset)
    
    dataset = remove_empty(dataset)
    print(dataset.shape)
    
    if tokenization_method == "whitespace":
        stats_src = whitespace_token(dataset["source"])
        stats_tg = whitespace_token(dataset["target"])
    elif tokenization_method == "spacy":
        stats_src = spacy_token(dataset["source"])
        stats_tg = spacy_token(dataset["target"])

    if "SampleNum" in stats_to_compute:
        print(
            f"[{dataset_name}] Number of samples of article or summary: {stats.src.SampleNum}"
        )
    if "mean" in stats_to_compute:
        stats.src.mean = stats_src.mean
        stats.tg.mean = stats_tg.mean
        print(
            f"[{dataset_name}] Mean of article & summary: {stats.src.mean:.2f}, {stats.tg.mean:.2f}"
        )
    if "median" in stats_to_compute:
        stats.src.median = stats_src.median
        stats.tg.median = stats_tg.median
        print(
            f"[{dataset_name}] Median of article & summary: {stats.src.median:.2f}, {stats.tg.median:.2f}"
        )
    if "std" in stats_to_compute:
        stats.src.std = stats_src.std
        stats.tg.std = stats_tg.std
        print(
            f"[{dataset_name}] Standard Deviation of article & summary: {stats.src.std:.2f}, {stats.tg.std:.2f}"
        )
    if "compression_ratio" in stats_to_compute:
        stats.src.compression_ratio = np.mean(stats_src.lens / stats_tg.lens)
        stats.tg.compression_ratio = stats.src.compression_ratio
        print(
            f"[{dataset_name}] ratio of article/summary: {stats.src.compression_ratio:.2f}"
        )

    # print(len(stats_tg.lens[np.where(stats_tg.lens == 0)]))
    return stats


def print_stats(
    dataset, dataset_name: str, tokenization_method: str = "whitespace"
) -> None:

    print(f"********{tokenization_method}********")
    stats = stats_cal(dataset, dataset_name, tokenization_method)


def load_data(dataset_name: str, version: str, split_: str = "train"):
    dataset = load_dataset(dataset_name, version, split=split_)
    if dataset_name == "wiki_lingua":
        dataset = dataset.rename_column("article", "source")
    elif dataset_name == "scitldr":
        pass
    else:
        dataset = rename_datasets(dataset)

    return dataset


nlp = en_core_web_sm.load(
    disable=("tok2vec", "tagger", "lemmatizer", "ner")
)  # Disabling components for only tokenization use.

In [None]:
cnn_train = load_data("cnn_dailymail", "3.0.0", "train")

In [12]:
cnn_test = load_data("cnn_dailymail", "3.0.0", "test")

Reusing dataset cnn_dailymail (/home/jli/working_dir/datasets_cache/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)


In [None]:
xsum_test = load_data("xsum", "1.2.0", "test")

In [13]:
print_stats(cnn_test, "cnn_dailymail")

********whitespace********
(11490, 3)
[cnn_dailymail] Number of samples of article or summary: 11490
[cnn_dailymail] Mean of article & summary: 683.51, 55.01
[cnn_dailymail] Median of article & summary: 613.00, 51.00
[cnn_dailymail] Standard Deviation of article & summary: 348.39, 22.52
[cnn_dailymail] ratio of article/summary: 13.45


In [9]:
print_stats(xsum_test, "xsum")

********whitespace********
(11333, 3)
[xsum] Number of samples of article or summary: 11334
[xsum] Mean of article & summary: 376.18, 21.10
[xsum] Median of article & summary: 295.00, 21.00
[xsum] Standard Deviation of article & summary: 308.22, 5.32
[xsum] ratio of article/summary: 18.98


In [14]:
print_stats(xsum_test, "xsum", "spacy")

********spacy********
(11333, 3)


100%|██████████| 11333/11333 [01:27<00:00, 129.88it/s]
100%|██████████| 11333/11333 [00:10<00:00, 1080.93it/s]

[xsum] Number of samples of article or summary: 11334
[xsum] Mean of article & summary: 457.51, 23.95
[xsum] Median of article & summary: 359.00, 24.00
[xsum] Standard Deviation of article & summary: 377.91, 6.00
[xsum] ratio of article/summary: 20.21





In [None]:
# tokens = list(tqdm(nlp.pipe(xsum_train["document"], n_process=8), total=len(xsum_train["document"])))

In [None]:
# df[df.isnull().values == True]

In [None]:
# wiki_en = load_dataset("GEM/wiki_lingua", "en", split="train")
# df = pd.DataFrame(wiki_en)
# df.replace(to_replace=r'^\s*$',value=np.nan,regex=True,inplace=True)
# display.display(df.isnull().sum())

In [11]:
cnn_train.features

{'source': Value(dtype='string', id=None),
 'target': Value(dtype='string', id=None),
 'id': Value(dtype='string', id=None)}