#### By: Peyman Shahidi
#### Created: Oct 19, 2025
#### Last Edit: Dec 6, 2025

<br>

In [None]:
#Python
import getpass
import numpy as np
import pandas as pd
from collections import defaultdict
import itertools
import random 

## formatting number to appear comma separated and with two digits after decimal: e.g, 1000 shown as 1,000.00
pd.set_option('float_format', "{:,.2f}".format)

import matplotlib.pyplot as plt
#%matplotlib inline
#from matplotlib.legend import Legend

import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_rows', 200)

In [None]:
main_folder_path = ".."
input_data_path = f"{main_folder_path}/data"
output_data_path = f'{input_data_path}/computed_objects/execTypeVaryingDWA_anthropicIndex'
output_plot_path = f"{main_folder_path}/writeup/plots/execTypeVaryingDWA"

In [None]:
# Create directories if they don't exist
import os

for path in [output_data_path, output_plot_path]:
    if not os.path.exists(path):
        os.makedirs(path)

## Set variables

In [None]:
# Number of reshuffles
n_shuffles = 1000


# dependent_var = 'is_ai'
# plot_title_variable = 'Task is AI'
dependent_var = 'is_automated'
plot_title_variable = 'Task is Automated'


TARGET_REGS = ['prev2_is_ai', 'prev_is_ai', 'next_is_ai', 'next2_is_ai']
SPECS = ['no_fe', 'fe_MajorGroup', 'fe_MinorGroup']

PLOT_TITLES = ['Task Before Previous Task', 'Previous Task', 'Next Task', 'Task After Next Task']

### Main Code

In [None]:
# Get list of DWAs with tasks in multiple occupations
dwa_list_path = f"{input_data_path}/computed_objects/similar_dwa_tasks/similarTasks"

# Read all CSV files
import glob
dwa_csv_files = glob.glob(os.path.join(dwa_list_path, "*.csv"))
print(f"Found {len(dwa_csv_files)} DWA CSV files.")

# Load them into DataFrames, skipping 1-row files
dwa_dfs = []
skipped_files_count = 0
for f in dwa_csv_files:
    df = pd.read_csv(f)
    if len(df) > 1: # Skip if DWA contains only one task
        dwa_dfs.append(df)
    else:
        skipped_files_count += 1
print(f"Skipped {skipped_files_count} DWA files with only one task.")
    

# Combine into one DataFrame
df_all = pd.concat(dwa_dfs, ignore_index=True)
repetitive_dwa_task_ids = df_all['Task ID'].unique().tolist()
repetitive_dwa_task_titles = df_all['Task Title'].unique().tolist()
print(f"Found {len(repetitive_dwa_task_ids)} tasks related to these DWAs.")

In [None]:
# Create dataframe with all tasks that have survived the DWA task similarity procedure
survived_tasks_count_df = df_all.groupby('DWA ID')['Task ID'].nunique().reset_index(name='num_tasks_survived')
survived_tasks_count_df


In [None]:
# Create a DWA-level dataset with number of tasks and occupations per DWA, as well as fraction of manual, automation, and augmentation tasks per DWA
merged_data = pd.read_csv(f"{input_data_path}/computed_objects/ONET_Eloundou_Anthropic_GPT/ONET_Eloundou_Anthropic_GPT.csv")
merged_data['is_manual'] = merged_data['label'] == 'Manual'
merged_data['is_automation'] = merged_data['label'] == 'Automation'
merged_data['is_augmentation'] = merged_data['label'] == 'Augmentation'


# Merge back DWA ID and DWA Titles to the merged_data
dwa_task_mapping = pd.read_csv(f"{input_data_path}/computed_objects/similar_dwa_tasks/dwa_task_mapping.csv")
print(f'Length of merged_data before merging DWA info: {merged_data.shape[0]}')
merged_data = merged_data.merge(dwa_task_mapping, on=['Task ID', 'Task Title', 'O*NET-SOC Code', 'Occupation Title'], how='left')
print(f'Length of merged_data after merging DWA info: {merged_data.shape[0]}')


