In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
from tqdm import tqdm
import os
import csv
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration
import table_clearing_experiment_utils
import utils
from datetime import datetime

In [3]:
include_trust_change = False
query_action = True

remaining_objects = {'plastic bottle': 3, 'fish can': 1, 'wine glass': 1}
root_node = table_clearing_experiment_utils.Node(remaining_objects, [], include_trust_change=include_trust_change)

if query_action:
    template_generation_fn = table_clearing_experiment_utils.generate_template_action
else:
    template_generation_fn = table_clearing_experiment_utils.generate_template

table_clearing_experiment_utils.expand_interaction_tree(root_node, template_generation_fn=template_generation_fn)
all_nodes = root_node.traverse()

1099

In [4]:
answer_choices = ["A", "B"]

In [5]:
result_header = ['remaining objects', 'history', 'prompt'] + answer_choices
result_header.append('sum of prob')
include_trust_change = False
query_action = True
descr = ""
if include_trust_change:
    descr += "trust_change"
if query_action:
    descr += "_query_action"
else:
    descr += "_query_yesno"
    
if query_action:
    template_generation_fn = table_clearing_experiment_utils.generate_template_action
else:
    template_generation_fn = table_clearing_experiment_utils.generate_template

llm_result_path = f"./results/t5_{descr}.csv"
if not os.path.exists(llm_result_path):
    max_memory = {0: "20GIB", 1: "20GIB", 2: "20GIB", 3: "20GIB"}
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", max_memory=max_memory)
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
    llm_result_file = open(llm_result_path, 'w')
    writer = csv.writer(llm_result_file)
    writer.writerow(result_header)

    remaining_objects = {'plastic bottle': 3, 'fish can': 1, 'wine glass': 1}
    root_node = table_clearing_experiment_utils.Node(remaining_objects, [], include_trust_change=include_trust_change)
    table_clearing_experiment_utils.expand_interaction_tree(root_node, template_generation_fn=template_generation_fn)
    all_nodes = root_node.traverse()
    
    print(f"Saving to {llm_result_path}")
    print(f"Include trust change {include_trust_change}, Query for {'yes no' if not query_action else 'action'}")

    for node in tqdm(all_nodes):
        for prompt in node.prompts:
            probs = utils.get_probs_t5([prompt], model, tokenizer, answer_choices)[0]
            writer.writerow([node.remaining_objects, node.history, prompt, probs[0], probs[1], sum(probs)])
    llm_result_file.close()

In [6]:
result_header = ['remaining objects', 'history', 'prompt'] + answer_choices
result_header.append('sum of prob')
include_trust_change = False
query_action = True
descr = ""
if include_trust_change:
    descr += "trust_change"
if query_action:
    descr += "_query_action"
else:
    descr += "_query_yesno"
    
if query_action:
    template_generation_fn = table_clearing_experiment_utils.generate_template_action
else:
    template_generation_fn = table_clearing_experiment_utils.generate_template

llm_result_path = f"./results/davinci_{descr}.csv"
if not os.path.exists(llm_result_path):
    result_header = ['remaining objects', 'history', 'prompt'] + answer_choices
    result_header.append('sum of prob')
    llm_result_file = open(llm_result_path, 'w')
    writer = csv.writer(llm_result_file)
    writer.writerow(result_header)

    remaining_objects = {'plastic bottle': 3, 'fish can': 1, 'wine glass': 1}
    root_node = table_clearing_experiment_utils.Node(remaining_objects, [], include_trust_change=include_trust_change)
    table_clearing_experiment_utils.expand_interaction_tree(root_node, template_generation_fn=template_generation_fn)
    all_nodes = root_node.traverse()

    print(f"Saving to {llm_result_path}")
    print(f"Include trust change {include_trust_change}, Query for {'yes no' if not query_action else 'action'}")
    
    for node in tqdm(all_nodes):
        for prompt in node.prompts:
            template = prompt + " {}"
            probs = utils.get_probs_davinci(template, answer_choices)
            writer.writerow([node.remaining_objects, node.history, prompt, probs[0], probs[1], sum(probs)])
    llm_result_file.close()

In [58]:
# Evaluate different policies in the simulated environment
# Note that the davinci policies are not very distinguishable
for csv_path in sorted(os.listdir('./results')):
    # csv_path = 'davinci_trust_change_query_yesno.csv'
    table_clearing_experiment_utils.evaluate_policy(f'./results/{csv_path}')

*************** results ***********************
#### return #####
mean return:  6.1388581549909995
std:  0.03358096569498668
##### intervention ratio #####
intervene ratio:  [0.07583333 0.2011     0.3603    ]
./results/davinci__query_action.csv
*************** results ***********************
#### return #####
mean return:  6.138182573410001
std:  0.03411672375918422
##### intervention ratio #####
intervene ratio:  [0.07833333 0.1944     0.3628    ]
./results/davinci__query_yesno.csv
*************** results ***********************
#### return #####
mean return:  6.148677582649
std:  0.03380614817922158
##### intervention ratio #####
intervene ratio:  [0.07413333 0.2114     0.3517    ]
./results/davinci_trust_change_query_action.csv
*************** results ***********************
#### return #####
mean return:  6.173598900674001
std:  0.033996323440301617
##### intervention ratio #####
intervene ratio:  [0.07443333 0.1975     0.3523    ]
./results/davinci_trust_change_query_yesno.csv
***