In [1]:
from __init__ import *

In [None]:
# Plotting Functions for Exp1

def plot_tpr_per_attack(args, results_df):

    results_df['set_fpr'].unique()[0] # set_fpr should be the same for all experiments, so we can just take the first value

    # drop the no_attack case from the results_df
    results_df = results_df[results_df['attack_name'] != 'no_attack']


    attack_names = results_df['attack_name'].unique()
    wm_methods = results_df['wm_method'].unique()
    models = results_df['model_id'].unique()

    # order the attacks and methods based on the order in name_mapping
    attack_names = np.array(sorted(attack_names, key=lambda x: list(ATTACK_NAME_MAPPING.keys()).index(x)))
    wm_methods = np.array(sorted(wm_methods, key=lambda x: list(METHODS_NAME_MAPPING.keys()).index(x)))
    models = np.array(sorted(models, key=lambda x: list(MODEL_NAME_MAPPING.keys()).index(x)))

    # for each attack (rows), plot all 4 WM methods in 4 sublpots (cols), all 2 models as lines

    ncols = wm_methods.shape[0] + 1 # per method, ülus one for title
    nrows = attack_names.shape[0] # for each attack
    fs = 10
    fs_xticks = 8
    fs_yticks = 8
    fs_title = 14
    y_adj = 0.937
    title_height_ratio = 0.15#0.65
    height_correction = 0
    title = ( 
        f'Performance of watermarking methods under different attacks\n'
        f'for dataset "{args.prompt_dataset}" for experiments in \n'
        f'{args.dataset_identifier}'
    )

    fig, gs, title_axes = setup_gridspec_figure(
        nrows=nrows, ncols=ncols ,
        fs=fs, title=title, fs_title=fs_title,
        y_adj=y_adj, title_height_ratio=title_height_ratio,
        sp_width=2, sp_height=1.75, height_correction=height_correction,
    )

    # # set the titles for each row, as the attack names
    # for i, ax in enumerate(title_axes):
    #     ax.text(0.5, 0.4, ATTACK_NAME_MAPPING[attack_names[i]]['name'], fontsize=fs_title, fontweight="bold", ha="center", va="center")
                      
    handles, labels = [], []

    # loop through all attacks (rows), and then per attack, loop through all WM methods
    for i, attack_name in enumerate(attack_names): # rows
        attack_df = results_df[results_df['attack_name'] == attack_name]
        if attack_name not in ATTACK_NAME_MAPPING:
            continue

        axes = [fig.add_subplot(gs[2*i +1, j]) for j in range(ncols)]
        for j, wm_method in enumerate(np.concatenate((wm_methods, ["title"]))): # columns
            if wm_method == "title": # last column is title of the attack
                axes[j].axis('off')
                axes[j].text(0.1, 0.5, ATTACK_NAME_MAPPING[attack_name]['name'], fontsize=fs, fontweight="bold", ha="left", va="center")
                if i == 0:
                    axes[j].set_title('Attacktype', fontsize=fs)
            else:
                wm_df = attack_df[attack_df['wm_method'] == wm_method]
                
                # Set axis direction based on attack type
                if ATTACK_NAME_MAPPING[attack_name]['order'] == 'low-to-high':
                    axes[j].invert_xaxis()
                    
                if i == 0:
                    axes[j].set_title(METHODS_NAME_MAPPING[wm_method], fontsize=fs)
                
                axes[j].set_yticks(np.arange(0, 1.1, 0.25))
                axes[j].set_yticklabels(np.arange(0, 1.1, 0.25), fontsize=fs_yticks)
                axes[j].set_ylim([-0.1, 1.1])
                axes[j].grid(True, linestyle='--', alpha=0.5)
                # set top and right spines to invisible
                axes[j].spines['top'].set_visible(False)
                axes[j].spines['right'].set_visible(False)

                if j == 0:# Add y-axis label to the first plot in each row
                    axes[j].set_ylabel("TPR@FPR=0.01")
                else:# disable y-axis labels for all but the first column
                    plt.setp(axes[j].get_yticklabels(), visible=False)
                    plt.setp(axes[j].get_yticklines(), visible=False)

                for model in models: # lines
                    model_df = wm_df[wm_df['model_id'] == model]
                    # Check if the model_df is empty
                    if model_df.empty:
                        print(f"Warning: No data for {attack_name}, {wm_method}, {model}\n")
                        continue

                    if attack_name == 'no_attack':
                        # No need to order the attack strengths for the no attack case
                        strengths = model_df['attack_strength'].unique()
                        results = model_df['tpr_empirical'].values
                        ci_lower = model_df['tpr_ci_lower_percentile'].values
                        ci_upper = model_df['tpr_ci_upper_percentile'].values
                    else:
                        strengths, results, ci_lower, ci_upper = order_attack_strengths(
                            ATTACK_NAME_MAPPING[attack_name]['order'],
                            model_df['attack_strength'], 
                            model_df['tpr_empirical'],
                            model_df['tpr_ci_lower_percentile'],
                            model_df['tpr_ci_upper_percentile'],
                            ATTACK_NAME_MAPPING[attack_name]['cast_to_int'],
                        )
                    
                    label = MODEL_NAME_MAPPING[model]['name']
                    
                    # Plot using actual strength values
                    line, = axes[j].plot(strengths, results,
                                marker=MODEL_NAME_MAPPING[model]['marker'],
                                linestyle=MODEL_NAME_MAPPING[model]['line'],
                                label=label,
                                color=MODEL_NAME_MAPPING[model]['color'])
                    
                    if (not np.isnan(ci_lower).any() and not np.isnan(ci_upper).any()) or (len(ci_lower) > 0 and len(ci_upper) > 0):
                        axes[j].fill_between(strengths, ci_lower, ci_upper, color=MODEL_NAME_MAPPING[model]['color'], alpha=0.2)
                        if attack_name == 'no_attack':
                            axes[j].plot(strengths, ci_lower, color=MODEL_NAME_MAPPING[model]['color'], alpha=0.2, marker='x', linestyle='--')
                            axes[j].plot(strengths, ci_upper, color=MODEL_NAME_MAPPING[model]['color'], alpha=0.2, marker='x', linestyle='--')

                                
                    if label not in labels:
                        handles.append(line)
                        labels.append(label)

                    # Set only the actual strength values as ticks
                    axes[j].set_xticks(strengths)
                    axes[j].set_xticklabels(strengths, fontsize=fs_xticks)
                    #axes[j].set_xlim([strengths[0]-0.1, strengths[-1]+0.1])
            
            

    
    fig.legend(loc='upper center', bbox_to_anchor=(0.2, 0.4, 0.5, 0.5), ncol=len(models), handles=handles, labels=labels)
    

    plt.savefig(args.output_plot, bbox_inches='tight', dpi=300)
    #plt.show()
    plt.close()
    print(f"Plot saved to {args.output_plot}")