# Aggregate to get fractions
dwa_grouped = merged_data.groupby(['DWA ID', 'DWA Title']).agg(
    num_tasks = ('Task ID', 'nunique'),
    num_occupations = ('O*NET-SOC Code', 'nunique'),
    fraction_manual = ('is_manual', 'mean'),
    fraction_automation = ('is_automation', 'mean'),
    fraction_augmentation = ('is_augmentation', 'mean'),
).reset_index()
print(f"Created DWA-level dataset with {dwa_grouped.shape[0]} DWAs.")

# Keep only DWAs with variation in terms of execution type across occupations
dwa_grouped_filtered = dwa_grouped[
     (dwa_grouped['num_occupations'] > 1) & (dwa_grouped['fraction_manual'] > 0) & (dwa_grouped['fraction_manual'] < 1)
].copy()
display(dwa_grouped_filtered)

# Create list of DWAs with varying execution types
dwas_varying_exec_types_ids = dwa_grouped_filtered['DWA ID'].unique().tolist()
dwas_varying_exec_types_titles = dwa_grouped_filtered['DWA Title'].unique().tolist()
print(f"Identified {len(dwas_varying_exec_types_ids)} DWAs with varying execution types across occupations.")

# Merge back the number of tasks survived info
dwa_grouped_filtered = dwa_grouped_filtered.merge(survived_tasks_count_df, left_on='DWA ID', right_on='DWA ID', how='left')

# Save output
dwa_grouped_filtered.to_csv(f"{output_data_path}/dwas_varying_execution_types.csv", index=False)

In [None]:
# Read the merged data
merged_data = pd.read_csv(f"{input_data_path}/computed_objects/ONET_Eloundou_Anthropic_GPT/ONET_Eloundou_Anthropic_GPT.csv")
merged_data = merged_data[['O*NET-SOC Code', 'Occupation Title', 'Task ID', 'Task Title',
       'Task Position', 'Task Type', 
       'Major_Group_Code', 'Major_Group_Title', 
       'Minor_Group_Code', 'Minor_Group_Title',
       'Broad_Occupation_Code', 'Broad_Occupation_Title',
       'Detailed_Occupation_Code', 'Detailed_Occupation_Title',
       'gpt4_exposure', 'human_labels', 
       'automation', 'augmentation', 'label']]


# Create is_ai and is_automated flags in merged_data
merged_data['is_ai'] = merged_data['label'].isin(['Augmentation','Automation']).astype(int)
merged_data['is_automated'] = merged_data['label'].isin(['Automation']).astype(int)
merged_data['is_exposed'] = merged_data['human_labels'].isin(['E1']).astype(int)


# Step 1: Add occupation's number of tasks info
num_tasks_per_occupation = merged_data.groupby('O*NET-SOC Code')['Task ID'].nunique().reset_index()
num_tasks_per_occupation = num_tasks_per_occupation.rename(columns={'Task ID': 'num_tasks'})
merged_data = merged_data.merge(num_tasks_per_occupation, on='O*NET-SOC Code', how='left')


# Step 2: Create flags for previous/next tasks is AI within occupation groups
# Sort by occupation and position when possible
merged_data['Task Position'] = pd.to_numeric(merged_data['Task Position'], errors='coerce')
merged_data = merged_data.sort_values(['O*NET-SOC Code', 'Task Position']).reset_index(drop=True)
group_col = 'O*NET-SOC Code'

# Compute neighbor flags (prev/next) within occupation groups when possible
merged_data['prev_is_ai'] = 0
merged_data['prev2_is_ai'] = 0
merged_data['next_is_ai'] = 0
merged_data['next2_is_ai'] = 0
pos_col = 'Task Position'

def add_neighbor_flags(df):
    df = df.copy()
    df['Task Position'] = pd.to_numeric(df['Task Position'], errors='coerce')
    df = df.sort_values(['O*NET-SOC Code','Task Position']).reset_index(drop=True)
    def _add_flags(g):
        g = g.sort_values('Task Position')
        g['prev_is_ai'] = g['is_ai'].shift(1).fillna(0).astype(int)
        g['prev2_is_ai'] = g['is_ai'].shift(2).fillna(0).astype(int)
        # g['prev2_is_ai'] = ((g['prev2_is_ai'] == 1) & (g['prev_is_ai'] == 1)).astype(int)
        g['next_is_ai'] = g['is_ai'].shift(-1).fillna(0).astype(int)
        g['next2_is_ai'] = g['is_ai'].shift(-2).fillna(0).astype(int)
        # g['next2_is_ai'] = ((g['next2_is_ai'] == 1) & (g['next_is_ai'] == 1)).astype(int)
        return g
    return df.groupby('O*NET-SOC Code', group_keys=False).apply(_add_flags).reset_index(drop=True)
