# MAIA Demo

#### Many of MAIA's experiments are available in the [experiment browser](https://multimodal-interpretability.csail.mit.edu/maia/experiment-browser/) ####

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch

from utils.call_agent import ask_agent
from maia_api import System, Synthetic_System, Tools
from utils.ExperimentEnvironment import ExperimentEnvironment
from utils.DatasetExemplars import DatasetExemplars
from utils.SyntheticExemplars import SyntheticExemplars
from utils.main_utils import *

In [4]:
model = 'resnet152' # we currently support 'resnet152', 'clip-RN50', 'dino_vits8' and "synthetic_neurons" (NEW!)
layer = 'layer4' # for "synthetic_neurons" this will be the operation mode: "mono", "or" or "and" (see paper for details)
unit = 122
setting = 'neuron_description'
maia_model = 'claude' # points to claude-3-5-sonnet-latest
path2prompts = './prompts/'
path2save = './results/'
path2exemplars = './exemplars/'
device_id = 0
device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") 
text2image = 'flux' # "sd" is for stable-diffusion
p2p_model = 'instdiff'

In [5]:
# return the prompt according to the task
def return_Prompt(prompt_path,setting='neuron_description'):
    with open(f'{prompt_path}/api.txt', 'r') as file:
        sysPrompt = file.read()
    with open(f'{prompt_path}/user_{setting}.txt', 'r') as file:
        user_prompt = file.read()
    return sysPrompt, user_prompt
    
maia_api, user_query = return_Prompt(path2prompts, setting) # load system prompt (maia api) and user prompt (the user query)

### MAIA API

In [None]:
print(maia_api)

### Interpretability task

In [None]:
print(user_query)

### Initializations

In [None]:
if model=="synthetic_neurons":
    net_dissect = SyntheticExemplars(os.path.join(path2exemplars, model), path2save, layer) # precomputes synthetic dataset examplars for tools.dataset_exemplars. 
    with open(os.path.join('./synthetic-neurons-dataset/labels/',f'{layer}.json'), 'r') as file: # load the synthetic neuron labels
        synthetic_neuron_data = json.load(file)
        gt_label = synthetic_neuron_data[unit]["label"].rsplit('_')
        print("groundtruth label:",gt_label)
        system = Synthetic_System(unit, gt_label, layer, device_id)
else:
    net_dissect = DatasetExemplars(path2exemplars, path2save, model, layer, [unit]) # precomputes dataset exemplars for tools.dataset_exemplars
    system = System(unit, layer, model, device_id, net_dissect.thresholds) # initialize the system class

tools = Tools(path2save, device_id, net_dissect, text2image_model_name=text2image, p2p_model_name=p2p_model) # initialize the tools class
experiment_env = ExperimentEnvironment(system, tools, globals())

### MAIA's interpretation experiment

Please note: 
This demo does not handle open-ai api exceptions and bugs in MAIA's code, please switch to ```main.py``` for error handling (recommended for looping over several units). ```main.py``` also handles saving results.

In [None]:
tools.experiment_log = []
tools.update_experiment_log(role='system', type="text", type_content=maia_api) # update the experiment log with the system prompt
tools.update_experiment_log(role='user', type="text", type_content=user_query) # update the experiment log with the user prompt
ind = len(tools.experiment_log)

while True:
    maia_experiment = ask_agent(maia_model,tools.experiment_log) # ask maia for the next experiment given the results log to the experiment log
    tools.update_experiment_log(role='maia', type="text", type_content=str(maia_experiment)) # update the experiment log with maia's response 
    plot_results_notebook(tools.experiment_log[ind:]) # plot the result to notebook
    ind = len(tools.experiment_log)
    if "[DESCRIPTION]" in maia_experiment: break # stop the experiment if the response contains the final description. 
    experiment_output = experiment_env.execute_experiment(maia_experiment) # execute the experiment
    if experiment_output != "":
        tools.update_experiment_log(role='user', type="text", type_content=experiment_output)
tools.generate_html(path2save)