# MAIA Demo

#### Many of MAIA's experiments are available in the [experiment browser](https://multimodal-interpretability.csail.mit.edu/maia/experiment-browser/) ####

In [None]:
import argparse
from getpass import getpass
import os
from tqdm import tqdm
import time
from random import random, uniform
import torch
import json
from call_agent import ask_agent
from IPython import embed
from maia_api import *
import random

### Load openai api key 

(in case you don't have an openai api-key, you can get one by following the instructions [here](https://platform.openai.com/docs/quickstart)).
\
\
option 1:
\
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Set your api-key as an environment variable (this is a bash command, look [here](https://platform.openai.com/docs/quickstart) for other OS)
```bash
export OPENAI_API_KEY='your-api-key-here'
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Load your API key from an environment variable or secret management service
```python
openai.api_key = os.getenv("OPENAI_API_KEY")
```
option 2:
\
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Load your API key manually:
```python
openai.api_key = 'your-api-key-here'
```


In [None]:
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORGANIZATION")

In [None]:
model = 'resnet152' # we currently support 'resnet152', 'clip-RN50', 'dino_vits8'
layer = 'layer4'
unit = 122
setting = 'neuron_description'
maia_model = 'gpt-4-vision-preview'
path2prompts = './prompts/'
path2save = './results/'
path2exemplars = '/data/vision/torralba/Functions/maia/exemplars/'
device_id = 0
device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") 
text2image = 'sd' # "sd" is for stable-diffusion

In [None]:
# return the prompt according to the task
def return_Prompt(prompt_path,setting='neuron_description'):
    with open(f'{prompt_path}/api.txt', 'r') as file:
        sysPrompt = file.read()
    with open(f'{prompt_path}/user_{setting}.txt', 'r') as file:
        user_prompt = file.read()
    return sysPrompt, user_prompt
    
maia_api, user_query = return_Prompt(path2prompts, setting) # load system prompt (maia api) and user prompt (the user query)

### MAIA API

In [None]:
print(maia_api)

### Interpretability task

In [None]:
print(user_query)

### Initializations

In [None]:
net_dissect = DatasetExemplars(path2exemplars, path2save, model, layer, [unit]) # precomputes dataset exemplars for tools.dataset_exemplars
system = System(unit, layer, model, device_id, net_dissect.thresholds) # initialize the system class
tools = Tools(path2save, device_id, net_dissect, text2image_model_name=text2image, images_per_prompt=1) # initialize the tools class

In [None]:
tools.experiment_log = []
tools.update_experiment_log(role='system', type="text", type_content=maia_api) # update the experiment log with the system prompt
tools.update_experiment_log(role='user', type="text", type_content=user_query) # update the experiment log with the user prompt

### Utils

In [None]:
# Parse the code by locating Python syntax
def get_code(maia_experiment):
    maia_code = maia_experiment.split('```python')[1].split('```')[0]
    return maia_code

# Run the code on python
def execute_maia_experiment(code,system,tools): 
    exec(compile(code, 'code', 'exec'), globals())
    execute_command(system,tools)
    return  

# Plot the results from the experiment log
def plot_results_notebook(experiment_log):
    if (experiment_log['role'] == 'assistant'):
        print('\n\n*** MAIA: ***\n\n')  
    else: 
        print('\n\n*** Experiment Execution: ***\n\n')
    for item in experiment_log['content']:
        if item['type'] == 'text':
            print(item['text'])
        elif item['type'] == 'image_url':
            display(str2image(item['image_url']['url'].split(',')[1]))

### MAIA's interpretation experiment

Please note: 
This demo does not handle open-ai api exceptions, please switch to ```main.py``` for error handling  (recommended for looping over several units). ```main.py``` also handles saving results.

In [None]:
while True:
    maia_experiment = ask_agent(maia_model,tools.experiment_log) # ask maia for the next experiment given the results log to the experiment log
    tools.update_experiment_log(role='maia', type="text", type_content=str(maia_experiment)) # update the experiment log with maia's response 
    plot_results_notebook(tools.experiment_log[-1]) # plot the result to notebook
    if "[DESCRIPTION]" in maia_experiment: break # stop the experiment if the response contains the final description. 
    maia_code = get_code(maia_experiment) # parse the code by locating Python syntax
    execute_maia_experiment(maia_code, system, tools) # execute the experiment, maia's code should contain tools.update_experiment_log(...) 
    plot_results_notebook(tools.experiment_log[-1]) # plot the result to notebook