merged_data = merged_data.groupby(group_col, group_keys=False).apply(add_neighbor_flags).reset_index(drop=True)



# Step 3: Add back DWA info
# Merge back DWA ID and DWA Titles to the merged_data
dwa_task_mapping = pd.read_csv(f"{input_data_path}/computed_objects/similar_dwa_tasks/dwa_task_mapping.csv")
merged_data = merged_data.merge(dwa_task_mapping, on=['Task ID', 'Task Title', 'O*NET-SOC Code', 'Occupation Title'], how='left')
# Note that the merge might map multiple DWAs to the same task


# Step 4: Flag "similar" tasks across occupations
merged_data['dwa_execType_varying'] = (
    (merged_data['DWA ID'].isin(dwas_varying_exec_types_ids)
    & 
    merged_data['Task ID'].isin(repetitive_dwa_task_ids)
    )
    & ~(merged_data['DWA ID'].isna())
).astype(int)

# Remove duplicates in terms of (O*NET-SOC Code, Task ID) if any
print(f'Length of merged_data before dropping duplicates: {merged_data.shape[0]}')
merged_data = merged_data.drop_duplicates(subset=['O*NET-SOC Code', 'Task ID'])
print(f'Length of merged_data after dropping duplicates: {merged_data.shape[0]}')
# Save the updated merged_data with flags
merged_data[merged_data['dwa_execType_varying'] == 1].to_csv(f"{output_data_path}/merged_data_DWAexecVaryingTypes.csv", index=False)


# Summary for flagged DWA rows
mask = merged_data['dwa_execType_varying'] == 1
n_flagged = int(mask.sum())
print(f'\nNumber of dwa_execType_varying rows: {n_flagged}')
if n_flagged > 0:
    for c in ['prev2_is_ai', 'prev_is_ai', 'next_is_ai', 'next2_is_ai']:
        s = int(merged_data.loc[mask, c].sum())
        frac = merged_data.loc[mask, c].mean()
        print(f'{c}: {s} of {n_flagged} flagged rows (fraction={frac:.3f})')
    try:
        display(merged_data.loc[mask].head())
    except Exception:
        print(merged_data.loc[mask].head().to_string(index=False))
else:
    print('No flagged rows to summarize.')


## Run regression of multiple-execution-type DWA tasks against execution type of neighboring tasks

In [None]:
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from scipy.stats import norm
from pathlib import Path

# --- Configuration ---
TARGET_REGS = ['prev2_is_ai', 'prev_is_ai', 'next_is_ai', 'next2_is_ai']

# Labels to match your desired output format
VAR_LABELS = {
    'prev2_is_ai': '($t-2$) Task AI',
    'prev_is_ai': '($t-1$) Task AI',
    'next_is_ai': '($t+1$) Task AI',
    'next2_is_ai': '($t+2$) Task AI'
}

