In [13]:
#csv_path = "data/experiment_plan_5_24.csv"  # Change if needed
csv_path = "data/experiment_plan_5_24_with_addons.csv"  # Change if needed
output_folder = "column_barplots_by_type"
columns_to_plot = [
    "harm_level", "benefit_level", "harmful_option", "option_swapped",
    "model_type", "original_sample_type", "topic", "harm_type", "benefit_reason"
]

In [18]:
import pandas as pd
import matplotlib.pyplot as plt
import os

def load_data(file_path: str) -> pd.DataFrame:
    """Load CSV data from the given file path."""
    return pd.read_csv(file_path)

def plot_stacked_counts(df: pd.DataFrame, group_col: str, stack_col: str, output_dir: str) -> None:
    """
    Create and save stacked bar plots for each group_col value,
    stacked by stack_col (e.g., item_type).
    """
    os.makedirs(output_dir, exist_ok=True)
    plt.rcParams.update({'font.size': 12})

    for col in group_col:
        # Drop NA and count values
        # filter only treatment and control
        df_filtered = df[df['item_type'].isin(['treatment', 'control'])]
        crosstab = pd.crosstab(df_filtered[col], df_filtered[stack_col])
        crosstab = crosstab.loc[crosstab.sum(axis=1).sort_values(ascending=False).index]  # sort bars

        # Print the crosstab
        #print(crosstab)
        # pretty print the crosstab
        print(crosstab.to_string())

        # Plot stacked bar chart
        ax = crosstab.plot(kind='bar', stacked=True, figsize=(14, 6), colormap='tab10')
        plt.title(f"Stacked Count of '{col}' by {stack_col}")
        plt.ylabel("Count")
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        filename = os.path.join(output_dir, f"{col}_stacked_by_{stack_col}.png")
        #plt.savefig(filename)
        #plt.show()
        plt.close()
        #print(f"Saved stacked plot to: {filename}")

# ==== CONFIGURATION ====
#csv_path = "/mnt/data/experiment_plan_5_24.csv"  # Adjust if needed
output_folder = "stacked_barplots_by_item_type"
columns_to_plot = [
    "harm_level", "benefit_level", "harmful_option", "option_swapped",
    "model_type", "original_sample_type", "topic", "harm_type", "benefit_reason"
]
stack_by = "item_type"
# ========================

# Load and plot
df = load_data(csv_path)
plot_stacked_counts(df, group_col=columns_to_plot, stack_col=stack_by, output_dir=output_folder)


item_type   control  treatment
harm_level                    
10                7         34
15                5         30
5                 8         26
item_type      control  treatment
benefit_level                    
50                   8         35
10                   6         34
30                   6         21
item_type       control  treatment
harmful_option                    
A                    11         45
B                     9         45
item_type       control  treatment
option_swapped                    
True                  9         45
item_type   control  treatment
model_type                    
Chatgpt           7         30
Claude            7         30
Gemini            6         30
item_type             control  treatment
original_sample_type                    
Treatment                   0         90
Control                    20          0
item_type                          control  treatment
topic                                                
Soc

In [11]:
from collections import defaultdict
from itertools import product
import pandas as pd
import random

def compute_value_counts(df, columns):
    return {col: df[col].value_counts().to_dict() for col in columns}

def get_all_values(df, columns):
    return {col: df[col].dropna().unique().tolist() for col in columns}

def get_underrepresented(value_counts, target_counts):
    return {
        col: [val for val in all_vals if value_counts[col].get(val, 0) < target_counts[col]]
        for col, all_vals in value_counts.items()
    }

def score_row(row, value_counts, target_counts):
    """How many columns does this row help balance?"""
    return sum(value_counts[col].get(row[col], 0) < target_counts[col] for col in row)

def simulate_balancing_greedy(df, columns, group_size=5, tolerance=2, max_groups=100):
    value_counts = compute_value_counts(df, columns)
    all_values = get_all_values(df, columns)
    target_counts = {col: max(counts.values()) for col, counts in value_counts.items()}

    added_rows = []

    for _ in range(max_groups):
        group = []

        # Generate smart rows
        candidates = []
        for _ in range(1000):
            row = {col: random.choice(all_values[col]) for col in columns}
            score = score_row(row, value_counts, target_counts)
            candidates.append((score, row))

        # Pick top-K scoring rows
        top_rows = sorted(candidates, key=lambda x: x[0], reverse=True)[:group_size]
        for _, row in top_rows:
            group.append(row)
            for col in columns:
                value_counts[col][row[col]] = value_counts[col].get(row[col], 0) + 1


        added_rows.extend(group)

        # Check if balanced
        all_balanced = True
        for col in columns:
            counts = list(value_counts[col].values())
            if max(counts) - min(counts) > tolerance:
                all_balanced = False
                break
        if all_balanced:
            break

    added_df = pd.DataFrame(added_rows)
    return added_df, value_counts

# ==== USE IT ====
#df = pd.read_csv("/mnt/data/experiment_plan_5_24.csv")
df = pd.read_csv("data/experiment_plan_5_24.csv")

columns_to_balance = [
    "harm_level", "benefit_level", "harmful_option", "option_swapped",
    "model_type", "original_sample_type", "topic", "harm_type", "benefit_reason", "item_type"
]

added_df, new_counts = simulate_balancing_greedy(df, columns_to_balance)

print(f"✅ Added {len(added_df)} synthetic rows ({len(added_df)//5} groups of 5).")

def verify_balance(value_counts, tolerance=2):
    for col, counts in value_counts.items():
        vals = list(counts.values())
        print(f"{col}: min={min(vals)}, max={max(vals)}", 
              "✅ OK" if max(vals) - min(vals) <= tolerance else "❌ NOT BALANCED")

verify_balance(new_counts)


✅ Added 500 synthetic rows (100 groups of 5).
harm_level: min=199, max=217 ❌ NOT BALANCED
benefit_level: min=199, max=216 ❌ NOT BALANCED
harmful_option: min=302, max=318 ❌ NOT BALANCED
option_swapped: min=271, max=349 ❌ NOT BALANCED
model_type: min=197, max=224 ❌ NOT BALANCED
original_sample_type: min=302, max=318 ❌ NOT BALANCED
topic: min=51, max=62 ❌ NOT BALANCED
harm_type: min=47, max=73 ❌ NOT BALANCED
benefit_reason: min=148, max=171 ❌ NOT BALANCED
item_type: min=150, max=165 ❌ NOT BALANCED
