In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy
from collections import defaultdict
import data

# datasets overview

In [2]:
try:
    df = pd.read_csv('results/datasets_ovw.csv', index_col=0)
except:
    tok_simp = English().tokenizer # init here to speedup call
    simple_tokenizer = lambda x: [str(x) for x in tok_simp(x)] 
    ds = defaultdict(list)
    class Args:
        ...

    args = Args()
    args.dataset = ''
    ks = sorted(['emotion', 'financial_phrasebank', 'rotten_tomatoes', 'sst2', 'tweet_eval'])
    for k in ks:
        args.dataset = k
        d, args = data.process_data_and_args(args)
        text = d['train'][args.dataset_key_text]
        ds['n_train'].append(len(text))


        counts = np.unique(d['train']['label'], return_counts=True)[1]
        ds['imbalance'].append(max(counts) / sum(counts))

        ds['num_classes'].append(counts.size)

        text_val = d['validation'][args.dataset_key_text]
        ds['n_val'].append(len(text_val))    

        v = CountVectorizer(tokenizer=simple_tokenizer)
        v.fit(text)
        ds['n_tokens'].append(len(v.vocabulary_))

        v = CountVectorizer(tokenizer=simple_tokenizer, ngram_range=(2, 2))
        v.fit(text)
        ds['n_bigrams'].append(len(v.vocabulary_))

        v = CountVectorizer(tokenizer=simple_tokenizer, ngram_range=(3, 3))
        v.fit(text)
        ds['n_trigrams'].append(len(v.vocabulary_))    
        
    df = pd.DataFrame.from_dict(ds)
    df.index = ks
    df
    df.to_csv('results/datasets_ovw.csv')

In [3]:
df

Unnamed: 0,n_train,imbalance,num_classes,n_val,n_tokens,n_bigrams,n_trigrams
emotion,16000,0.335125,6,2000,15165,106201,201404
financial_phrasebank,2313,0.623433,3,1140,7169,28481,39597
rotten_tomatoes,8530,0.5,2,1066,16631,93921,147426
sst2,67349,0.557826,2,872,13887,72501,108800
tweet_eval,9000,0.579667,2,1000,18476,106277,171769


In [18]:
def prep_for_printing(df):
    df = df.sort_values('n_train')
    df['num_classes'] = df.pop('num_classes') # move imbalance to end
    df['imbalance'] = df.pop('imbalance') # move imbalance to end
    df = df.infer_objects()
    for i in range(len(df.columns)):
        col_name = df.columns[i]
        if not 'imbalance' in col_name:
            df[col_name] = df[col_name].astype(int)
        else:
            df[col_name] = df[col_name].round(2).astype(str)
    df = df.rename(
        columns=data.COLUMNS_RENAME_DICT,
        index=data.DSETS_RENAME_DICT,
    ).sort_index()
        
    return df

pd.options.display.float_format = '{:,}'.format
print(prep_for_printing(df).transpose().to_latex())

\begin{tabular}{llllll}
\toprule
{} & Emotion & Financial phrasebank & Rotten tomatoes &    SST2 & Tweet (Hate) \\
\midrule
Samples (train)         &   16000 &                 2313 &            8530 &   67349 &         9000 \\
Samples (val)           &    2000 &                 1140 &            1066 &     872 &         1000 \\
Unigrams                &   15165 &                 7169 &           16631 &   13887 &        18476 \\
Bigrams                 &  106201 &                28481 &           93921 &   72501 &       106277 \\
Trigrams                &  201404 &                39597 &          147426 &  108800 &       171769 \\
Classes                 &       6 &                    3 &               2 &       2 &            2 \\
Majority class fraction &    0.34 &                 0.62 &             0.5 &    0.56 &         0.58 \\
\bottomrule
\end{tabular}