In [13]:
def plot_tpr_per_metric(args, results_df, metric_name, metric_column, title_suffix, xlabel, xlim):
    """
    Generic plotting function that can use any metric for the x-axis
    
    Parameters:
    - args: The command line arguments
    - metric_name: String identifier for the metric (used in filenames)
    - metric_column: Name of the column to use for x-axis values
    - title_suffix: Text to add to the plot title
    - xlabel: Label for the x-axis
    """
    # results_df = pd.read_csv(args.output_csv)
    
    attack_names = results_df['attack_name'].unique()
    wm_methods = results_df['wm_method'].unique()
    models = results_df['model_id'].unique()

    attack_names = np.array(sorted(attack_names, key=lambda x: list(ATTACK_NAME_MAPPING.keys()).index(x)))
    wm_methods = np.array(sorted(wm_methods, key=lambda x: list(METHODS_NAME_MAPPING.keys()).index(x)))
    models = np.array(sorted(models, key=lambda x: list(MODEL_NAME_MAPPING.keys()).index(x)))

    # Setup figure with same layout
    ncols = wm_methods.shape[0] + 1 # per method, plus one for title
    nrows = attack_names.shape[0]  # for each attack
    fs = 10
    fs_xticks = 8
    fs_yticks = 8
    fs_title = 14
    y_adj = 0.937
    title_height_ratio = 0.15
    title = (
        f'Watermarking performance vs {title_suffix}\n'
        f'for dataset "{args.prompt_dataset}" for experiments in \n'
        f'{args.dataset_identifier}'
    )

    fig, gs, title_axes = setup_gridspec_figure(
        nrows=nrows, ncols=ncols,
        fs=fs, title=title, fs_title=fs_title,
        y_adj=y_adj, title_height_ratio=title_height_ratio,
        sp_width=2, sp_height=1.75
    )

    # # Set row titles (attack names)
    # for i, ax in enumerate(title_axes):
    #     ax.text(0.5, 0.25, ATTACK_NAME_MAPPING[attack_names[i]]['name'], 
    #             fontsize=fs_title, fontweight="bold", ha="center", va="center")
                      
    handles, labels = [], []
    xticks_num = 3
    xticks_stepsize = (xlim[1] - xlim[0]) / xticks_num
    xticks = np.round(np.arange(xlim[0], xlim[1] + xticks_stepsize, xticks_stepsize), 2)
    xlim_buffer = np.abs(xlim[1] - xlim[0]) * 0.07
    xlim = (xlim[0] - xlim_buffer, xlim[1] + xlim_buffer)

    # Loop through attacks and watermarking methods
    for i, attack_name in enumerate(attack_names): # rows
        attack_df = results_df[results_df['attack_name'] == attack_name]
        if attack_name not in ATTACK_NAME_MAPPING:
            continue

        axes = [fig.add_subplot(gs[2*i +1, j]) for j in range(ncols)]
        for j, wm_method in enumerate(np.concatenate((wm_methods, ["title"]))): # columns
            if wm_method == "title": # last column is title of the attack
                axes[j].axis('off')
                axes[j].text(0.1, 0.5, ATTACK_NAME_MAPPING[attack_name]['name'], fontsize=fs, fontweight="bold", ha="left", va="center")
                if i == 0:
                    axes[j].set_title('Attacktype', fontsize=fs)
            else:
                wm_df = attack_df[attack_df['wm_method'] == wm_method]

                # Set axis direction based on attack type
                if ATTACK_NAME_MAPPING[attack_name]['order'] == 'low-to-high':
                    axes[j].invert_xaxis()
                    
                if i == 0:
                    axes[j].set_title(METHODS_NAME_MAPPING[wm_method], fontsize=fs)
                
                axes[j].set_ylim([-0.1, 1.1])
                axes[j].set_yticks(np.arange(0, 1.1, 0.25))
                axes[j].set_yticklabels(np.arange(0, 1.1, 0.25), fontsize=fs_yticks)
                axes[j].set_xlim(xlim)
                axes[j].set_xticks(xticks)
                axes[j].set_xticklabels(xticks, fontsize=fs_xticks)
                axes[j].grid(True, linestyle='--', alpha=0.5)
                # set top and right spines to invisible
                axes[j].spines['top'].set_visible(False)
                axes[j].spines['right'].set_visible(False)

                if j == 0:# Add y-axis label to the first plot in each row
                    axes[j].set_ylabel("TPR@FPR=0.01")
                else:# disable y-axis labels for all but the first column
                    plt.setp(axes[j].get_yticklabels(), visible=False)
                    plt.setp(axes[j].get_yticklines(), visible=False)
                
                # For quality metrics (like CLIP similarity score), higher is better, 
                # so have higher values to the left
                if "score" in metric_column.lower() or "similarity" in metric_column.lower():
                    #print(f"enter score for {metric_column}")
                    if axes[j].get_xlim()[0] < axes[j].get_xlim()[1]:  # If lower values are on left
                        #print(f"enter score for {metric_column} invert")
                        axes[j].invert_xaxis()  # Invert so higher values are on left
                # For distance metrics (like FID), lower is better, so have lower values to the left
                if "fid" in metric_column.lower() or "distance" in metric_column.lower():
                    #print(f"enter fid for {metric_column}")
                    if axes[j].get_xlim()[0] > axes[j].get_xlim()[1]:  # If higher values are on left
                        #print(f"enter fid for {metric_column} invert")
                        axes[j].invert_xaxis()  # Invert so lower values are on left

                for model in models: # lines
                    model_df = wm_df[wm_df['model_id'] == model]
                    
                    # Check if the metric column exists
                    if metric_column not in model_df.columns:
                        print(f"Warning: {metric_column} not found for {attack_name}, {wm_method}, {model}")
                        continue

                    # Sort by the metric column
                    df_sorted = model_df.sort_values(by=metric_column)
                    x_values = df_sorted[metric_column].values
                    tpr_values = df_sorted['tpr_empirical'].values
                    attack_strengths = df_sorted['attack_strength'].values

                    label = MODEL_NAME_MAPPING[model]['name']
                    
                    line, = axes[j].plot(x_values, tpr_values,
                                marker=MODEL_NAME_MAPPING[model]['marker'],
                                linestyle=MODEL_NAME_MAPPING[model]['line'],
                                label=label,
                                color=MODEL_NAME_MAPPING[model]['color'])
                    
                    # Add attack strength as text near each point for reference
                    for k, (x, y, strength) in enumerate(zip(x_values, tpr_values, attack_strengths)):
                        if k % 2 == 0:  # Only label every other point to avoid clutter
                            axes[j].annotate(f"{strength}", (x, y), 
                                            textcoords="offset points", 
                                            xytext=(0, 5), 
                                            ha='center',
                                            fontsize=7)
                                
                    if label not in labels:
                        handles.append(line)
                        labels.append(label)

    
                
                

    fig.legend(loc='upper center', bbox_to_anchor=(0.2, 0.4, 0.5, 0.5), ncol=len(models), handles=handles, labels=labels)
    
    output_plot = args.output_plot.replace('.pdf', f'_{metric_name}.pdf')
    plt.savefig(output_plot, bbox_inches='tight', dpi=300)
    #plt.show()
    plt.close()
    print(f"\n{title_suffix} plot saved to {output_plot}")


