In [68]:
import torch
import torch.nn.functional as F
import numpy as np

from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm

In [69]:
# function to generate a random graph
def construct_random_subgraph(num_nodes, min_edges=2, max_edges=5):
    min_edges-=1
    max_edges-=1
    connections = np.random.uniform(size=(num_nodes, num_nodes))
    connections += np.roll(np.eye(num_nodes), 1, 1)  # make sure every node is connected
    connections -= np.eye(num_nodes)  # make sure no self-connections
    # set strongest connections
    sorted_connections = np.sort(connections)[:, ::-1]
    min_edges = min_edges
    max_edges = max_edges + 1
    min_edges = int(min_edges / 2 + 0.5)  # used to be min_edges // 2
    max_edges = int(max_edges / 2 + 0.5)
    indices = np.stack([np.arange(num_nodes), np.random.randint(min_edges, max_edges, num_nodes)])
    thresholds = sorted_connections[indices[0], indices[1]]
    connections = np.where(connections > thresholds[:, None], 1., 0.).astype(np.float32)
    connections = np.clip(connections + connections.T, 0, 1)
    return connections

# function to generate a random small world graph
def construct_small_world_graph(n_nodes_per_world=6, n_worlds=4):
    n_nodes = n_nodes_per_world * n_worlds
    connections = np.zeros((n_nodes, n_nodes))
    for world_i in range(n_worlds):
        connection = construct_random_subgraph(n_nodes_per_world)
        # add to global graph
        start_i = world_i * n_nodes_per_world
        end_i = (world_i+1) * n_nodes_per_world
        connections[start_i:end_i, start_i:end_i] = connection
        # connect to graph
        if world_i != 0:
            node_from = np.random.randint(start_i, end_i)
            node_to = node_from - n_nodes_per_world
            connections[node_from, node_to] = 1
            connections[node_to, node_from] = 1
    return connections

# function to generate a dead ends graph
def construct_dead_ends_graph():
    # construct the shell-like connectivity
    n_shells = 4
    n_neurons_per_shell = 6
    layers_with_circular_connection = [1]
    size = n_shells * n_neurons_per_shell
    connections = np.zeros((size, size)).astype(np.float32)
    eye = np.eye(n_neurons_per_shell)
    for i in range(n_shells):
        if i in layers_with_circular_connection:
            # add circular connection
            idx = i * n_neurons_per_shell
            connections[idx:idx+n_neurons_per_shell, idx:idx+n_neurons_per_shell] = np.roll(eye, 1, axis=1)

        # add connections to next shell
        if i + 1 < n_shells:
            # connect to outer shell
            idx = i * n_neurons_per_shell
            j = idx + n_neurons_per_shell
            connections[idx:idx+n_neurons_per_shell, j:j+n_neurons_per_shell] = eye

    connections = connections + connections.T
    return connections

# function to generate a grid graph with numerical paths to move to a target
def construct_grid_graph():
    n = 4
    l = 5
    n_nodes = n * l + l // 2

    tmp_n_nodes = 0
    layer_index_bounds = []
    for i in range(l):
        tmp_n_nodes += n if i % 2 == 0 else n + 1
        layer_index_bounds.append(tmp_n_nodes)

    connections = []

    lower_idx = 0
    for layer_idx in range(l):
        upper_idx = layer_index_bounds[layer_idx]
        for i in range(lower_idx, upper_idx):
            if i + 1 < upper_idx:
                connections.append([i, i + 1])
            if layer_idx + 1 < l:
                # not last layer
                next_upper_idx = layer_index_bounds[layer_idx+1]
                dist1 = n if layer_idx % 2 == 0 else n + 1
                dist2 = n + 1 if layer_idx % 2 == 0 else n
                if upper_idx <= i + dist1 < next_upper_idx:
                    connections.append([i, i + dist1])
                if upper_idx <= i + dist2 < next_upper_idx:
                    connections.append([i, i + dist2])
        lower_idx = upper_idx
    connections = np.array(connections)
    i, j = connections.T
    adj = np.zeros((n_nodes, n_nodes))
    adj[i, j] = 1
    adj += adj.T
    return adj

