### 📊 Dataset Preparation Summary

To make the large Amazon Polarity dataset more manageable and suitable for experimentation in a semi-supervised learning setup, we created a balanced and downsized subset of **4,000 samples**. This subset was split as follows:

- **Validation Set**: 400 samples (200 positive, 200 negative)  
- **Labeled Training Set**: 300 samples (150 positive, 150 negative)  
- **Unlabeled Training Set**: 3,300 samples (1,650 positive, 1,650 negative)

Each subset maintains a **50/50 class balance** to ensure fairness during training and evaluation.

The splits were saved in `.parquet` format to the `/data` directory for efficient storage and faster I/O operations during development.

This setup enables easy experimentation with semi-supervised learning approaches, where only a small fraction of the data is labeled.

In [10]:
from datasets import load_dataset, Dataset, DatasetDict
from collections import Counter
import random

# Load and shuffle dataset
dataset = load_dataset("fancyzhx/amazon_polarity", split="train").shuffle(seed=42)

# Separate by label
positives = [x for x in dataset if x['label'] == 1]
negatives = [x for x in dataset if x['label'] == 0]

# Take only 2000 positive and 2000 negative
positives = positives[:2000]
negatives = negatives[:2000]

# 200 pos + 200 neg for validation
val_pos = positives[:200]
val_neg = negatives[:200]

# 150 pos + 150 neg for labeled
labeled_pos = positives[200:350]
labeled_neg = negatives[200:350]

# 1650 pos + 1650 neg for unlabeled
unlabeled_pos = positives[350:2000]
unlabeled_neg = negatives[350:2000]

# Convert to HF Datasets
validation = Dataset.from_list(val_pos + val_neg).shuffle(seed=42)
labeled = Dataset.from_list(labeled_pos + labeled_neg).shuffle(seed=42)
unlabeled = Dataset.from_list(unlabeled_pos + unlabeled_neg).shuffle(seed=42)

# Wrap into DatasetDict
final_dataset = DatasetDict({
    "validation": validation,
    "labeled": labeled,
    "unlabeled": unlabeled
})

# Check counts
print("Validation:", Counter(final_dataset["validation"]["label"]))
print("Labeled:", Counter(final_dataset["labeled"]["label"]))
print("Unlabeled:", Counter(final_dataset["unlabeled"]["label"]))
total = len(validation) + len(labeled) + len(unlabeled)
print("Total samples:", total)

Validation: Counter({0: 200, 1: 200})
Labeled: Counter({0: 150, 1: 150})
Unlabeled: Counter({0: 1650, 1: 1650})
Total samples: 4000


Save the splits

In [12]:
import os

os.makedirs("data", exist_ok=True)

# save each split to data
final_dataset["validation"].to_parquet("data/validation.parquet")
final_dataset["labeled"].to_parquet("data/labeled.parquet")
final_dataset["unlabeled"].to_parquet("data/unlabeled.parquet")

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

1460690