## Neural Agent
- graph-based

In [1]:
import networkx as nx

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device: {}".format(device))
print("torch version: {}".format(torch.__version__))

import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

device: cpu
torch version: 1.6.0


In [2]:
# override the same function from pytorch_geometric
# this provides consistent renaming from node labels to integers
# because the original impl makes a node's __repr__ as key
# whose ordering is subject to change by calling the default conversion function
def from_networkx(G, rmap):
    r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
    :class:`torch_geometric.data.Data` instance.

    Args:
        G (networkx.Graph or networkx.DiGraph): A networkx graph.
        rmap: A mapping from G's node keys to corresponding integers
    """

    # G = nx.convert_node_labels_to_integers(G)
    G = nx.relabel_nodes(G, rmap)
    G = G.to_directed() if not nx.is_directed(G) else G
    edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

    data = {}

    for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
        for key, value in feat_dict.items():
            data[str(key)] = [value] if i == 0 else data[str(key)] + [value]

    for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
        for key, value in feat_dict.items():
            data[str(key)] = [value] if i == 0 else data[str(key)] + [value]

    for key, item in data.items():
        try:
            data[key] = torch.tensor(item)
        except ValueError:
            pass

    data['edge_index'] = edge_index.view(2, -1)
    data = torch_geometric.data.Data.from_dict(data)
    data.num_nodes = G.number_of_nodes()

    return data

In [3]:
class Net(torch.nn.Module):
    def __init__(self, config):
        super(Net, self).__init__()
        self.conv1 = GCNConv(config["num_node_features"], 16)
        self.conv2 = GCNConv(16, config["num_hidden"])

    def encode(self, data):
        r"""Encode the whole graph and return embedding of every node
        """
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        # return F.log_softmax(x, dim=1)
        return F.relu(x)
    
    def compute_pref(self, cur, tgt, act):
        # note that this function only deals with 1 state 1 target with 1 action list
        # cur: (n=1, num_hidden)
        # tgt: (n=1, num_hidden)
        # act: (num_actions, num_hidden)
        tmp0 = torch.mean(torch.cat([cur, tgt], dim=0), dim=0, keepdim=True) # (1, num_hidden)
        tmp_act = act.view(config["num_hidden"], -1) # (num_hidden, num_actions)
        
        tmp1 = torch.mm(tmp0,tmp_act) # (B=1, num_actions)
        return F.log_softmax(tmp1, dim=1) # (B=1, num_actions)
        

## Pipeline Utils

In [4]:
# 1. roll out an action sequence
# 2. compute reward
# 3. policy gradient back propagation
def rollout(arg_config):
    # note: remember to clear the state
    arg_config["env"].launch_app()
    arg_config["agent"].train()
    
    rollout_outputs = []
    rollout_actions = []
    rollout_action_ids = []
    final_reward = 0.0
    
    raw_graph = from_networkx(arg_config["wtg"].wtg_graph, arg_config["obj2ind"])
    
    for i in range(arg_config["maxn_steps"]):
        time.sleep(1)
        curr_state = arg_config["env"].get_current_state()
        curr_actions = arg_config["env"].get_available_actionable_elements(curr_state)
        n_actions = len(action_list)
        
        if n_actions == 0:
            print("# no action is found, terminate.")
            # penalty
            final_reward = -3
            # no available actions any more
            break
        
        # get corresponding components in wtg
        curr0_state = arg_config["env"].get_wtg_state(arg_config["wtg"])
        curr0_actions = [
            arg_config["env"].get_matching_dynamic_action_to_static_action(curr_actions[dind], arg_config["wtg"])
            for dind in range(len(curr_actions))
        ]
        # TODO: mark the current state out in wtg
        # update a copy of the current graph to reflect the state/environment
        # TODO: should dynamically construct features
        curr_graph = torch_geometric.data.Data(x=arg_config["features"], edge_index=raw_graph.edge_index)
        # then the graph is ready as agent input
        tf_nodes = arg_config["agent"].encode(curr_graph) # (num_nodes, num_hidden)
        
        # then get the embedding of current state
        curr0_states = arg_config["env"].get_wtg_state(arg_config["wtg"])
        # now state list is a list of nodes
        # collapse the state list into concrete representation
        curr1_states = []
        for dind0 in range(len(curr0_states)):
            p = curr0_states[dind0]
            ind0 = arg_config["obj2ind"][p]
            curr1_states.append(tf_nodes[ind0,:])
        tf_state = torch.mean(torch.cat(curr1_states, dim=0), dim=0, keepdim=True) # (1, num_hidden)
        
        # then get the embedding of goal edge(s)
        curr0_goals = arg_config["wtg"].get_goal_edges()
        # now goal list is a list of (node0, node1, edge_id)
        # FIXME: now assuming there's only 1 goal
        # collapse goal edge(s) into concrete representations
        curr1_goals = []
        for dind0 in range(len(curr0_goals)):
            p = curr0_goals[dind0]
            ind0 = arg_config["obj2ind"][p[0]]
            ind1 = arg_config["obj2ind"][p[1]]
            curr1_goals.append(tf_nodes[ind0,:])
            curr1_goals.append(tf_nodes[ind1,:])
        tf_goal = torch.mean(torch.cat(curr1_goals, dim=0), dim=0, keepdim=True) # (1, num_hidden)
    
        # now action list is a list of lists of (node0, node1, edge_id)
        # collapse action list into concrete representations
        curr1_actions = []
        for dind0 in range(len(curr0_actions)):
            plist0 = curr0_actions[dind0]
            rlist0 = []
            for p in plist0:
                ind0 = arg_config["obj2ind"][p[0]]
                ind1 = arg_config["obj2ind"][p[1]]
                rlist0.append(tf_nodes[ind0,:])
                rlist0.append(tf_nodes[ind1,:])
            # (1, num_hidden)
            curr1_actions.append( torch.mean(torch.cat(rlist0, dim=0), dim=0, keepdim=True) )
        tf_actions = torch.cat(curr1_actions, dim=0) # (num_actions, num_hidden)
        
        
        tout_preference = arg_agent.compute_preference(t_state, t_target) # (B=1, embedding_dim)
        tout_action = arg_agent.encode_action_list(t_action) # (B=1, ??, embedding_dim)
        
#         print("tout_preference.shape={}".format(tout_preference.shape))
#         print("tout_action.shape={}".format(tout_action.shape))
        
        # ====> using cosine similarity
        # (n_actions, spec_dims)
        # t0_output = t_output.expand_as(t_pool)
        # t_cos = F.cosine_similarity(t0_output, t_pool, dim=1)
        # t_act = F.log_softmax(t_cos, dim=0)
        # ====> directly mm similarity
        # note: assuming B=1 already
        tout0_preference = tout_preference.view(-1,1) # (embedding_dim, 1)
        tout0_action = tout_action.view(-1, EMBEDDING_DIM) # (n_actions, embedding_dim)
        tout0_mm = torch.mm(tout0_action, tout0_preference)  # (n_actions, 1)
        tout1_mm = tout0_mm.view(-1) # (n_actions,)
        tout2_mm = F.log_softmax(tout1_mm)
#         print("# tout2_mm: {}".format(tout2_mm))
        
        if random.random()<max(0.1, 1.0-ep/20):
            # explore
            selected_action_id = random.choice(list(range(len(action_list))))
            print("# [explore] selected_action_id (rnd): {}, log-sim: {}".format(selected_action_id, tout2_mm[selected_action_id]))
        else:
            # exploit
            probs = tout2_mm.exp().tolist()
            selected_action_id = np.argmax(probs)
            # selected_action_id = torch.argmax(tout2_mm, dim=0).tolist()
            print("# [exploit] selected_action_id (mul): {}, log-sim: {}".format(selected_action_id, tout2_mm[selected_action_id]))
#         else:
#             # exploit
#             probs = tout2_mm.exp().tolist()
#             selected_action_id = random.choices(list(range(len(action_list))), weights=probs, k=1)[0]
#             # selected_action_id = torch.argmax(tout2_mm, dim=0).tolist()
#             print("# [exploit] selected_action_id (mul): {}, log-sim: {}".format(selected_action_id, tout2_mm[selected_action_id]))
        
        # perform action
        arg_env.perform_action(action_list[selected_action_id])
        next_state = arg_env.get_current_state()
        dreward = get_reward0(arg_env, arg_wtg, next_state)
        print("  # r: {}".format(dreward))
        final_reward += dreward
        
        # store the choices
        rollout_outputs.append(tout2_mm)
        rollout_actions.append(action_list)
        rollout_action_ids.append(selected_action_id)
        
        # input("PAUSE")
        
    # here we use the final reward as the cumulative reward
    final_state = arg_env.get_current_state()
    # test whether goal states are reached
    rlist = arg_env.get_reached_goal_states("train")
    print("# final reward000: {}".format(final_reward))
    if len(rlist)>0:
        print("# goal state: {}".format(rlist))

    print("# final reward: {}".format(final_reward))
    
    rollout_loss = []
    current_reward = final_reward
    # reverse from the last to first
    for i in range(len(rollout_outputs))[::-1]:
        rollout_loss.append( current_reward * (-rollout_outputs[i][rollout_action_ids[i]]) )
        current_reward *= 0.8 # decay
    rollout_loss = rollout_loss[::-1]
    
    optimizer.zero_grad()
    loss = sum(rollout_loss)
    loss.backward()
    optimizer.step()

## Top-Level Control Flow

In [5]:
from main import *

CURR_DIR = os.path.dirname(os.getcwd())
OUTPUT_DIR = os.path.join(CURR_DIR, "results")

args = {
#     "path": "../results/test_app_1/testapp_1.apk",
#     "path": "../results/test_app_2/testapp_2.apk",
#     "path": "../test/com.github.cetoolbox_11/app_simple0.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/tmp/Wordpress_394/Wordpress_394.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/tmp/com.zoffcc.applications.aagtl_31/com.zoffcc.applications.aagtl_31.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/tmp/Translate/Translate.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/tmp/com.chmod0.manpages_3/com.chmod0.manpages_3.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/tmp/Book-Catalogue/Book-Catalogue.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/test/out.andFHEM.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/test/out.blue-chat.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/test/out.CallMeter3G-debug.apk",
#     "path": "/Users/joseph/Desktop/UCSB/20summer/MarthaEnv/test/out.Lucid-Browser.apk",
    "path": "../results/test_app_3/testapp_3.apk",
    "output": "../results/",
    "wtginput": "../results/test_app_3/",
    "goalstates": "../results/test_app_3/goals_caller.json",
}

if args["path"] is not None:
    pyaxmlparser_apk = APK(args["path"])
    apk_base_name = os.path.splitext(os.path.basename(args["path"]))[0]
else:
    parser.print_usage()
    sys.exit(1)
    
goal_states = {}
if args["goalstates"] is not None:
    with open(args["goalstates"], 'r') as fp:
        goal_states = json.load(fp)

else:
    parser.print_usage()
    sys.exit(1)

if args["output"] is not None:
    OUTPUT_DIR = args["output"]

output_dir = os.path.join(OUTPUT_DIR, 'exploration_output', apk_base_name)

wtg = None
if args["wtginput"]:
    wtg = args["wtginput"] #os.path.join(args.wtginput, apk_base_name)

if os.path.exists(output_dir):
    rmtree(output_dir)

os.makedirs(output_dir, exist_ok=True)

# Setting the path for log file
log_path = os.path.join(output_dir, 'analysis.log')
log = init_logging('analyzer.%s' % apk_base_name, log_path, file_mode='w', console=True)

# Record analysis start time
now = datetime.datetime.now()
analysis_start_time = now.strftime(DATE_FORMAT)
info('Analysis started at: %s' % analysis_start_time)
start_time = time.time()

# Get the serial for the device attached to ADB
device_serial = get_device_serial(log)

if device_serial is None:
    log.warning("Device is not connected!")
    sys.exit(1)

# Initialize the uiautomator device object using the device serial
uiautomator_device = u2.connect(device_serial)
run_adb_as_root(log, device_serial)
apk_obj = Apk(args["path"], uiautomator_device, output_dir, log, device_serial)
wtg_obj = WTG(wtg, log)
wtg_obj.set_goal_nodes(goal_states)
apk_obj.launch_app()
# to track some goal state at startup, you don't have to do this
apk_obj.clean_logcat()

[W 210727 23:19:35 __init__:208] [pid:25474] atx-agent has something wrong, auto recovering
[D 210727 23:19:35 __init__:292] [pid:25474] device 93MAYS0020Z is online


[36m[#] Analysis started at: 2021-07-27 11:19:35 PM[0m


[I 210727 23:19:35 init:156] uiautomator2 version: 2.16.6
[D 210727 23:19:35 init:295] Real version: [0, 10, 0], Expect version: [0, 10, 0]
[D 210727 23:19:36 init:256] apk-debug package-info: {'package_name': 'com.github.uiautomator', 'version_name': '2.3.3', 'version_code': '2003003', 'flags': ['DEBUGGABLE', 'HAS_CODE', 'ALLOW_CLEAR_USER_DATA', 'ALLOW_BACKUP'], 'first_install_time': datetime.datetime(2021, 7, 28, 2, 16, 46), 'last_update_time': datetime.datetime(2021, 7, 28, 2, 16, 46), 'signature': 'ae17cd86], past signatures:['}
[D 210727 23:19:36 init:257] apk-debug-test package-info: {'package_name': 'com.github.uiautomator.test', 'version_name': '', 'version_code': '', 'flags': ['DEBUGGABLE', 'HAS_CODE', 'ALLOW_CLEAR_USER_DATA', 'ALLOW_BACKUP'], 'first_install_time': datetime.datetime(2021, 7, 28, 2, 16, 57), 'last_update_time': datetime.datetime(2021, 7, 28, 2, 16, 57), 'signature': 'ae17cd86], past signatures:['}
[D 210727 23:19:36 init:167] Shell: ('/data/local/tmp/atx-agent'

In [8]:
nsteps = 4
neural_agent = NeuralAgent(EMBEDDING_DIM).to(device)
optimizer = torch.optim.SGD(neural_agent.parameters(), lr=0.1)
st = time.time()
for ep in range(100000):
    print("# ep{}, time elapsed: {}".format(ep, time.time()-st))
    rollout(apk_obj, neural_agent, optimizer, nsteps, target)
#     random_rollout(apk_obj, nsteps, target)

TypeError: 'int' object is not subscriptable

In [6]:
apk_obj.get_available_actionable_elements(apk_obj.get_current_state())

[D 210328 15:42:39 __init__:600] kill process(ps): uiautomator
[D 210328 15:42:41 __init__:619] uiautomator-v2 is starting ... left: 40.0s
[D 210328 15:42:42 __init__:619] uiautomator-v2 is starting ... left: 38.8s
[D 210328 15:42:44 __init__:619] uiautomator-v2 is starting ... left: 37.5s
[D 210328 15:42:45 __init__:619] uiautomator-v2 is starting ... left: 36.3s
[D 210328 15:42:46 __init__:619] uiautomator-v2 is starting ... left: 35.1s
[D 210328 15:42:47 __init__:619] uiautomator-v2 is starting ... left: 33.8s
[I 210328 15:42:48 __init__:583] uiautomator back to normal


[<gui_elements.GuiElements at 0x7f8af57c9e48>,
 <gui_elements.GuiElements at 0x7f8af57c9ef0>,
 <gui_elements.GuiElements at 0x7f8af57c9eb8>,
 <gui_elements.GuiElements at 0x7f8af57c9f98>,
 <gui_elements.GuiElements at 0x7f8af57c9f60>,
 <gui_elements.GuiElements at 0x7f8af57c9fd0>]

In [7]:
apk_obj.get_wtg_state(wtg_obj)

[<wtg_node.WTGNode at 0x7f8af5794be0>]

In [53]:
# node mapping to integer
# use map1 to from_networkx, and use map0 as reference
nodes0 = sorted(list(wtg_obj.nodes.keys())) # ["n1", "n2", ...]
label2ind = {nodes0[dind]:dind for dind in range(len(nodes0))} # "n?" -> ind
node2ind = {wtg_obj.nodes[nodes0[dind]]:dind for dind in range(len(nodes0))} # obj -> ind
display(label2ind, node2ind)
# construct graph feature tensor
features = torch.tensor([
    [1,0] for dind in range(len(nodes0))
], dtype=torch.float)
display(features)

{'n1': 0,
 'n2': 1,
 'n3': 2,
 'n4': 3,
 'n5': 4,
 'n6': 5,
 'n7': 6,
 'n8': 7,
 'n9': 8}

{<wtg_node.WTGNode at 0x7f8d9376e5c0>: 0,
 <wtg_node.WTGNode at 0x7f8d9376e8d0>: 1,
 <wtg_node.WTGNode at 0x7f8d9376e7f0>: 2,
 <wtg_node.WTGNode at 0x7f8d93765ef0>: 3,
 <wtg_node.WTGNode at 0x7f8d9376e668>: 4,
 <wtg_node.WTGNode at 0x7f8d9376e860>: 5,
 <wtg_node.WTGNode at 0x7f8d9376e978>: 6,
 <wtg_node.WTGNode at 0x7f8d9376e6d8>: 7,
 <wtg_node.WTGNode at 0x7f8d9376e748>: 8}

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]])

In [54]:
tt = from_networkx(wtg_obj.wtg_graph, obj2ind)
tt

Data(edge_index=[2, 102])

In [55]:
zz = torch_geometric.data.Data(x=features, edge_index=tt.edge_index)

In [56]:
zz.num_features

2

In [72]:
zz.to(device)

Data(edge_index=[2, 102], x=[9, 2])

In [3]:
aa = apk_obj.get_available_actionable_elements(apk_obj.get_current_state())
aa

[<gui_elements.GuiElements at 0x7fbb01318ba8>,
 <gui_elements.GuiElements at 0x7fbb01318c50>,
 <gui_elements.GuiElements at 0x7fbb01318c18>,
 <gui_elements.GuiElements at 0x7fbb01318cf8>,
 <gui_elements.GuiElements at 0x7fbb01318cc0>,
 <gui_elements.GuiElements at 0x7fbb01318d30>]

In [4]:
apk_obj.get_matching_dynamic_action_to_static_action(aa[0], wtg_obj)

[(<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2d30>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2b38>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2a20>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  1),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2c50>,
  0)]

In [5]:
apk_obj.get_matching_dynamic_action_to_static_action(aa[1], wtg_obj)

[(<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2d30>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2b38>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2a20>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  1),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2c50>,
  0)]

In [6]:
apk_obj.get_matching_dynamic_action_to_static_action(aa[2], wtg_obj)

[(<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2d30>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2b38>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2a20>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  1),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2c50>,
  0)]

In [7]:
apk_obj.get_matching_dynamic_action_to_static_action(aa[3], wtg_obj)

[(<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2d30>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2b38>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2a20>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  1),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2c50>,
  0)]

In [8]:
apk_obj.get_matching_dynamic_action_to_static_action(aa[4], wtg_obj)

[(<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2d30>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2b38>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2a20>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  1),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2c50>,
  0)]

In [9]:
apk_obj.get_matching_dynamic_action_to_static_action(aa[5], wtg_obj)

[(<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2d30>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2b38>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2a20>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  0),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2ac8>,
  1),
 (<wtg_node.WTGNode at 0x7fbb01111668>,
  <wtg_node.WTGNode at 0x7fbb012f2c50>,
  0)]

In [None]:
apk_obj.get_reached_goal_states("train")

In [None]:
apk_obj.get_current_state()

In [70]:
aa = apk_obj.get_available_actionable_elements(apk_obj.get_current_state())
aa

[<gui_elements.GuiElements at 0x7f8d935a7978>]

In [12]:
apk_obj.get_wtg_state(wtg_obj)

[<wtg_node.WTGNode at 0x7ff5f1908f60>]

In [27]:
wtg_obj.get_goal_edges()

[(<wtg_node.WTGNode at 0x7ff5f1908e10>,
  <wtg_node.WTGNode at 0x7ff5f1908e10>,
  1),
 (<wtg_node.WTGNode at 0x7ff5f1908e10>,
  <wtg_node.WTGNode at 0x7ff5f12b86d8>,
  0)]

In [None]:
wtg_obj.wtg

In [None]:
nx.all_pairs_shortest_path(wtg_obj.wtg)

In [None]:
nx.shortest_path(wtg_obj.wtg, apk_obj.get_wtg_state(wtg_obj)[0], apk_obj.get_wtg_state(wtg_obj)[0])

In [28]:
apk_obj.get_wtg_state(wtg_obj)

[<wtg_node.WTGNode at 0x7ff5f1908d30>]

In [None]:
wtg_obj.wtg_graph

In [36]:
wtg_obj.nodes

{'n4': <wtg_node.WTGNode at 0x7ff5f12b86d8>,
 'n1': <wtg_node.WTGNode at 0x7ff5f12b8630>,
 'n5': <wtg_node.WTGNode at 0x7ff5f1908fd0>,
 'n8': <wtg_node.WTGNode at 0x7ff5f1908f60>,
 'n9': <wtg_node.WTGNode at 0x7ff5f1908ef0>,
 'n3': <wtg_node.WTGNode at 0x7ff5f1908e10>,
 'n6': <wtg_node.WTGNode at 0x7ff5f1908da0>,
 'n2': <wtg_node.WTGNode at 0x7ff5f1908d30>,
 'n7': <wtg_node.WTGNode at 0x7ff5f1908c88>}

In [None]:
wtg_obj.wtg

In [None]:
wtg_obj.wtg.nodes

In [None]:
wtg_obj.wtg_graph

In [None]:
nx.shortest_path(wtg_obj.wtg_graph, apk_obj.get_wtg_state(wtg_obj)[0], apk_obj.get_wtg_state(wtg_obj)[0])

In [None]:
nx.shortest_path_length(wtg_obj.wtg_graph, apk_obj.get_wtg_state(wtg_obj)[0], apk_obj.get_wtg_state(wtg_obj)[0])

In [None]:
wtg_obj.get_goal_edges()

In [None]:
nx.shortest_path_length(
    wtg_obj.wtg_graph, 
    apk_obj.get_wtg_state(wtg_obj)[0],
    wtg_obj.get_goal_edges()[0][1]
)

In [21]:
wtg_obj.wtg_graph

<networkx.classes.multidigraph.MultiDiGraph at 0x7feae079bf98>

In [55]:
wtg_obj.nodes

{'n4': <wtg_node.WTGNode at 0x7ff5f12b86d8>,
 'n1': <wtg_node.WTGNode at 0x7ff5f12b8630>,
 'n5': <wtg_node.WTGNode at 0x7ff5f1908fd0>,
 'n8': <wtg_node.WTGNode at 0x7ff5f1908f60>,
 'n9': <wtg_node.WTGNode at 0x7ff5f1908ef0>,
 'n3': <wtg_node.WTGNode at 0x7ff5f1908e10>,
 'n6': <wtg_node.WTGNode at 0x7ff5f1908da0>,
 'n2': <wtg_node.WTGNode at 0x7ff5f1908d30>,
 'n7': <wtg_node.WTGNode at 0x7ff5f1908c88>}

In [75]:
map0 = {wtg_obj.nodes[dkey]:dkey for dkey in wtg_obj.nodes.keys()}
map0

{<wtg_node.WTGNode at 0x7ff5f12b86d8>: 'n4',
 <wtg_node.WTGNode at 0x7ff5f12b8630>: 'n1',
 <wtg_node.WTGNode at 0x7ff5f1908fd0>: 'n5',
 <wtg_node.WTGNode at 0x7ff5f1908f60>: 'n8',
 <wtg_node.WTGNode at 0x7ff5f1908ef0>: 'n9',
 <wtg_node.WTGNode at 0x7ff5f1908e10>: 'n3',
 <wtg_node.WTGNode at 0x7ff5f1908da0>: 'n6',
 <wtg_node.WTGNode at 0x7ff5f1908d30>: 'n2',
 <wtg_node.WTGNode at 0x7ff5f1908c88>: 'n7'}

In [76]:
qq = nx.relabel_nodes(wtg_obj.wtg_graph, map0)
qq

<networkx.classes.multidigraph.MultiDiGraph at 0x7ff5d5c3bb70>

In [77]:
qq.nodes

NodeView(('n4', 'n1', 'n5', 'n8', 'n9', 'n3', 'n6', 'n2', 'n7'))

In [56]:
tt.nodes

NodeView((<wtg_node.WTGNode object at 0x7ff5f12b86d8>, <wtg_node.WTGNode object at 0x7ff5f12b8630>, <wtg_node.WTGNode object at 0x7ff5f1908fd0>, <wtg_node.WTGNode object at 0x7ff5f1908f60>, <wtg_node.WTGNode object at 0x7ff5f1908ef0>, <wtg_node.WTGNode object at 0x7ff5f1908e10>, <wtg_node.WTGNode object at 0x7ff5f1908da0>, <wtg_node.WTGNode object at 0x7ff5f1908d30>, <wtg_node.WTGNode object at 0x7ff5f1908c88>))

In [13]:
from torch_geometric.utils import from_networkx

In [14]:
tt = wtg_obj.wtg_graph
dd = from_networkx(tt)

In [15]:
dd

Data(edge_index=[2, 102])

In [44]:
tt.nodes(data=True)

NodeDataView({<wtg_node.WTGNode object at 0x7ff5f12b86d8>: {}, <wtg_node.WTGNode object at 0x7ff5f12b8630>: {}, <wtg_node.WTGNode object at 0x7ff5f1908fd0>: {}, <wtg_node.WTGNode object at 0x7ff5f1908f60>: {}, <wtg_node.WTGNode object at 0x7ff5f1908ef0>: {}, <wtg_node.WTGNode object at 0x7ff5f1908e10>: {}, <wtg_node.WTGNode object at 0x7ff5f1908da0>: {}, <wtg_node.WTGNode object at 0x7ff5f1908d30>: {}, <wtg_node.WTGNode object at 0x7ff5f1908c88>: {}})

In [48]:
dir(dd)

['__apply__',
 '__call__',
 '__cat_dim__',
 '__class__',
 '__contains__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__inc__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__num_nodes__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'apply',
 'clone',
 'coalesce',
 'contains_isolated_nodes',
 'contains_self_loops',
 'contiguous',
 'debug',
 'edge_attr',
 'edge_index',
 'face',
 'from_dict',
 'is_coalesced',
 'is_directed',
 'is_undirected',
 'keys',
 'normal',
 'num_edge_features',
 'num_edges',
 'num_faces',
 'num_features',
 'num_node_features',
 'num_nodes',
 'pos',
 'to',
 'to_dict',
 'to_namedtuple',
 'x',
 'y']

In [57]:
tt.nodes

NodeView((<wtg_node.WTGNode object at 0x7ff5f12b86d8>, <wtg_node.WTGNode object at 0x7ff5f12b8630>, <wtg_node.WTGNode object at 0x7ff5f1908fd0>, <wtg_node.WTGNode object at 0x7ff5f1908f60>, <wtg_node.WTGNode object at 0x7ff5f1908ef0>, <wtg_node.WTGNode object at 0x7ff5f1908e10>, <wtg_node.WTGNode object at 0x7ff5f1908da0>, <wtg_node.WTGNode object at 0x7ff5f1908d30>, <wtg_node.WTGNode object at 0x7ff5f1908c88>))

In [52]:
rr = nx.convert_node_labels_to_integers(tt)

In [53]:
rr.nodes

NodeView((0, 1, 2, 3, 4, 5, 6, 7, 8))

In [65]:
rr.nodes

NodeView((0, 1, 2, 3, 4, 5, 6, 7, 8))

In [74]:
list(tt.nodes.keys())[0]

<wtg_node.WTGNode at 0x7ff5f12b86d8>

In [54]:
tt.nodes

NodeView((<wtg_node.WTGNode object at 0x7ff5f12b86d8>, <wtg_node.WTGNode object at 0x7ff5f12b8630>, <wtg_node.WTGNode object at 0x7ff5f1908fd0>, <wtg_node.WTGNode object at 0x7ff5f1908f60>, <wtg_node.WTGNode object at 0x7ff5f1908ef0>, <wtg_node.WTGNode object at 0x7ff5f1908e10>, <wtg_node.WTGNode object at 0x7ff5f1908da0>, <wtg_node.WTGNode object at 0x7ff5f1908d30>, <wtg_node.WTGNode object at 0x7ff5f1908c88>))

In [42]:
for i, (Q, feat_dict) in enumerate(tt.nodes(data=True)):
    print("i={}, Q={}, fd={}".format(i, Q, feat_dict))
    for key, value in feat_dict.items():
        print("key={}, value={}".format(key, value))

TypeError: __str__ returned non-string (type NoneType)

In [18]:
dd.num_node_features

0

In [21]:
dd.num_nodes

9

In [22]:
dd.num_edges

102

In [23]:
dd.is_directed()

True

In [25]:
dd.keys

['edge_index']

In [26]:
dd.edge_index

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 4, 0, 0, 0,
         1, 7, 7, 2, 2, 3, 5, 5, 1, 1, 1, 0, 5, 7, 2, 7, 7, 7, 0, 0, 0, 2, 2, 1,
         1, 1, 5, 5, 5, 3],
        [0, 0, 0, 5, 7, 0, 5, 5, 5, 5, 5, 5, 8, 0, 0, 2, 1, 7, 7, 7, 6, 7, 7, 7,
         1, 0, 0, 2, 2, 6, 5, 5, 5, 8, 0, 0, 1, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1,
         2, 0, 0, 0, 0, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 0, 1, 2, 0,
         0, 0, 0, 1, 2, 3, 4, 5, 0, 0, 1, 0, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 0,
         1, 0, 0, 1, 2, 0]])

In [16]:
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dd.num_node_features, 16)
        self.conv2 = GCNConv(16, dd.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [17]:
model = Net().to(device)

NameError: name 'dataset' is not defined

In [None]:
model.train()
out = model(dd)
out