Skip to content

Commit

Permalink
did some code clean up and fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
nexusapoorvacus committed Mar 12, 2018
1 parent 595e8a4 commit a7a407c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 54 deletions.
11 changes: 7 additions & 4 deletions graph.py
Expand Up @@ -47,15 +47,18 @@ def get_entity_by_name(self, entity):
def variation_based_traversal(self, subject_name, object_name, previously_mined_attributes=[], max_num_to_return=-1):
subject_node = self.get_entity_by_name(subject_name)
object_node = self.get_entity_by_name(object_name)
if subject_node == None:
return [], []
attributes_to_return = {}
for a_edge in subject_node.attribute_edges:
if self.attribute_nodes[subject_node.attribute_edges[a_edge].attribute_id].ID not in previously_mined_attributes:
attributes_to_return[subject_node.attribute_edges[a_edge].attribute_id] = subject_node.attribute_edges[a_edge].multiplicity
predicates_to_return = {}
for p_edge in subject_node.predicate_edges:
subject_node.attribute_edges[a_edge].multiplicity
if subject_node.predicate_edges[p_edge].object_id == object_node.ID:
predicates_to_return[subject_node.predicate_edges[p_edge].predicate_id] = subject_node.predicate_edges[p_edge].multiplicity
if object_node != None:
for p_edge in subject_node.predicate_edges:
subject_node.attribute_edges[a_edge].multiplicity
if subject_node.predicate_edges[p_edge].object_id == object_node.ID:
predicates_to_return[subject_node.predicate_edges[p_edge].predicate_id] = subject_node.predicate_edges[p_edge].multiplicity

attributes_to_return = sorted(attributes_to_return.items(), key=lambda x: -x[1])
predicates_to_return = sorted(predicates_to_return.items(), key=lambda x: -x[1])
Expand Down
4 changes: 2 additions & 2 deletions image_state.py
Expand Up @@ -98,7 +98,7 @@ def step(self, attribute_action, predicate_action, next_object_action):

gt_subject_index = self.overlaps(self.current_subject)
if gt_subject_index != -1: #overlap
if pred_attribute_name in self.gt_scene_graph["labels"]["objects"][gt_subject_index]["attributes"]:
if "attributes" in self.gt_scene_graph["labels"]["objects"][gt_subject_index] and pred_attribute_name in self.gt_scene_graph["labels"]["objects"][gt_subject_index]["attributes"]:
reward_attribute = 1
gt_object_index = self.overlaps(self.current_object)
if gt_object_index != -1:
Expand All @@ -113,7 +113,7 @@ def step(self, attribute_action, predicate_action, next_object_action):
gt_new_object_index = self.overlaps(next_object_action)
#self.explored_entities.append(new_object_index)
if gt_new_object_index != -1:
if next_object_action == gt_new_object_index:
if gt_new_object_index not in self.explored_entities:
reward_next_object = 5
self.current_object = next_object_action
else:
Expand Down
98 changes: 50 additions & 48 deletions main.py
Expand Up @@ -83,9 +83,7 @@ def train(parameters):
state_vector = create_state_vector(im_state)
subject_id = im_state.current_subject
object_id = im_state.current_object
print("Done!")
if type(state_vector) == type(None):
import pdb; pdb.set_trace()
if im_state.current_subject == None:
break
else:
Expand All @@ -95,6 +93,7 @@ def train(parameters):
continue

# perform variation structured traveral scheme to get adaptive actions
print("Creating adaptive action sets...")
subject_name = entity_to_aliases(im_state.entity_classes[subject_id])
object_name = entity_to_aliases(im_state.entity_classes[object_id])
subject_bbox = im_state.entity_proposals[subject_id]
Expand All @@ -104,28 +103,13 @@ def train(parameters):
next_object_adaptive_actions = find_object_neighbors(subject_bbox, im_state.entity_proposals, previously_mined_next_objects)

# creating state + action vectors to feed in DQN
attr_action_len = len(attribute_adaptive_actions)
pred_action_len = len(predicate_adaptive_actions)
next_object_action_len = len(next_object_adaptive_actions)

identity_attr = torch.autograd.Variable(torch.from_numpy(np.identity(attr_action_len)).float()) if attr_action_len > 0 else None
identity_pred = torch.autograd.Variable(torch.from_numpy(np.identity(pred_action_len)).float()) if pred_action_len > 0 else None
identity_next_object = torch.autograd.Variable(torch.from_numpy(np.identity(next_object_action_len)).float()) if next_object_action_len > 0 else None
print("Creating state + action vectors to pass into DQN...")
attribute_state_vectors = create_state_action_vector(state_vector, attribute_adaptive_actions)
predicate_state_vectors = create_state_action_vector(state_vector, predicate_adaptive_actions)
next_object_state_vectors = create_state_action_vector(state_vector, next_object_adaptive_actions)

if torch.cuda.is_available():
identity_attr = identity_attr.cuda() if attr_action_len > 0 else None
identity_pred = identity_pred.cuda() if pred_action_len > 0 else None
identity_next_object = identity_next_object.cuda() if next_object_action_len > 0 else None

