In [25]:
%load_ext autoreload
%autoreload 2
import pandas as pd
from tqdm import tqdm
import os
import csv
import numpy as np
import utils
from transformers import T5Tokenizer, T5ForConditionalGeneration
import utensil_passing_utils
from datetime import datetime

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
remaining_objects = {'spatula': 1, 'whisk': 1, 'scissors': 1, 'knife': 1}
root_node = utensil_passing_utils.Node(remaining_objects, [], include_trust_change=True)

utensil_passing_utils.expand_interaction_tree(root_node, template_generation_fn=utensil_passing_utils.generate_template)
all_nodes = root_node.traverse()

1281

In [27]:
answer_choices = ["A", "B"]

In [28]:
result_header = ['remaining objects', 'history', 'prompt'] + answer_choices
result_header.append('sum of prob')
include_trust_change = False
query_action = True

llm_result_path = f"./results/davinci.csv"
if not os.path.exists(llm_result_path):
    llm_result_file = open(llm_result_path, 'w')
    writer = csv.writer(llm_result_file)
    writer.writerow(result_header)

    remaining_objects = {'spatula': 1, 'whisk': 1, 'scissors': 1, 'knife': 1, }
    root_node = utensil_passing_utils.Node(remaining_objects, [], include_trust_change=True)
    utensil_passing_utils.expand_interaction_tree(root_node, template_generation_fn=utensil_passing_utils.generate_template)
    all_nodes = root_node.traverse()
    
    print(f"Saving to {llm_result_path}")

    for node in tqdm(all_nodes):
        for prompt in node.prompts:
            template = prompt + " {}"
            probs = utils.get_probs_davinci(template, answer_choices)
            writer.writerow([node.remaining_objects, node.history, prompt, probs[0], probs[1], sum(probs)])
    llm_result_file.close()

In [29]:
llm_result_df = pd.read_csv('./results/davinci.csv')
# Get the probs
s = set()
for node in tqdm(all_nodes):
    probs = {}
    if len(llm_result_df[llm_result_df['history'] == str(node.history)]) == 0:
        assert len(node.children) == 0
    else:
        robot_action_list = list(node.remaining_objects.keys())
        if 'knife' not in node.remaining_objects:
            assert len(node.children) == len(llm_result_df[llm_result_df['history'] == str(node.history)]) * 3
        else:
            assert len(node.children) == len(llm_result_df[llm_result_df['history'] == str(node.history)]) * 3 - 1
        rows = llm_result_df[llm_result_df['history'] == str(node.history)]
        assert len(rows) == len(robot_action_list)
        for action_idx, action in enumerate(robot_action_list):
            row = rows.iloc[action_idx]
            assert f"{action}?" in row['prompt']
            probs[action] = [row['A'] / (row['A'] + row['B']), row['B'] / (row['A'] + row['B'])]
    node.probs = probs

100%|██████████| 1281/1281 [00:00<00:00, 3175.51it/s]


In [30]:
# Check the trust dynamics
violation_amount = 0
violation = 0
total = 0
for node in all_nodes:
    if len(node.children) != 0:
        for child_node_idx, child_node in enumerate(node.children):
            if len(child_node.children) != 0:
                for action, dist in child_node.probs.items():
                    if child_node.history[-1][1]:
                        # Intervened
                        if dist[0] > node.probs[action][0]:
                            violation += 1
                            violation_amount += abs(dist[0] - node.probs[action][0])
                        total += 1
                    else:
                        # Stayed put
                        if child_node.history[-1][-1]:
                            # Succeed
                            if dist[0] < node.probs[action][0]:
                                violation += 1
                                violation_amount += abs(dist[0] - node.probs[action][0])
                        else:
                            # Failed
                            if dist[0] > node.probs[action][0]:
                                violation += 1
                                violation_amount += abs(dist[0] - node.probs[action][0])
                        total += 1
print(violation, total)
print(violation_amount)

20 498
2.9909489720437346e-05


In [33]:
# Get the policy
# reward_dict = {'spoon': 1, 'fork': 1, 'knife': 1}
q_dict = {node: {} for node in all_nodes}
discount = 0.99
# This is ok because we have a tree structure
def get_value_of_node(cur_node: utensil_passing_utils.Node):
    robot_action_list = list(cur_node.remaining_objects.keys())
    if len(cur_node.children) == 0:
        return 0
    else:
        for a_idx, a in enumerate(robot_action_list):
            immediate_reward = 1
            intervene_prob = cur_node.probs[a][1]
            stay_put_prob = cur_node.probs[a][0]
                        
            if a != "knife":
                stay_put_succ_node = cur_node.children[a_idx * 3]
                stay_put_fail_node = cur_node.children[a_idx * 3 + 1]
                intervene_node = cur_node.children[a_idx * 3 + 2]
                q_dict[cur_node][f"{a}_fail"] = intervene_prob * discount * get_value_of_node(intervene_node) + stay_put_prob * (-1 + discount * get_value_of_node(stay_put_fail_node))
                q_dict[cur_node][f"{a}_succ"] = intervene_prob * discount * get_value_of_node(intervene_node) + stay_put_prob * (immediate_reward + discount * get_value_of_node(stay_put_succ_node))
            else:
                intervene_node = cur_node.children[a_idx * 3 + 1]
                # Failing on knife incurs a penalty of 10 and terminate
                q_dict[cur_node][a] = intervene_prob * discount * get_value_of_node(intervene_node) + stay_put_prob * discount * -10
                
        return np.max(list(q_dict[cur_node].values()))
            
leaf_node = root_node.children[0].children[0].children[0].children[0]
get_value_of_node(root_node)

0.11353785142560353

In [35]:
cur_node = root_node
while len(cur_node.children) != 0:
    action = max(q_dict[cur_node], key=q_dict[cur_node].get)
    # print(q_dict[cur_node])
    print(action)
    if action != "knife":
        object, result = action.split('_')
        robot_action_list = list(cur_node.remaining_objects.keys())
        action_idx = robot_action_list.index(object)
        result_idx = 0 if result == 'succ' else 1
        cur_node = cur_node.children[action_idx * 3 + result_idx]
    else:
        cur_node = cur_node.children[0]

spatula_succ
whisk_succ
scissors_fail
knife