# ==========================================
# 1. Robust AME Extractor
# ==========================================
def build_ame_df(res, dataset_name, model_name, target_regs, fe_label):
    try:
        # --- Calculate Model Statistics ---
        pr2 = res.prsquared
        k = res.params.shape[0]
        adj_pr2 = 1 - (res.llf - k) / res.llnull
        nobs = res.nobs

        # --- Calculate AME ---
        margeff = res.get_margeff(at='overall', method='dydx', dummy=True)
        summary = margeff.summary_frame()
        
        summary = summary.reset_index().rename(columns={'index': 'term'})
        
        rename_map = {
            'dy/dx': 'ame_coef',
            'std err': 'ame_se', 'Std. Err.': 'ame_se',
            'P>|z|': 'p_value', 'z': 'z_score'
        }
        summary = summary.rename(columns=rename_map)
        
        if 'ame_se' not in summary.columns and summary.shape[1] >= 2:
             summary['ame_se'] = summary.iloc[:, 1]

        summary = summary[summary['term'].isin(target_regs)].copy()
        if summary.empty: return pd.DataFrame()

        # --- Manual P-Value Calculation ---
        summary['ame_coef'] = pd.to_numeric(summary['ame_coef'], errors='coerce')
        summary['ame_se'] = pd.to_numeric(summary['ame_se'], errors='coerce')
        
        if 'p_value' not in summary.columns or summary['p_value'].isnull().any():
            z_stat = summary['ame_coef'] / summary['ame_se']
            summary['p_value'] = 2 * (1 - norm.cdf(np.abs(z_stat)))

        df = pd.DataFrame({
            'dataset': dataset_name,
            'model': model_name,
            'fe_label': fe_label,
            'nobs': nobs,
            'r2_pseudo': pr2,
            'r2_adj_pseudo': adj_pr2,
            'term': summary['term'],
            'ame_coef': summary['ame_coef'],
            'ame_se': summary['ame_se'],
            'p_value': summary['p_value']
        })
        return df

    except Exception as e:
        print(f"Error calculating AME for {model_name}: {e}")
        return pd.DataFrame()

# ==========================================
# 2. Regression Runner (FIXED)
# ==========================================
def run_regressions_on(df, dataset_name, dependent_var, regressors):
    df = df.copy()
    all_cols = regressors + [dependent_var, 'is_exposed', 'num_tasks']
    existing_cols = [c for c in all_cols if c in df.columns]
    df[existing_cols] = df[existing_cols].apply(pd.to_numeric, errors='coerce').fillna(0)

    base_formula = f'{dependent_var} ~ ' + ' + '.join(regressors)
    ame_list = []
    models = {} # We store models here

    # 1) No FE
    try:
        formula = base_formula + ' + is_exposed + num_tasks'
        res = smf.logit(formula, data=df).fit(
        disp=False, 
        cov_type='cluster',
        cov_kwds={'groups': df['DWA ID'],
                  'use_correction': True}
        )
        models['no_fe'] = res
        ame_list.append(build_ame_df(res, dataset_name, 'no_fe', regressors, fe_label="None"))
        # print(f"[{dataset_name}] No-FE model converged.")
    except Exception as e: 
        print(f"[{dataset_name}] No-FE failed: {e}")

    # 2) Fixed Effects
    fe_cols = [('Major_Group_Code', 'MajorGroup', 'Major Group'), 
               ('Minor_Group_Code', 'MinorGroup', 'Minor Group')]
    
    for col, short, nice_label in fe_cols:
        if col not in df.columns: continue
        try:
            formula = base_formula + f' + C({col}) + is_exposed + num_tasks'
            df_fe = df.groupby(col).filter(lambda g: g[dependent_var].nunique() == 2 and len(g) >= 10)
            res = smf.logit(formula, data=df_fe).fit(
            disp=False, 
            cov_type='cluster',
            cov_kwds={'groups': df_fe['DWA ID'],
                      'use_correction': True}
            )
            models[f'fe_{short}'] = res
            ame_list.append(build_ame_df(res, dataset_name, f'fe_{short}', regressors, fe_label=nice_label))
            # print(f"[{dataset_name}] FE {short} converged.")
        except Exception as e:
            print(f"[{dataset_name}] FE {short} failed: {e}")

    # FIXED RETURN: Returns tuple (models, dataframe)
    combined = pd.concat(ame_list, ignore_index=True) if ame_list else pd.DataFrame()

    # Save results to CSV
    out_path = f'{output_data_path}/regression_summaries_{dependent_var}'
    os.makedirs(out_path, exist_ok=True)
    combined.to_csv(f'{out_path}/regression_ame_results_{dataset_name}.csv', index=False)

    return models, combined