attribute_state_vectors, predicate_state_vectors, next_object_state_vectors = None, None, None
if attr_action_len > 0:
attribute_state_vectors = torch.cat([state_vector.repeat(attr_action_len, 1), identity_attr], 1)
if pred_action_len > 0:
predicate_state_vectors = torch.cat([state_vector.repeat(pred_action_len, 1), identity_pred], 1)
if next_object_action_len > 0:
next_object_state_vectors = torch.cat([state_vector.repeat(next_object_action_len, 1), identity_next_object], 1)

# choose action using epsilon greedy
print("Choose action using epsilon greedy...")
attribute_action, predicate_action, next_object_action = None, None, None
if type(attribute_state_vectors) != type(None):
attribute_action = choose_action_epsilon_greedy(attribute_state_vectors, attribute_adaptive_actions, model_attribute_main, parameters["epsilon"], training=replay_buffer.can_sample())
Expand All @@ -134,20 +118,28 @@ def train(parameters):
if type(next_object_state_vectors) != type(None):
next_object_action = choose_action_epsilon_greedy(next_object_state_vectors, next_object_adaptive_actions, model_next_object_main, parameters["epsilon"], training=replay_buffer.can_sample())
# step image_state
print("Step state environment using action...")
attribute_reward, predicate_reward, next_object_reward, done = im_state.step(attribute_action, predicate_action, next_object_action)
next_state = create_state_vector(im_state)
image_states[image_name] = next_state

next_state = create_state_vector(im_state)
im_state = image_states[image_name]
# decay epsilon
if parameters["epsilon"] > parameters["epsilon_end"]:
parameters["epsilon"] = parameters["epsilon"] * parameters["epsilon_anneal_rate"]

# add transition tuple to replay buffer
# TODO: placeholder for where to continue debugging
replay_buffer.push(state_vector, next_state, attribute_adaptive_actions, predicate_adaptive_actions, next_object_adaptive_actions, attribute_reward, predicate_reward, next_object_reward, done)
print("Adding transition tuple to replay buffer...")
subject_name_1 = entity_to_aliases(im_state.entity_classes[im_state.current_subject])
object_name_1 = entity_to_aliases(im_state.entity_classes[im_state.current_object])
previously_mined_attributes_1 = im_state.current_scene_graph["objects"][im_state.current_subject]["attributes"]
previously_mined_next_objects_1 = im_state.objects_explored_per_subject[im_state.current_subject]
attribute_adaptive_actions_1, predicate_adaptive_actions_1 = semantic_action_graph.variation_based_traversal(subject_name_1, object_name_1, previously_mined_attributes)
next_object_adaptive_actions_1 = find_object_neighbors(im_state.entity_proposals[im_state.current_subject], im_state.entity_proposals, previously_mined_next_objects)

replay_buffer.push(state_vector, next_state, attribute_adaptive_actions, predicate_adaptive_actions, next_object_adaptive_actions, attribute_reward, predicate_reward, next_object_reward, attribute_adaptive_actions_1, predicate_adaptive_actions_1, next_object_adaptive_actions_1, done)

# sample minibatch if replay_buffer has enough samples
if replay_buffer.can_sample():
print("Sample minibatch of transitions...")
minibatch_transitions = replay_buffer.sample(parameters["batch_size"])
main_q_attribute_list, main_q_predicate_list, main_q_next_object_list = [], [], []
target_q_attribute_list, target_q_predicate_list, target_q_next_object_list = [], [], []
Expand All @@ -158,26 +150,24 @@ def train(parameters):
target_q_predicate = transition.predicate_reward
target_q_next_object = transition.target_q_next_object
else:

next_state_attribute = torch.concat([transition.next_state,
torch.from_numpy(np.identity(len(transition.next_state_attribute_actions)))], 1)
next_state_predicate = torch.concat([transition.next_state,
torch.from_numpy(np.identity(len(transition.next_state_predicate_actions)))], 1)
next_state_next_object = torch.concat([transition.next_state,
torch.from_numpy(np.identity(len(transition.next_state_next_object_actions)))], 1)
target_q_attribute = transition.attribute_reward + discount_factor * torch.max(model_attribute_target(next_state_attribute))
target_q_predicate = transition.predicate_reward + discount_factor * torch.max(model_predicate_target(next_state_predicate))
target_q_next_object = transition.next_object_reward + discount_factor * torch.max(model_next_object_target(next_state_next_object))
next_state_attribute = create_state_action_vector(transition.next_state, transition.next_state_attribute_actions)
next_state_predicate = create_state_action_vector(transition.next_state, transition.next_state_predicate_actions)
next_state_next_object = create_state_action_vector(transition.next_state, transition.next_state_next_object_actions)
import pdb; pdb.set_trace()
if type(next_state_attribute) != type(None):
target_q_attribute = transition.attribute_reward + parameters["discount_factor"] * torch.max(model_attribute_target(next_state_attribute))
if type(next_state_predicate) != type(None):
target_q_predicate = transition.predicate_reward + parameters["discount_factor"] * torch.max(model_predicate_target(next_state_predicate))
if type(next_state_next_object) != type(None):
target_q_next_object = transition.next_object_reward + parameters["discount_factor"] * torch.max(model_next_object_target(next_state_next_object))
# compute loss
main_state_attribute = torch.concat([transition.state,
torch.from_numpy(np.identity(len(transition.attribute_actions)))], 1)
main_state_predicate = torch.concat([transition.state,
torch.from_numpy(np.identity(len(transition.predicate_actions)))], 1)
main_state_next_object = torch.concat([transition.state,
torch.from_numpy(np.identity(len(transition.next_object_actions)))], 1)
main_q_attribute = transition.attribute_reward + discount_factor * torch.max(model_attribute(main_state_attribute))
main_q_predicate = transition.predicate_reward + discount_factor * torch.max(model_predicate(main_state_predicate))
main_q_next_object = transition.next_object_reward + discount_factor * torch.max(model_next_object(main_state_next_object))
main_state_attribute = create_state_action_vector(transition.state, transition.attribute_actions)
main_state_predicate = create_state_action_vector(transition.state, transition.predicate_actions)
main_state_next_object = create_state_action_vector(transition.state, transition.next_object_actions)

main_q_attribute = transition.attribute_reward + parameters["discount_factor"] * torch.max(model_attribute_main(main_state_attribute))
main_q_predicate = transition.predicate_reward + parameters["discount_factor"] * torch.max(model_predicate_main(main_state_predicate))
main_q_next_object = transition.next_object_reward + parameters["discount_factor"] * torch.max(model_next_object_main(main_state_next_object))

# add to q value lists
target_q_attribute_list.append(target_q_attribute)
Expand Down Expand Up @@ -254,9 +244,21 @@ def create_state_vector(image_state):
curr_object_feature = image_state.entity_features[image_state.current_object]
return torch.cat([torch.squeeze(image_state.image_feature), torch.squeeze(curr_subject_feature), torch.squeeze(curr_object_feature)])

def create_state_action_vector(state_vector, action_set):
len_action_set = len(action_set)
if len_action_set == 0:
return None
else:
identity = torch.autograd.Variable(torch.from_numpy(np.identity(len_action_set)).float())
if torch.cuda.is_available():
identity = identity.cuda()
model_input = torch.cat([state_vector.repeat(len_action_set, 1), identity], 1)
return model_input.view(1, 1, model_input.size(0), model_input.size(1))

def choose_action_epsilon_greedy(state, adaptive_action_set, model, epsilon, training=False):
sample = random.random()
if sample > epsilon and training: # exploit
import pdb; pdb.set_trace()
return adaptive_action_set[model(state).data.max(1)]
else: # explore
return random.choice(adaptive_action_set)
Expand All @@ -282,15 +284,15 @@ def update_target(main_model, target_model):
parser.add_argument("--train", help="trains model", action="store_true")
parser.add_argument("--test", help="evaluates model", action="store_true")
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs to train on")
parser.add_argument("--batch_size", type=int, default=64, help="batch size to use")
parser.add_argument("--batch_size", type=int, default=4, help="batch size to use")
parser.add_argument("--discount_factor", type=float, default=0.9, help="discount factor")
parser.add_argument("--learning_rate", type=float, default=0.0007, help="learning rate")
parser.add_argument("--epsilon", type=float, default=1, help="epsilon starting value (used in epsilon greedy)")
parser.add_argument("--epsilon_anneal_rate", type=float, default=0.045, help="factor to anneal epsilon by")
parser.add_argument("--epsilon_end", type=float, default=0.1, help="minimum value of epsilon (when we can stop annealing)")
parser.add_argument("--target_update_frequency", type=int, default=10000, help="how often to update the target")
parser.add_argument("--replay_buffer_capacity", type=int, default=20000, help="maximum size of the replay buffer")
parser.add_argument("--replay_buffer_minimum_number_samples", type=int, default=500, help="Minimum replay buffer size before we can sample")
parser.add_argument("--replay_buffer_minimum_number_samples", type=int, default=8, help="Minimum replay buffer size before we can sample")
parser.add_argument("--object_detection_threshold", type=float, default=0.005, help="threshold for Faster RCNN module when detecting objects")
parser.add_argument("--maximum_num_entities_per_image", type=int, default=10, help="maximum number of entities to explore per image")
parser.add_argument("--maximum_adaptive_action_space_size", type=int, default=20, help="maximum size of adaptive_action space")
Expand Down
1 change: 1 addition & 0 deletions replay_buffer.py
@@ -1,4 +1,5 @@
from collections import namedtuple
import random

Transition = namedtuple('Transition', ('state', 'next_state', 'attribute_actions', 'predicate_actions',
'next_object_actions', 'attribute_reward', 'predicate_reward',
Expand Down

0 comments on commit a7a407c

Please sign in to comment.