# Class Swap Augmentation Overview

This notebook compares the base, truncated, and truncated+augmented datasets. It provides high-level and detailed views of class balance and text-length distributions, with text summaries alongside each visualization for reporting.


In [None]:
from __future__ import annotations

from pathlib import Path
import sys
import warnings

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from loguru import logger

warnings.filterwarnings("ignore")
logger.remove()
logger.add(sys.stderr, level="INFO")

sns.set_theme(style="whitegrid", context="notebook")
plt.rcParams["figure.figsize"] = (10, 6)


In [None]:
BASE_HF_DATASET = "kforghani/sentipers"
BASE_HF_SPLIT = "train"

TRUNCATED_CSV = Path("data/base/sentipers_train.csv")
AUGMENTED_CSV = Path("data/output/augmented_sentipers_train.csv")

CLASS_DEFINITIONS = {
    0: {"name": "Very Negative", "description": "intense dissatisfaction, frustration, or strong criticism"},
    1: {"name": "Negative", "description": "clear dislike, disappointment, or criticism without extremes"},
    2: {"name": "Neutral", "description": "factual or balanced tone with no strong positive/negative cues"},
    3: {"name": "Positive", "description": "clear approval, satisfaction, or praise without extremes"},
    4: {"name": "Very Positive", "description": "strong enthusiasm, admiration, or high praise"},
}
LABEL_NAMES = {key: value["name"] for key, value in CLASS_DEFINITIONS.items()}
LABEL_DESCRIPTIONS = {key: value["description"] for key, value in CLASS_DEFINITIONS.items()}


In [None]:
def load_base_dataset() -> pd.DataFrame:
    logger.info("Loading base dataset from HuggingFace: {} ({})", BASE_HF_DATASET, BASE_HF_SPLIT)
    hf_ds = load_dataset(BASE_HF_DATASET, split=BASE_HF_SPLIT)
    base_df = hf_ds.to_pandas()[["text", "label"]]
    base_df["dataset"] = "base"
    base_df["is_synthetic"] = False
    return base_df

def load_csv_dataset(path: Path, name: str) -> pd.DataFrame:
    logger.info("Loading dataset from {}", path)
    df = pd.read_csv(path)
    df = df[["text", "label"] + (["is_synthetic"] if "is_synthetic" in df.columns else [])]
    df["dataset"] = name
    if "is_synthetic" not in df.columns:
        df["is_synthetic"] = False
    df["is_synthetic"] = df["is_synthetic"].astype(bool)
    return df

base_df = load_base_dataset()
truncated_df = load_csv_dataset(TRUNCATED_CSV, "truncated")
augmented_df = load_csv_dataset(AUGMENTED_CSV, "augmented")

logger.info("Loaded base={} truncated={} augmented={}", len(base_df), len(truncated_df), len(augmented_df))


In [None]:
summary = pd.DataFrame(
    [
        {"dataset": "base", "records": len(base_df), "synthetic_ratio": 0.0},
        {"dataset": "truncated", "records": len(truncated_df), "synthetic_ratio": 0.0},
        {"dataset": "augmented", "records": len(augmented_df),
         "synthetic_ratio": augmented_df["is_synthetic"].mean()},
    ]
)
summary["synthetic_ratio"] = (summary["synthetic_ratio"] * 100).round(2)
summary


## Class Distribution (High-Level)
This compares label counts across base, truncated, and augmented datasets.