# ==========================================
# 3. LaTeX Table Generator
# ==========================================
def generate_latex_table(df_results):
    if df_results.empty:
        print("No results to tabulate.")
        return

    # Filter for one dataset
    dataset_to_show = df_results['dataset'].unique()[0]
    subset = df_results[df_results['dataset'] == dataset_to_show].copy()
    
    print(f"\n% --- LaTeX Table for {dataset_to_show} ---")

    # --- Formatting ---
    def fmt(row):
        stars = ""
        p = row['p_value']
        if pd.notna(p):
            if p < 0.01: stars = "***"
            elif p < 0.05: stars = "**"
            elif p < 0.10: stars = "*"
        return f"{row['ame_coef']:.2f}{stars}", f"({row['ame_se']:.2f})"

    formatted = subset.apply(fmt, axis=1, result_type='expand')
    subset['coef_str'] = formatted[0]
    subset['se_str'] = formatted[1]

    # Pivot
    pivot_coef = subset.pivot(index='term', columns='model', values='coef_str')
    pivot_se = subset.pivot(index='term', columns='model', values='se_str')

    # Ordering
    valid_vars = [v for v in TARGET_REGS if v in pivot_coef.index]
    pivot_coef = pivot_coef.reindex(valid_vars)
    pivot_se = pivot_se.reindex(valid_vars)
    
    model_order = ['no_fe', 'fe_MajorGroup', 'fe_MinorGroup']
    valid_models = [m for m in model_order if m in pivot_coef.columns]

    # Extract Footer Stats
    stats = subset[['model', 'nobs', 'r2_pseudo', 'r2_adj_pseudo', 'fe_label']].drop_duplicates('model').set_index('model')

    # --- Print LaTeX ---
    col_def = "l" + "c" * len(valid_models) 
    
    print(f"\\begin{{tabular}}{{{col_def}}}")
    print(r"\toprule")
    
    # Header
    header_nums = [f"({i+1})" for i in range(len(valid_models))]
    print(f"Specification & " + " & ".join(header_nums) + r" \\")
    print(r"\midrule")

    # Body (Variables)
    for var in valid_vars:
        label = VAR_LABELS.get(var, var.replace('_', ' '))
        
        # Coefficient Row
        c_vals = [pivot_coef.loc[var, m] if pd.notna(pivot_coef.loc[var, m]) else "" for m in valid_models]
        print(f"{label} & " + " & ".join(c_vals) + r" \\")
        
        # SE Row
        s_vals = [pivot_se.loc[var, m] if pd.notna(pivot_se.loc[var, m]) else "" for m in valid_models]
        print(f" & " + " & ".join(s_vals) + r" \\")
        print(r"\addlinespace")

    print(r"\midrule")
    
    # --- Footer ---
    
    # Pseudo R2
    r2_vals = [f"{stats.loc[m, 'r2_pseudo']:.3f}" if m in stats.index else "" for m in valid_models]
    print(f"Pseudo $R^2$ & " + " & ".join(r2_vals) + r" \\")
    
    # Adj Pseudo R2
    adj_r2_vals = [f"{stats.loc[m, 'r2_adj_pseudo']:.3f}" if m in stats.index else "" for m in valid_models]
    print(f"Adj. Pseudo $R^2$ & " + " & ".join(adj_r2_vals) + r" \\")
    
    # Observations
    obs_vals = [f"{int(stats.loc[m, 'nobs']):,}" if m in stats.index else "" for m in valid_models]
    print(f"Observations & " + " & ".join(obs_vals) + r" \\")

    # Fixed Effects
    fe_vals = []
    for m in valid_models:
        if m in stats.index:
            label = stats.loc[m, 'fe_label']
            if pd.isna(label) or str(label) == "None":
                fe_vals.append("")
            else:
                fe_vals.append(str(label)[:5])
        else:
            fe_vals.append("")

    print(f"SOC Group Fixed Effects & " + " & ".join(fe_vals) + r" \\")
    
    print(r"\bottomrule")
    print(r"\footnotesize{Standard errors in parentheses (clustered at DWA level). *** p$<$0.01, ** p$<$0.05, * p$<$0.1}")
    print(r"\end{tabular}")

# ==========================================
# 4. Execution Block
# ==========================================

# dependent_var = 'is_ai'

# print(">>> Running Regressions on Full Data...")
# models_full, res_full = run_regressions_on(merged_data, 'full_0', dependent_var, TARGET_REGS)

print(">>> Running Regressions on Filtered Data...")
filtered = merged_data[merged_data['dwa_execType_varying'] == 1].reset_index(drop=True)
models_filt, res_filt = run_regressions_on(filtered, 'filtered_0', dependent_var, TARGET_REGS)

# Generate Tables
# print("\n\n")
# generate_latex_table(res_full)
# print("\n\n")
generate_latex_table(res_filt)