In [14]:
args = Namespace()
args.exp_name = 'exp1'


# specify which experimental setup we want to plot
args.num_imgs = 200
args.prompt_dataset = 'coco'

# for exp1, we merge results over wmch_16 for Flux and wmch_4 for SD
args.dataset_identifier = [f'num_{args.num_imgs}_fpr_0.01_cfg_3.0_wmch_16', 
                           f'num_{args.num_imgs}_fpr_0.01_cfg_3.0_wmch_4'] 


# create the output directories and ffilenames
args.input_dir = os.path.join('experiments', args.exp_name)
args.output_dir = os.path.join('experiments', args.exp_name, '_results', args.prompt_dataset,  args.dataset_identifier[0])
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
args.output_plot = os.path.join(args.output_dir, args.dataset_identifier[0] + '_plot.pdf')
args.merged_result_csv = os.path.join(args.output_dir, args.dataset_identifier[0] + '_merged.csv')

# merged results already created in 5_merge_results.py
results_df = pd.read_csv(args.merged_result_csv)

# 1. plot TPR vs attack strength
#plot_tpr_per_attack(args, results_df)

# 2. plot TPR vs CLIP 
xmin = results_df['clip_score_wm'].min()
xmax = results_df['clip_score_wm'].max()
plot_tpr_per_metric(
    args, 
    results_df, 
    metric_name="clip_score", 
    metric_column="clip_score_wm",
    title_suffix="CLIP similarity score",
    xlabel="CLIP score (↑)",
    xlim=[xmin, xmax]
)

