In [208]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))

from antra.antra import *
import json
import torch

In [209]:
def index2corrdinates(index:int, 
                      grid_size=6):
    row = index//grid_size
    col = index%grid_size
    return (row, col)

def get_start_state():
    def f(agent_positions_batch, target_positions_batch):
        results = []
        batch_size = len(agent_positions_batch)
        for i in range(0, batch_size):
            x = index2corrdinates(target_positions_batch[i][0])[0] - \
                    index2corrdinates(agent_positions_batch[i][0])[0]
            y = index2corrdinates(target_positions_batch[i][0])[1] - \
                    index2corrdinates(agent_positions_batch[i][0])[1]
            o = 0 # default DOWN
            a = 0 # default NULL
            results.append([x, y, o, a])
        return torch.tensor(results)
    return f

def update_state():
    def f(prev_s):
        
        o_map = {
            "DOWN":0,
            "UP":1,
            "LEFT":2,
            "RIGHT":3,
            0:"DOWN",
            1:"UP",
            2:"LEFT",
            3:"RIGHT"
        }
        
        a_map = {
            "NULL":0,
            "walk":4,
            "turn left":5,
            "turn right":3,
            0:"NULL",
            4:"walk",
            5:"turn left",
            3:"turn right"
        }
        
        results = []
        batch_size = len(prev_s)
        for i in range(0, batch_size):
            prev_x = int(prev_s[i][0])
            prev_y = int(prev_s[i][1])
            prev_o = int(prev_s[i][2])
            prev_a = int(prev_s[i][3])

            # update x
            if o_map[prev_o] =="RIGHT" and a_map[prev_a] == "walk":
                x = prev_x + 1
            elif o_map[prev_o] =="LEFT" and a_map[prev_a] == "walk":
                x = prev_x - 1
            else:
                x = prev_x
            
            # update y
            if o_map[prev_o] =="UP" and a_map[prev_a] == "walk":
                y = prev_y + 1
            elif o_map[prev_o] =="DOWN" and a_map[prev_a] == "walk":
                y = prev_y - 1
            else:
                y = prev_y
            
            # update o
            rotate_left = {"RIGHT":"UP", "LEFT":"DOWN", "UP":"LEFT","DOWN":"RIGHT"}
            rotate_right = {"UP":"RIGHT", "DOWN":"LEFT", "LEFT":"UP","RIGHT":"DOWN"}
            if a_map[prev_a] == "walk" or a_map[prev_a] == "NULL":
                o = prev_o
            elif a_map[prev_a] == "turn left":
                o = o_map[rotate_left[o_map[prev_o]]]
            elif a_map[prev_a] == "turn right":
                o = o_map[rotate_right[o_map[prev_o]]]
            
            # update a
            if y == 0 and x == 0:
                a = a_map["NULL"]
            if y == 0:
                if x > 0:
                    if o_map[o] == "LEFT":
                        a = a_map["walk"]
                    elif o_map[o] in ["RIGHT", "UP"]:
                        a = a_map["turn left"]
                    elif o_map[o] == "DOWN":
                        a = a_map["turn right"]
                if x < 0:
                    if o_map[o] == "RIGHT":
                        a = a_map["walk"]
                    elif o_map[o]  == "UP":
                        a = a_map["turn right"]
                    elif o_map[o] in ["LEFT", "DOWN"]:
                        a = a_map["turn left"]
            else:
                if y > 0:
                    if o_map[o] == "DOWN":
                        a = a_map["walk"]
                    elif o_map[o] in ["UP", "LEFT"]:
                        a = a_map["turn left"]
                    elif o_map[o] == "RIGHT":
                        a = a_map["turn right"]
                if y < 0:
                    if o_map[o] == "UP":
                        a = a_map["walk"]
                    elif o_map[o] in ["DOWN", "RIGHT"]:
                        a = a_map["turn left"]
                    elif o_map[o] == "LEFT":
                        a = a_map["turn right"]
            results.append([x, y, o, a])
        return torch.tensor(results)
    return f
            
def get_counter_compgraph(
    max_decode_step,
    cache_results=False
):
    agent_positions_batch = GraphNode.leaf(name=f"agent_positions_batch")
    target_positions_batch = GraphNode.leaf(name=f"target_positions_batch")
    fs = get_start_state()
    _s = GraphNode(agent_positions_batch, target_positions_batch, name="s0", forward=fs)
    for i in range(1, max_decode_step+1):
        fs = update_state()
        new_s = GraphNode(_s, name="s"+str(i), forward=fs)
        _s = new_s
    return ComputationGraph(_s)

In [210]:
hi1 = GraphInput({
    "agent_positions_batch": torch.tensor([[1],[2]]),
    "target_positions_batch": torch.tensor([[3],[4]]),
    },
    batched=True, batch_dim=0, cache_results=False
)

hi2 = GraphInput({
    "agent_positions_batch": torch.tensor([[1],[13]]),
    "target_positions_batch": torch.tensor([[10],[21]]),
    },
    batched=True, batch_dim=0, cache_results=False
)
G = get_counter_compgraph(
    max_decode_step=20
)

In [224]:
hi2_hidden = G.compute_node("s6", hi2)

In [229]:
hi2_hidden

tensor([[0, 0, 2, 0],
        [0, 0, 2, 0]])

In [237]:
high_interv = Intervention(
    hi1, {"s1": torch.tensor([[1, 3, 0, 4], [1, -1, 0, 3]])}, 
    cache_results=False,
    cache_base_results=False,
    batched=True
)

In [238]:
high_interv

{'base': [(('agent_positions_batch', (1,)), ('target_positions_batch', (3,))),
          (('agent_positions_batch', (2,)), ('target_positions_batch', (4,)))],
 'interv': [(('s1', (1, 3, 0, 4)),), (('s1', (1, -1, 0, 3)),)],
 'locs': {}}

In [239]:
G.intervene_node("s2", high_interv)

(tensor([[0, 1, 0, 4],
         [0, 1, 0, 4]]),
 tensor([[ 1,  2,  0,  4],
         [ 1, -1,  2,  3]]))