In [None]:
filtered[['O*NET-SOC Code', 'Occupation Title', 'Task ID', 'Task Title', 'Task Position', 'label', 'human_labels']]

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Define the specifications we expect to see
SPECS = ['no_fe', 'fe_MajorGroup', 'fe_MinorGroup']

# ======================================================
# 1. Helper: Extract AME directly (No Transformation)
# ======================================================
def results_to_dict(df_results):
    """
    Reads the dataframe output from run_regressions_on and converts
    the 'ame_coef' column into a dictionary structure. Also returns
    'ame_se' values as a separate dict.
    Returns: (coef_dict, se_dict) with structure coef_dict[spec][term]
    """
    coef_out = {spec: {term: np.nan for term in TARGET_REGS} for spec in SPECS}
    se_out = {spec: {term: np.nan for term in TARGET_REGS} for spec in SPECS}
    
    if df_results is None or (hasattr(df_results, 'empty') and df_results.empty):
        return coef_out, se_out

    for _, row in df_results.iterrows():
        spec = row.get('model')
        term = row.get('term')
        if spec in coef_out and term in coef_out[spec]:
            if 'ame_coef' in row and pd.notna(row['ame_coef']):
                coef_out[spec][term] = row['ame_coef']
            if 'ame_se' in row and pd.notna(row['ame_se']):
                se_out[spec][term] = row['ame_se']
    return coef_out, se_out

# Store observed values
# obs_dict_full, obs_se_full = results_to_dict(res_full)
obs_dict_filt, obs_se_filt = results_to_dict(res_filt)

# ======================================================
# 2. Reshuffling Loop
# ======================================================
# Prepare containers for reshuffled AMEs
resh_full = {spec: {t: [] for t in TARGET_REGS} for spec in SPECS}
resh_filt = {spec: {t: [] for t in TARGET_REGS} for spec in SPECS}

# Assuming n_shuffles is defined (e.g., 1000)
print(f'Running {n_shuffles} reshuffles to generate Null Distribution of AMEs...')

for i in range(n_shuffles):
    seed = 42 + i
    
    # CHANGED FILENAME: Use '_ame_summary.csv' to avoid loading old cached raw-coef files
    # fname_full = f"{output_data_path}/regression_summaries_{dependent_var}/regression_ame_results_full_{i}.csv"
    fname_filt = f"{output_data_path}/regression_summaries_{dependent_var}/regression_ame_results_filtered_{i}.csv"

    # --- Load or Compute ---
    # if Path(fname_full).exists() and Path(fname_filt).exists():
    if Path(fname_filt).exists():
        # Load existing results (CSV produced by run_regressions_on)
        res_shuf_filt = pd.read_csv(fname_filt)
    else:
        # Create Shuffled Data
        df_shuf = merged_data.copy()
        # Shuffle Task Position within O*NET Code
        df_shuf['Task Position'] = df_shuf.groupby('O*NET-SOC Code')['Task Position'].transform(
            lambda x: x.sample(frac=1, random_state=seed).values
        )
        
        # Re-calculate neighbor flags based on shuffled positions
        df_shuf = add_neighbor_flags(df_shuf)
        
        # Run Regressions (Filtered)
        df_shuf_filt = df_shuf[df_shuf['dwa_execType_varying'] == 1].reset_index(drop=True)
        _, res_shuf_filt = run_regressions_on(df_shuf_filt, f'filtered_{i}', dependent_var=dependent_var, regressors=TARGET_REGS)
    # --- Store Results ---
    d_filt, d_filt_se = results_to_dict(res_shuf_filt)

    # Append to lists
    for spec in SPECS:
        for t in TARGET_REGS:
            resh_filt[spec][t].append(d_filt[spec][t])

    if (i+1) % 50 == 0:
        print(f'  Completed {i+1}/{n_shuffles}')

print('Reshuffles complete; Marginal Effects stored in resh_full and resh_filt.')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import os # Ensure os is imported for makedirs

