In [None]:

from datasets import load_dataset, Dataset
import pandas as pd
from pathlib import Path

paper_path = Path("../../papers/Distillation-MI-ICLR/tables/nlp/")



In [None]:

def load_specter():
    """
    embedding-data/SPECTER
    :return:
    """
    dataset = load_dataset("embedding-data/specter")
    dataset = dataset["train"]

    samples = []

    for elem in dataset:
        s1, s2 = elem["set"][0], elem["set"][1]
        samples.append(s1)
        samples.append(s2)

    df = pd.DataFrame(samples, columns=["text"])
    df = df.drop_duplicates()

    return df


def load_amazon_qa():
    dataset = load_dataset("embedding-data/Amazon-QA")
    dataset = dataset["train"]

    samples = []
    for elem in dataset:
        query = elem["query"]
        samples.append(query)
        for answer in elem["pos"]:
            samples.append(answer)

    df = pd.DataFrame(samples, columns=["text"])
    df = df.drop_duplicates()

    return df


def load_simple_wiki():
    dataset = load_dataset("embedding-data/simple-wiki")
    dataset = dataset["train"]

    samples = []
    for elem in dataset:
        for s in elem["set"]:
            samples.append(s)

    df = pd.DataFrame(samples, columns=["text"])
    df = df.drop_duplicates()

    return df


def load_QQP_triplets():
    dataset = load_dataset("embedding-data/QQP_triplets")
    dataset = dataset["train"]

    samples = []

    for elem in dataset:
        d = elem["set"]
        samples.append(d["query"])

        for s in d["pos"]:
            samples.append(s)

        for s in d["neg"]:
            samples.append(s)

    df = pd.DataFrame(samples, columns=["text"])
    df = df.drop_duplicates()

    return df


def load_sentence_compression():
    dataset = load_dataset("embedding-data/sentence-compression")
    dataset = dataset["train"]

    samples = []
    for elem in dataset:
        for s in elem["set"]:
            samples.append(s)

    df = pd.DataFrame(samples, columns=["text"])
    df = df.drop_duplicates()

    return df


def load_altlex():
    dataset = load_dataset("embedding-data/altlex")
    dataset = dataset["train"]

    samples = []
    for elem in dataset:
        for s in elem["set"]:
            samples.append(s)

    df = pd.DataFrame(samples, columns=["text"])
    df = df.drop_duplicates()

    return df


def load_agnews():
    dataset = load_dataset("fancyzhx/ag_news")
    dataset = dataset["train"]

    df = dataset.to_pandas()
    df = df["text"].to_frame()

    return df


def load_sst2():
    dataset = load_dataset("stanfordnlp/sst2")
    dataset = dataset["train"]
    df = dataset.to_pandas()["sentence"].to_frame()
    df = df.rename(columns={"sentence": "text"})

    return df


def load_dair_emotion():
    dataset = load_dataset("dair-ai/emotion", "unsplit")
    dataset = dataset["train"]
    df = dataset.to_pandas()
    df = df["text"].to_frame()

    return df


def load_snli():
    dataset = load_dataset("stanfordnlp/snli")

    dataset = dataset["train"]

    premise, hypothesis = (dataset["premise"], dataset["hypothesis"])
    df_premise = pd.DataFrame(premise, columns=["text"])
    df_hypothesis = pd.DataFrame(hypothesis, columns=["text"])

    df = pd.concat([df_premise, df_hypothesis])

    return df


def tweet_eval():
    dataset = load_dataset("cardiffnlp/tweet_eval", "emoji")

    dataset = dataset["train"]

    df = dataset.to_pandas()
    df = df["text"].to_frame()

    return df


def load_imdb():
    dataset = load_dataset("stanfordnlp/imdb")
    dataset = dataset["train"]

    df = dataset.to_pandas()
    df = df["text"].to_frame()

    return df


In [None]:
# Load all datasets and compute statistics


# load dataset statistics
datasets = {
    "SPECTER": (load_specter(), "embedding-data/SPECTER"),
    "Amazon-QA": (load_amazon_qa(), "embedding-data/Amazon-QA"),
    "Simple-wiki": (load_simple_wiki(), "embedding-data/simple-wiki"),
    "QQP_triplets": (load_QQP_triplets(), "embedding-data/QQP_triplets"),
    "Sentence-compression": (load_sentence_compression(), "embedding-data/sentence-compression"),
    "Altlex": (load_altlex(), "embedding-data/altlex"),
    "AG-news": (load_agnews(), "fancyzhx/ag_news"),
    "SST2": (load_sst2(), "stanfordnlp/sst2"),
    "DAIR-emotion": (load_dair_emotion(), "dair-ai/emotion"),
    "SNLI": (load_snli(), "stanfordnlp/snli"),
    "Tweet_eval": (tweet_eval(), "cardiffnlp/tweet_eval"),
    "IMDB": (load_imdb(), "stanfordnlp/imdb"),
}

# Make huggingface url
hfbase = "https://huggingface.co/datasets/"

# Make dataframe with Name, URL, Number of samples

data = []
for name, (df, url) in datasets.items():
    data.append((name, "\\url{" + hfbase + url + "}", len(df)))
    
df = pd.DataFrame(data, columns=["Name", "URL", "Number of samples"])

    




In [None]:
# add line for total number of samples
total_samples = df["Number of samples"].sum()

data.append(("Total", "Total", total_samples))

df = pd.DataFrame(data, columns=["Name", "URL", "Number of samples"])

In [None]:
idx = pd.IndexSlice


style = df[['URL', "Number of samples"]].set_index("URL").style
style = style.format(escape="latex", subset=["Number of samples"])

# hide index
# style = style.hide(axis=0)

         
latex= style.to_latex(caption="Number of samples in each dataset", clines="skip-last;index", hrules=True)

# add resize box
latex = latex.replace("\\begin{tabular}", "\\centering \\resizebox{\\textwidth}{!}{ \\begin{tabular}")
latex = latex.replace("\\end{tabular}", "\\end{tabular}\n}")



print(latex)

with open(paper_path / "training_datasets.tex", "w") as f:
    f.write(latex)
    