# SAIA Demo
In this demo, the SAIA will try to identify visual attribute reliances in an apron classifier with an injected attribute reliance on the presence of feminine-presenting individuals in the image.

In [None]:
import argparse
import os
import random
from IPython import embed

import torch
import traceback

from utils.call_agent import ask_agent
from agent_api import Synthetic_System, Tools
from utils.ExperimentEnvironment import ExperimentEnvironment
from utils.SyntheticExemplars import SyntheticExemplars
from utils.main_utils import MainUtils
from utils.api_utils import str2image
import json


In [None]:
agent = 'claude' # points to claude-3-5-sonnet-latest, we currently support 'claude', 'gpt-4o', and 'gpt-4-turbo
base = './results/'
path2prompts = './prompts/'
mode = 'gender'
bias = 'female' # demographic attribute feature reliance on feminine-presenting individuals
bias_discount = 0.9
path2exemplars = f'./exemplars/{mode}/{bias}/{bias_discount}_discount'
path2save = os.path.join(base, mode, bias, f'{bias_discount}_discount')
device_id = 0
text2image = 'flux' # Flux image generation model
p2p_model = 'instdiff' # InstructDiffusion image editing model
n_experiment = 0

### Initializations

In [None]:
net_dissect = SyntheticExemplars(path2exemplars, path2save, mode) # precomputes synthetic dataset examplars for tools.dataset_exemplars
with open(os.path.join(path2exemplars,'data.json'), 'r') as file: # load the benchmark model labels
    classifier_data = json.load(file)

with open(os.path.join(path2prompts,'user_adversarial_experiment.txt'), 'r') as file:
    experiment_prompt = file.read()

In [None]:
classifier_number = 0 # id of specific classifier
item = classifier_data[classifier_number]
gt_label = item["label"].rsplit('_')[1:] # label of the classifier
print(f"Target concept: {gt_label[0]}")
print(f"Feature reliance: {bias}")
obj = gt_label[0]
experiment_path2save = os.path.join(path2save, item["label"])
os.makedirs(experiment_path2save, exist_ok=True)

system = Synthetic_System(classifier_number, gt_label, mode, device_id, bias=bias, bias_discount=bias_discount) # initialize the system class
tools = Tools(device_id, net_dissect, text2image_model_name=text2image, p2p_model_name=p2p_model) # initialize the tools class
experiment_env = ExperimentEnvironment(system, tools, globals()) # initialize the experiment environment
main_utils = MainUtils(path2prompts, experiment_path2save, agent, obj, tools, system, n_experiment) # initialize the main utils class

### SAIA's attribute reliance detection experiment

In [None]:
experiment_rounds = 5 # max number of experiment rounds
prev_pos_avg = None
prev_neg_avg = None

for r in range(experiment_rounds):
    if r == 0:
        tools.experiment_log = []
        agent_api, user_query = main_utils.return_prompt(setting='bias_discovery')
        tools.update_experiment_log(role='system', type="text", type_content=agent_api) # update the experiment log with the system prompt
        tools.update_experiment_log(role='user', type="text", type_content=user_query) # update the experiment log with the user prompt
        main_utils.plot_results_notebook(tools.experiment_log)
    else:
        tools.experiment_log = []
        tools.update_experiment_log(role='system', type="text", type_content=agent_api) # update the experiment log with the system prompt
        prev_pos_avg, prev_neg_avg = main_utils.load_prev_context(round_num=r-1) # load the previous round's self-reflection context
        agent_experiment = ask_agent(agent, tools.experiment_log) # ask the agent for the next experiment given the context
        tools.update_experiment_log(role='agent', type="text", type_content=str(agent_experiment)) # update the experiment log with the agent's response
        tools.update_experiment_log(role='user', type="text", type_content=experiment_prompt) # update the experiment log with the user prompt
        main_utils.plot_results_notebook(tools.experiment_log)
    ind = len(tools.experiment_log)
    for i in range(20):
        try:
            agent_experiment = ask_agent(agent,tools.experiment_log) # ask the agent for the next experiment given the results log to the experiment log (in the first round, the experiment log contains only the system prompt (agent api) and the user prompt (the query))
            tools.update_experiment_log(role='agent', type="text", type_content=str(agent_experiment)) # update the experiment log with the agent's response (str casting is for exceptions)
            tools.generate_html(experiment_path2save, name=f'experiment_{n_experiment}_r{r}')
            main_utils.plot_results_notebook(tools.experiment_log[ind:])
            ind = len(tools.experiment_log)
            if "[BIAS LABEL]" in agent_experiment: 
                break # stop the experiment if the response contains the final description. "[BIAS LABEL]" is the stopping signal.  
            experiment_output = experiment_env.execute_experiment(agent_experiment) # execute the experiment
            if experiment_output != "":
                tools.update_experiment_log(role='user', type="text", type_content=experiment_output) # update the experiment log with the experiment results
        except Exception as e:
            tools.update_experiment_log(role='user', type="text", type_content=e)
            traceback.print_exc()
        
    pos_avg, neg_avg, history = main_utils.save_results(round_num=r) # save the results of the current round
    print(f"cur positive average: {pos_avg}, cur negative average: {neg_avg}, prev positive average: {prev_pos_avg}, prev negative average: {prev_neg_avg}")
    with open(os.path.join(experiment_path2save, f"history_{main_utils.experiment}.json"), 'w') as file:
        json.dump(history, file) # persist the history of the experiment
    file.close()
    # stops after 5 experiment rounds or once the early stopping heuristic is met
    if pos_avg >= 0.7*main_utils.avg_acts and neg_avg <= 0.4*main_utils.avg_acts:
        break