In [None]:
def add_label_name(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["label_name"] = df["label"].map(LABEL_NAMES)
    df["label_description"] = df["label"].map(LABEL_DESCRIPTIONS)
    return df

def class_distribution(df: pd.DataFrame) -> pd.DataFrame:
    counts = df.groupby(["dataset", "label"]).size().reset_index(name="count")
    totals = counts.groupby("dataset")["count"].transform("sum")
    counts["percent"] = (counts["count"] / totals * 100).round(1)
    counts["label_name"] = counts["label"].map(LABEL_NAMES)
    counts["label_description"] = counts["label"].map(LABEL_DESCRIPTIONS)
    return counts

combined_df = pd.concat([base_df, truncated_df, augmented_df], ignore_index=True)
combined_df = add_label_name(combined_df)
class_counts = class_distribution(combined_df)
class_counts.sort_values(["dataset", "label"])


In [None]:
logger.info("Class distribution summary (counts, %%) for each dataset:\n{}", class_counts)

sns.barplot(data=class_counts, x="label_name", y="count", hue="dataset")
plt.title("Class Distribution by Dataset")
plt.xlabel("Class")
plt.ylabel("Count")
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.show()


## Augmented Dataset: Base vs Synthetic Balance
This breaks down class counts inside the augmented dataset by origin (base vs synthetic).


In [None]:
augmented_only = augmented_df.copy()
augmented_only["label_name"] = augmented_only["label"].map(LABEL_NAMES)

aug_split = (
    augmented_only.groupby(["label_name", "is_synthetic"]).size().reset_index(name="count")
)
aug_split["origin"] = np.where(aug_split["is_synthetic"], "synthetic", "base")
aug_split


In [None]:
logger.info("Augmented dataset origin split:\n{}", aug_split)

pivot = aug_split.pivot(index="label_name", columns="origin", values="count").fillna(0)
pivot.plot(kind="bar", stacked=True)
plt.title("Augmented Dataset: Base vs Synthetic Counts")
plt.xlabel("Class")
plt.ylabel("Count")
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.show()


## Text Length Distributions (High-Level)
We compare character and word length distributions across datasets.


In [None]:
def add_length_features(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["char_len"] = df["text"].str.len()
    df["word_len"] = df["text"].str.split().str.len()
    return df

combined_len_df = add_length_features(combined_df)

length_summary = (
    combined_len_df.groupby("dataset")[["char_len", "word_len"]]
    .agg(["mean", "median", lambda s: np.percentile(s, 95)])
)
length_summary.columns = ["_".join(map(str, col)).replace("<lambda_0>", "p95") for col in length_summary.columns]
length_summary


In [None]:
logger.info("Text length summary (mean/median/p95):\n{}", length_summary)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.histplot(data=combined_len_df, x="char_len", hue="dataset", bins=40, ax=axes[0])
axes[0].set_title("Character Length Distribution")
axes[0].set_xlabel("Characters")

sns.histplot(data=combined_len_df, x="word_len", hue="dataset", bins=40, ax=axes[1])
axes[1].set_title("Word Length Distribution")
axes[1].set_xlabel("Words")

plt.tight_layout()
plt.show()


## Synthetic vs Base Text Length (Augmented Dataset)
This highlights whether synthetic samples differ in length from base records.


In [None]:
aug_len_df = add_length_features(augmented_df)
aug_len_df["origin"] = np.where(aug_len_df["is_synthetic"], "synthetic", "base")

origin_summary = (
    aug_len_df.groupby("origin")[["char_len", "word_len"]]
    .agg(["mean", "median", lambda s: np.percentile(s, 95)])
)
origin_summary.columns = ["_".join(map(str, col)).replace("<lambda_0>", "p95") for col in origin_summary.columns]
origin_summary


In [None]:
logger.info("Synthetic vs base length summary:\n{}", origin_summary)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.boxplot(data=aug_len_df, x="origin", y="char_len", ax=axes[0])
axes[0].set_title("Character Length: Base vs Synthetic")

sns.boxplot(data=aug_len_df, x="origin", y="word_len", ax=axes[1])
axes[1].set_title("Word Length: Base vs Synthetic")

plt.tight_layout()
plt.show()


## Text Length by Class (Base vs Synthetic)
This compares per-class length distributions between base and synthetic records.


In [None]:
aug_len_df["label_name"] = aug_len_df["label"].map(LABEL_NAMES)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.boxplot(data=aug_len_df, x="label_name", y="char_len", hue="origin", ax=axes[0])
axes[0].set_title("Character Length by Class")
axes[0].set_xlabel("Class")
axes[0].set_ylabel("Characters")
axes[0].tick_params(axis="x", rotation=30)

sns.boxplot(data=aug_len_df, x="label_name", y="word_len", hue="origin", ax=axes[1])
axes[1].set_title("Word Length by Class")
axes[1].set_xlabel("Class")
axes[1].set_ylabel("Words")
axes[1].tick_params(axis="x", rotation=30)

plt.tight_layout()
plt.show()


## Notes for Reporting
- Use the class distribution table and chart to highlight balance changes after augmentation.
- Use length summaries to describe how synthetic data compares in length to base data.
- Use per-class length boxplots to call out any label-specific artifacts introduced by augmentation.
