In [1]:
# dataset

In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer
from collections import Counter
import numpy as np

# GPT-2 tokenizer for measuring tokens
tokenizer = AutoTokenizer.from_pretrained("gpt2")


def analyze_dataset(name, subset=None, text_keys=["sentence", "text", "content", "review_body"]):
    # ----- Load dataset -----
    if subset:
        ds = load_dataset(name, subset)
        tag = f"{name} / {subset}"
    else:
        ds = load_dataset(name)
        tag = name

    # detect train split
    if "train" in ds:
        train = ds["train"]
    elif "validation" in ds:
        train = ds["validation"]
    else:
        raise ValueError(f"No train or validation split found for {tag}")

    # ----- Extract labels -----
    labels = [int(sample["label"]) for sample in train]
    label_counts = Counter(labels)

    # ----- Get text column -----
    def get_text(sample):
        for key in text_keys:
            if key in sample:
                return sample[key]
        raise ValueError(f"Could not find text field in dataset {tag}")

    texts = [get_text(x) for x in train]

    # ----- Compute token lengths -----
    token_lengths = []
    for t in texts:
        ids = tokenizer.encode(t, add_special_tokens=False)
        token_lengths.append(len(ids))

    max_tokens = max(token_lengths)

    # ----- Print results -----
    print("\n==============================================")
    print(f"DATASET: {tag}")
    print("----------------------------------------------")
    print(f"Train samples: {len(train)}")
    print(f"Max token length in dataset: {max_tokens}")
    print("Class distribution:")
    for cls, count in label_counts.items():
        print(f"  Class {cls}: {count}")
    print("==============================================\n")


# ======================================================
# RUN ANALYSIS FOR ALL BINARY SENTIMENT DATASETS
# ======================================================

# 1. SST-2
analyze_dataset("glue", subset="sst2")

# 2. Financial PhraseBank (50% agreement)
#analyze_dataset("takala/financial_phrasebank", "sentences_50agree")

# 3. IMDb
analyze_dataset("imdb", text_keys=["text"])

# 4. Rotten Tomatoes
analyze_dataset("rotten_tomatoes", text_keys=["text"])

# 5. Yelp Polarity
analyze_dataset("yelp_polarity", text_keys=["text"])

# 6. Amazon Polarity
analyze_dataset("amazon_polarity", text_keys=["content"])


DATASET: glue / sst2
----------------------------------------------
Train samples: 67349
Max token length in dataset: 65
Class distribution:
  Class 0: 29780
  Class 1: 37569



Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors



DATASET: imdb
----------------------------------------------
Train samples: 25000
Max token length in dataset: 3097
Class distribution:
  Class 0: 12500
  Class 1: 12500


DATASET: rotten_tomatoes
----------------------------------------------
Train samples: 8530
Max token length in dataset: 78
Class distribution:
  Class 1: 4265
  Class 0: 4265


DATASET: yelp_polarity
----------------------------------------------
Train samples: 560000
Max token length in dataset: 2200
Class distribution:
  Class 0: 280000
  Class 1: 280000


DATASET: amazon_polarity
----------------------------------------------
Train samples: 3600000
Max token length in dataset: 659
Class distribution:
  Class 1: 1800000
  Class 0: 1800000