In [70]:
torch.randint(0, 1, (0,))

tensor([], dtype=torch.int64)

In [71]:
# Create a dataset of trajectories
class RandomWalkDataset(Dataset):
    def __init__(self, adj_matrix, trajectory_length, num_trajectories, n_items):
        self.n_items = n_items
        self.adj_matrix = adj_matrix
        self.num_trajectories = num_trajectories
        self.trajectory_length = trajectory_length
        self.edges, self.action_indices = edges_from_adjacency(adj_matrix)
        #start_nodes = torch.randint(0, adj_matrix.size(0), (num_trajectories,)).tolist()
        start_nodes = [torch.randint(0, adj_matrix.size(0), (1,)).tolist()[0]] * num_trajectories # same start node for all
        self.data = []
        for node in start_nodes:
            items = (torch.rand(self.adj_matrix.shape[0]) * self.n_items).to(torch.int32)
            trajectory = strict_random_walk(self.adj_matrix, node, self.trajectory_length, self.action_indices, items)
            self.data.append(torch.tensor([(x[0], x[1], x[2]) for x in trajectory]))
    def __len__(self):
        return self.num_trajectories  # Number of trajectories

    def __getitem__(self, idx):
        return self.data[idx]
    
    
# function to generate random walk trajectories on a given graph
def strict_random_walk(adj_matrix, start_node, length, action_indices, items):
    current_node = start_node
    trajectory = []
    for _ in range(length - 1):  # subtract 1 to account for the start node
        neighbors = torch.where(adj_matrix[current_node] > 0)[0].tolist()
        if not neighbors:
            break
        next_node = neighbors[torch.randint(0, len(neighbors), (1,)).item()]
        trajectory.append((items[current_node], action_indices[(current_node, next_node)], items[next_node]))
        current_node = next_node
    return trajectory

# indexing each action for a given adjacency matrix
def edges_from_adjacency(adj_matrix):
    # The input is a given random matrix's adjacency matrix
    # The outputs are:
        # edges: a list of pairs of (start node, end node) for each action
        # action_indices: a dictionary, each key is a pair of(start node, end node),
            # and its corresponding value is this action's index
    # For a pure on-line algorithm, this can also be done by assigning index to unseen actions
    # during random-walk on-line
    n = adj_matrix.shape[0]
    edges = []
    action_idx = 0
    action_indices = {}
    for i in range(n):
        for j in range(i+1, n):  # Only upper triangle
            if adj_matrix[i][j] != 0:
                edges.append((i, j))
                action_indices[(i, j)] = action_idx
                action_idx += 1
                edges.append((j, i))
                action_indices[(j, i)] = action_idx
                action_idx += 1
    return edges, action_indices

In [72]:
class GraphEnv:
    def __init__(self, size=32, n_items=10, env='random', batch_size=15, num_desired_trajectories=10, device=None):
        if env == 'random':
            self.adj_matrix = construct_random_subgraph(size, 2, 5)
        elif env == 'small world':
            self.adj_matrix = construct_small_world_graph()
        elif env == 'dead ends':
            self.adj_matrix = construct_dead_ends_graph()
        elif env == 'grid':
            self.adj_matrix = construct_grid_graph()
        self.adj_matrix = torch.tensor(self.adj_matrix)
        self.size = self.adj_matrix.shape[0]
        self.affordance, self.node_to_action_matrix,\
        self.action_to_node = node_outgoing_actions(self.adj_matrix)
        
        self.affordance = {k: torch.tensor(v).to(device)\
                           for k, v in self.affordance.items()}
        self.node_to_action_matrix = self.node_to_action_matrix.to(device)
        self.action_to_node = {k: torch.tensor(v).to(device) \
                            for k, v in self.action_to_node.items()}
        
        self.n_items = n_items
        self.batch_size = batch_size
        self.num_desired_trajectories = num_desired_trajectories
        self.populate_graph()
        self.gen_dataset()

    # uniformly random observations
    def populate_graph(self):
        self.items = (torch.rand(self.size) * self.n_items).to(torch.int32)

    def gen_dataset(self):
        self.dataset = RandomWalkDataset(self.adj_matrix, self.batch_size,
                                         self.num_desired_trajectories, self.n_items)
        self.n_actions = len(self.dataset.action_indices)
        