# --- Plotting: distributions of Marginal Effects (AME) ---
def plot_comparison_hist(resh_dict, obs_dict, obs_se_dict, title, out_name, plot_title_variable, bins=30):
    """Create the multi-row comparison histogram and also save each row (spec) as a separate image.

    Args:
        resh_dict: dict of reshuffled AMEs per spec and term
        obs_dict: dict of observed AMEs per spec and term
        obs_se_dict: dict of observed AME standard errors per spec and term
        title: title string to include in saved figures
        out_name: filename for the full multi-row figure
        plot_title_variable: human-readable dependent var name for titles
        bins: histogram bins
    """
    
    # --- 1. Calculate Global Bounds for X-Axis ---
    all_resh_vals = [v 
                     for inner_dict in resh_dict.values()
                     for values_list in inner_dict.values()
                     for v in values_list if not np.isnan(v)]
    
    all_obs_vals = [v 
                    for inner_dict in obs_dict.values()
                    for v in inner_dict.values() if not np.isnan(v)]
    
    total_vals = all_resh_vals + all_obs_vals
    if not total_vals:
        print("Warning: No valid data found to plot.")
        return

    g_min, g_max = min(total_vals), max(total_vals)
    symmetric_bound = max(abs(g_min), abs(g_max))
    span = 2 * symmetric_bound
    if span == 0: span = 0.1
    x_limit_min = -symmetric_bound - (span * 0.125)
    x_limit_max = symmetric_bound + (span * 0.125)

    # --- 2. Setup Plot ---
    colors = [plt.cm.tab10(i % 10) for i in range(len(SPECS))]
    fig, axes = plt.subplots(nrows=len(SPECS), ncols=len(TARGET_REGS), 
                             figsize=(6*len(TARGET_REGS), 5*len(SPECS)), 
                             sharey='col')

    if len(SPECS) == 1:
        axes = np.array([axes])

    for r, spec in enumerate(SPECS):
        color_row = colors[r]
        for c, term in enumerate(TARGET_REGS):
            ax = axes[r, c]
            vals = np.array(resh_dict[spec][term], dtype=float)
            vals_clean = vals[~np.isnan(vals)]

            if len(vals_clean):
                # # Shade reshuffle 95% CI behind the histogram
                # lo, hi = np.percentile(vals_clean, [2.5, 97.5])
                # # Clip to x-limits
                # lo = max(lo, x_limit_min)
                # hi = min(hi, x_limit_max)
                # if hi <= lo:
                #     eps = 1e-8 if abs(lo) > 0 else 1e-4
                #     lo, hi = lo - eps, hi + eps
                # ax.axvspan(lo, hi, color=color_row, alpha=0.12, zorder=0)

                # Draw histogram on top of the shaded CI
                ax.hist(vals_clean, bins=bins, color=color_row, alpha=0.7, edgecolor='k', label='Task Position Reshuffled AMEs', zorder=2)
            else:
                ax.text(0.5, 0.5, 'no estimates', ha='center', va='center')

            # Observed AME (red dashed) and Â±1.645*SE (vertical lines/shade)
            obs_val = obs_dict.get(spec, {}).get(term, np.nan)
            if not np.isnan(obs_val):
                obs_se = obs_se_dict.get(spec, {}).get(term, np.nan)
                if not np.isnan(obs_se):
                    se_band = 1.645 * obs_se
                    # Shade observed SE band slightly above histogram but under the observed line
                    ax.axvspan(obs_val - se_band, obs_val + se_band, color='red', alpha=0.08, zorder=1)
                    # Also draw thin boundary lines for the observed SE band
                    ax.axvline(obs_val - se_band, color='red', linestyle='--', linewidth=1, alpha=0.9, zorder=3)
                    ax.axvline(obs_val + se_band, color='red', linestyle='--', linewidth=1, alpha=0.9, zorder=3)
                # Observed center line on top
                ax.axvline(obs_val, color='red', linestyle='--', linewidth=3, label=f'Observed = {obs_val:.3f}', zorder=4)

            # Baseline (No Effect) set to 0
            ax.axvline(0.0, color='black', linestyle='-', linewidth=1.5, alpha=0.5, zorder=4)

            # Titles and Labels
            if r == 0:
                clean_title = VAR_LABELS.get(term, term) if 'VAR_LABELS' in globals() else PLOT_TITLES[c]
                ax.set_title(clean_title, fontsize=15, fontweight='bold')
            if r == len(SPECS) - 1:
                ax.set_xlabel('Average Marginal Effect', fontsize=15)
            if c == 0:
                clean_spec = spec.replace('fe_', '').replace('_', ' ').title()
                ax.set_ylabel(f'{clean_spec}\nCount', fontsize=15)

            # Apply consistent X-limits and grid
            ax.set_xlim(x_limit_min, x_limit_max)
            ax.grid(axis='y', linestyle=':', alpha=0.5)
            ax.legend(loc='best', fontsize=10)

    fig.tight_layout()

    # Ensure output dir exists
    Path(output_plot_path).mkdir(parents=True, exist_ok=True)
    out_dir = f'{output_plot_path}/{dependent_var}'
    os.makedirs(out_dir, exist_ok=True)
    out_path = f'{out_dir}/{out_name}'
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    print('Saved full multi-row plot to', out_path)

    # --- 3. Save Individual Rows ---
    base_name = out_name.rsplit('.', 1)[0]
    for r, spec in enumerate(SPECS):
        fig_row, axs_row = plt.subplots(nrows=1, ncols=len(TARGET_REGS), figsize=(24, 5), sharey=False)
        if len(TARGET_REGS) == 1: axs_row = [axs_row]
        color_row = colors[r]
        for c, term in enumerate(TARGET_REGS):
            axr = axs_row[c]
            vals = np.array(resh_dict[spec][term], dtype=float)
            vals_clean = vals[~np.isnan(vals)]

            if len(vals_clean):
                # lo, hi = np.percentile(vals_clean, [2.5, 97.5])
                # lo = max(lo, x_limit_min)
                # hi = min(hi, x_limit_max)
                # if hi <= lo:
                #     eps = 1e-8 if abs(lo) > 0 else 1e-4
                #     lo, hi = lo - eps, hi + eps
                # axr.axvspan(lo, hi, color=color_row, alpha=0.12, zorder=0)
                axr.hist(vals_clean, bins=bins, color=color_row, alpha=0.7, edgecolor='k', label='Task Position Reshuffled AMEs', zorder=2)
            else:
                axr.text(0.5, 0.5, 'no estimates', ha='center', va='center')

            obs_val = obs_dict.get(spec, {}).get(term, np.nan)
            if not np.isnan(obs_val):
                obs_se = obs_se_dict.get(spec, {}).get(term, np.nan)
                if not np.isnan(obs_se):
                    se_band = 1.645 * obs_se
                    axr.axvspan(obs_val - se_band, obs_val + se_band, color='red', alpha=0.08, zorder=1)
                    axr.axvline(obs_val - se_band, color='red', linestyle='--', linewidth=1, alpha=0.9, zorder=3)
                    axr.axvline(obs_val + se_band, color='red', linestyle='--', linewidth=1, alpha=0.9, zorder=3)
                axr.axvline(obs_val, color='red', linestyle='--', linewidth=3, label=f'Observed = {obs_val:.3f}', zorder=4)

            # Zero line
            axr.axvline(0.0, color='black', linestyle='-', linewidth=1.5, alpha=0.5, zorder=4)
            clean_title = VAR_LABELS.get(term, term) if 'VAR_LABELS' in globals() else PLOT_TITLES[c]
            axr.set_title(clean_title, fontsize=15)
            if c == 0:
                axr.set_ylabel('Count', fontsize=15)
            axr.grid(axis='y', linestyle=':', alpha=0.5)
            axr.set_xlim(x_limit_min, x_limit_max)
            axr.set_xlabel('Average Marginal Effect', fontsize=15)
            axr.legend(loc='best', fontsize=10)

        fig_row.tight_layout()
        out_path_row = f'{out_dir}/{base_name}_{spec}.png'
        fig_row.savefig(out_path_row, dpi=300, bbox_inches='tight')
        plt.close(fig_row)
        print('Saved row plot to', out_path_row)

    plt.close()

Path(output_plot_path).mkdir(parents=True, exist_ok=True)

plot_comparison_hist(
    resh_filt, 
    obs_dict_filt, 
    obs_se_filt, 
    f'FILTERED Dataset (n={n_shuffles})', 
    f'AME_filtered_{dependent_var}.png', 
    plot_title_variable
)

print('All done: comparative Marginal Effect histogram figures created.')