# 3. plot TPR vs diff 
xmin = results_df['wm_diff'].min()
xmax = results_df['wm_diff'].max()
plot_tpr_per_metric(
    args, 
    results_df, 
    metric_name="wm_diff", 
    metric_column="wm_diff",
    title_suffix="Abs. Mean Difference (originial - recovered)",
    xlabel="Diff (↓)",
    xlim=[xmin, xmax]
)

# 4. plot TPR vs FID (WM vs COCO)
xmin = results_df['fid_wm_coco'].min()
xmax = results_df['fid_wm_coco'].max()
plot_tpr_per_metric(
    args, 
    results_df, 
    metric_name="fid_coco", 
    metric_column="fid_wm_coco",
    title_suffix="FID (WM vs COCO)",
    xlabel="FID (↓)",
    xlim=[xmin, xmax]
)

# 5. plot TPR vs FID (WM vs NOWM)
xmin = results_df['fid_wm_nowm'].min()
xmax = results_df['fid_wm_nowm'].max()
plot_tpr_per_metric(
    args, 
    results_df, 
    metric_name="fid_wm_nowm", 
    metric_column="fid_wm_nowm",
    title_suffix="FID (WM vs NOWM)",
    xlabel="FID (↓)",
    xlim=[xmin, xmax]
)


CLIP similarity score plot saved to experiments/exp1/_results/coco/num_200_fpr_0.01_cfg_3.0_wmch_16/num_200_fpr_0.01_cfg_3.0_wmch_16_plot_clip_score.pdf

Abs. Mean Difference (originial - recovered) plot saved to experiments/exp1/_results/coco/num_200_fpr_0.01_cfg_3.0_wmch_16/num_200_fpr_0.01_cfg_3.0_wmch_16_plot_wm_diff.pdf

FID (WM vs COCO) plot saved to experiments/exp1/_results/coco/num_200_fpr_0.01_cfg_3.0_wmch_16/num_200_fpr_0.01_cfg_3.0_wmch_16_plot_fid_coco.pdf

FID (WM vs NOWM) plot saved to experiments/exp1/_results/coco/num_200_fpr_0.01_cfg_3.0_wmch_16/num_200_fpr_0.01_cfg_3.0_wmch_16_plot_fid_wm_nowm.pdf