def node_outgoing_actions(adj_matrix):
    # This function creates several look-up tables for later computation's convecience
    edges, action_indices = edges_from_adjacency(adj_matrix)
    # Use an action index as a key, retrieve its (start node, end node)
    inverse_action_indices = {v: k for k, v in action_indices.items()}
    # Given a node as a key, retrieve all of its available outgoing actions' indexes.
    node_actions = {}
    # Given a pair of (start node, end node), get the action index.
    # Since a index can be 0, this matrix is initialized to be a all -1.
    node_to_action_matrix = -1*torch.ones_like(adj_matrix)
    for edge in edges:
        node_from, node_to = edge
        if node_from not in node_actions:
            node_actions[node_from] = []
        node_actions[node_from].append(action_indices[edge])
        node_to_action_matrix[node_from][node_to] = action_indices[edge]  
    return node_actions, node_to_action_matrix.long(), inverse_action_indices

In [93]:
class Agent(torch.nn.Module):
    def __init__(self, o_size, a_size, z_size, s_dim, device=None):
        super(Agent, self).__init__()
        self.Q = torch.nn.Parameter(1*torch.randn(s_dim, o_size, device=device))
        self.V = torch.nn.Parameter(0.1*torch.randn(s_dim, a_size, device=device))
        self.Z = torch.nn.Parameter(0.1*torch.randn(s_dim, z_size, device=device))
        self.W = torch.nn.Parameter(0.1*torch.randn(a_size, s_dim, device=device))
        self.o_size = o_size
        self.a_size = a_size

        self.device = device

    def forward(self, o_pre, action, o_next):
        prediction_error = self.Q[:,o_next]-(self.Q[:,o_pre]+self.V[:,action])
        return prediction_error
    
    def plan(self, start, goal, env, weight=False):
        a_record = []
        o_record = []
        loc = int(start)
        length = 0
        for i in range(self.o_size):
            o_record.append(loc)
            if loc==goal:
                if weight:
                    return length, o_record
                else:
                    return i, o_record
            loc, action = self.move_one_step(loc, goal, a_record, env.affordance[loc], 
                    env.action_to_node, env.node_to_action_matrix[loc], weight)
            a_record.append(action)
            if weight:
                length += w_connection[o_record[-1],loc]

        if weight:
            return length, o_record
        else:
            return i, o_record
        
    def move_one_step(self, loc, goal, a_record, affordance, action_to_node,
                      next_node_to_action, weight=False, w_connection=None):  
        affordance_vector = torch.zeros(self.a_size, device=self.device)
        affordance_vector[affordance] = 1
        if weight:    
            for a in affordance:
                a = a.item()
                affordance_vector[a]/=(w_connection[action_to_node[a][0],
                                                    action_to_node[a][1]])
        affordance_vector_fix = affordance_vector.clone()
        not_recommended_actions = a_record
        affordance_vector_fix[not_recommended_actions] *= 0.

        delta = self.Q[:,goal]-self.Q[:,loc]
        utility = (self.W@delta) * affordance_vector_fix
        if torch.max(utility)!=0:
            action_idx = torch.argmax(utility).item()
        else:
            utility = (self.V.T@delta) * affordance_vector
            action_idx = torch.argmax(utility).item()

            
        return action_to_node[action_idx][1].item(), action_idx

class PartiallyObservableCML(torch.nn.Module):
    def __init__(self, n_states, n_obs, n_act, embedding_dim):
        super().__init__()
        self.n_states = n_states
        self.n_obs = n_obs
        self.n_act = n_act
        self.embedding_dim = embedding_dim

        self.U = torch.nn.Parameter(torch.randn(embedding_dim, n_obs), requires_grad=False)
        self.Q = torch.nn.Parameter(torch.randn(embedding_dim, n_states))
        self.V = torch.nn.Parameter(torch.randn(embedding_dim, n_act))

        self.reset()

    # assume we start in the same state every time
    def reset(self):
        self.M = torch.zeros(self.embedding_dim, self.embedding_dim)
        self.z = self.Q[:, 0] # arbitrary state selection

    def forward(self, o, a, o_next):
        print(a.shape, self.z.shape, self.V[:, a].shape)
        z_pred = self.z.detach() + self.V[:, a]
        z_next = self.closest_state(z_pred.detach(), metric='euclidean')
        
        x_pred = self.M.T @ z_next
        x_next = self.U[:, o_next]
        print(self.M.shape, z_next.shape, x_pred.shape)

        self.z = z_next
        self.M += torch.outer(z_next, x_next).detach()

        return z_pred, z_next, x_pred, x_next

    def closest_state(self, s, metric='dot'):
        if metric == 'dot':
            return self.Q[torch.argmax(self.Q.T @ s)]
        elif metric == 'euclidean':
            s = s.unsqueeze(1)
            print("Q", self.Q.shape, s.shape)
            print("L", torch.linalg.norm(self.Q.detach() - s, dim=0).shape)
            return self.Q[:, torch.argmin(torch.linalg.norm(self.Q.detach() - s, dim=0))]
        else:
            raise NotImplementedError

In [94]:
def train_model(model: PartiallyObservableCML, dataloader, epochs, norm=False):
    optim = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()
    losses = []
    for epoch in tqdm(range(epochs), desc="Epochs"):
        loss_over_epoch = 0
        for trajectory in dataloader:
            model.reset()
            
            # "break in" traj w/o backprop first
            for i in range(trajectory.shape[0]):
                o, a, o_next = trajectory[0,i,0], trajectory[0,i,1], trajectory[0,i,2]
                z_pred, z_next, x_pred, x_next = model(o, a, o_next)

                loss1 = loss_fn(z_pred, z_next)
                loss2 = loss_fn(x_pred, x_next)

                optim.zero_grad()
                loss1.backward()
                optim.step()

                optim.zero_grad()
                loss2.backward()
                optim.step()

                loss_over_epoch += loss1.item() + loss2.item()

            if norm:
                model.V.data = model.V / torch.norm(model.V, dim=0)
            
        losses.append(loss_over_epoch)
        print(f"Epoch {epoch} | Loss: {loss_over_epoch}")
    return losses

In [95]:
torch.autograd.set_detect_anomaly(True)

n_nodes = 32
batch_size = 32
state_dim = 1000
epochs = 10
n_obs = 10

num_desired_trajectories=200
# choose env from "random", "small world" or "dead ends"
env = GraphEnv(size=n_nodes, n_items=n_obs, env='random', batch_size=batch_size, num_desired_trajectories=num_desired_trajectories)

dataset = RandomWalkDataset(env.adj_matrix, batch_size, num_desired_trajectories, n_obs)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
model = PartiallyObservableCML(n_states=env.size, n_obs=n_obs, n_act=env.n_actions, embedding_dim=state_dim)


loss_record= train_model(model, dataloader, epochs, norm=False)

  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/calvin/miniforge3/envs/pytorch/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/calvin/miniforge3/envs/pytorch/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/calvin/miniforge3/envs/pytorch/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/Users/calvin/miniforge3/envs/pytorch/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/Users/calvin/miniforge3/envs/pytorch/lib/python3.12/asyncio/base_events.py", line 638, in run_forever
    self._run_once()
  File "/Users/calvin/miniforge3/envs/pytorch/lib/python3.12/asyncio/base_events.py", line 1971, in _run_once
    handle._run()
  File "/Users/calvin/mini

torch.Size([]) torch.Size([1000]) torch.Size([1000])
Q torch.Size([1000, 32]) torch.Size([1000, 1])
L torch.Size([32])
x False False True





RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1000, 1000]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!