In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import mdpsolver

In [78]:
DIR_DICT = {
    'left': 0,
    'up': 1,
    'right': 2,
    'down': 3
}

class GridElement():
    def __init__(self, x, y, reward = None, walls=None):
        self.x = x
        self.y = y
        if walls is not None:
            assert len(walls) == 4, "Walls must be a list of four boolean values."
            assert all(wall in [True, False] for wall in walls), "Walls must be True or False."
            self.walls = walls
        else:
            self.walls = [False, False, False, False]  # left, up, right, down

        self.neighbors = [None, None, None, None]  # left, up, right, down

        self.reward = reward

    def add_wall(self, direction):
        if direction in ['left', 'right', 'up', 'down']:
            direction = DIR_DICT[direction]
        if direction in [0, 1, 2, 3]:
            self.walls[direction] = True
        else:
            raise ValueError("Direction must be 0 (left), 1 (up), 2 (right), or 3 (down).")
        
    def add_neighbor(self, direction, neighbor):
        if direction in ['left', 'right', 'up', 'down']:
            direction = DIR_DICT[direction]
        if direction in [0, 1, 2, 3]:
            self.neighbors[direction] = neighbor
        else:
            raise ValueError("Direction must be 0 (left), 1 (up), 2 (right), or 3 (down).")
        
    def get_wall_string(self):
        walls_str = ['x x','   ','x x']
        if self.walls[0]:
            if self.walls[2]:
                walls_str[1] = '| |'
            else:
                walls_str[1] = '|  '
        elif self.walls[2]:
            walls_str[1] = '  |'
        if self.walls[1]:
            walls_str[0] = 'x―x'
        if self.walls[3]:
            walls_str[2] = 'x―x'
        
        return walls_str
    
class Grid():
    def __init__(self, grid_elements):
        assert isinstance(grid_elements, list), "grid_elements must be a list of GridElement objects."
        assert all(isinstance(element, GridElement) for element in grid_elements), "All elements must be GridElement objects."
        self.grid_elements = grid_elements
        self.width = max(element.x for element in grid_elements) + 1
        self.height = max(element.y for element in grid_elements) + 1
        self.grid = [[None for _ in range(self.width)] for _ in range(self.height)]
        for element in grid_elements:
            self.grid[element.y][element.x] = element

    def get_element(self, x, y):
        if 0 <= x < self.width and 0 <= y < self.height:
            return self.grid[y][x]
        else:
            raise IndexError("Coordinates out of bounds.")
    
    def __repr__(self):
        """
        Return a string representation of the grid.
        """
        grid_str = ""
        for y in range(self.height-1, -1, -1):
            # Top walls
            for x in range(self.width):
                element = self.get_element(x, y)
                if element is not None:
                    walls_str = element.get_wall_string()
                    # add upper walls to grid_str
                    grid_str += f'{walls_str[0][0:2]}'
                    if x == self.width - 1:  # left boundary
                        grid_str += f'{walls_str[0][2]}\n'
                else:
                    grid_str += '  '

            # Middle walls
            for x in range(self.width):
                element = self.get_element(x, y)
                if element is not None:
                    walls_str = element.get_wall_string()
                    # add middle walls to grid_str
                    grid_str += f'{walls_str[1][0:2]}'
                    if x == self.width - 1:
                        grid_str += f'{walls_str[1][2]}\n'
                else:
                    grid_str += '  '

        # Bottom walls
        for x in range(self.width):
            element = self.get_element(x, 0)
            if element is not None:
                walls_str = element.get_wall_string()
                # add bottom walls to grid_str
                grid_str += f'{walls_str[2][0:2]}'
                if x == self.width - 1:
                    grid_str += f'{walls_str[2][2]}\n'
            else:
                grid_str += '  '
        
        return grid_str

def fullgrid_init(width, height, wall_horiz=None, wall_vert=None, reward=None):
    """
    Initialize a dense WxH grid world 
    The grid is represented as a list of GridElement objects.
    Each GridElement can have walls on its sides and neighbors.
    """
    assert isinstance(width, int) and width > 0, "Width must be a positive integer."
    assert isinstance(height, int) and height > 0, "Height must be a positive integer."
    assert wall_horiz is None or isinstance(wall_horiz, list), "wall_horiz must be a list of coordinates."
    assert wall_vert is None or isinstance(wall_vert, list), "wall_vert must be a list of coordinates."
    assert reward is None or isinstance(reward, dict), "reward must be a dictionary of coordinates and values."
    
    # Create a grid of GridElement objects with walls
    grid = []
    for x in range(width):
        for y in range(height):
            element = GridElement(x, y)
            
            # Add walls corresponding to edges of the grid
            if x == 0: # left wall
                element.add_wall('left')
            if x == width - 1: # right wall
                element.add_wall('right')
            if y == 0: # bottom wall
                element.add_wall('down')
            if y == height - 1: # top wall
                element.add_wall('up')
            
            # Add walls based on wall_horiz and wall_vert
            if wall_horiz is not None:
                if [x, y+1] in wall_horiz: # horizontal wall above
                    element.add_wall('up')
                if [x, y] in wall_horiz: # horizontal wall below
                    element.add_wall('down')

            if wall_vert is not None:
                if [x+1, y] in wall_vert: # vertical wall to the right
                    element.add_wall('right')
                if [x, y] in wall_vert: # vertical wall to the left
                    element.add_wall('left')

            # Add reward if specified
            if reward is not None:
                if (x, y) in reward.keys():
                    element.reward = reward[(x, y)]
                else:
                    element.reward = 0

            grid.append(element)

    grid = Grid(grid)

    # Update neighbors
    for x in range(width):
        for y in range(height):
            element = grid.get_element(x, y)
            
            upper_neighbor = grid.get_element(x, y+1) if y < height - 1 and not element.walls[1] else None
            element.add_neighbor('up', upper_neighbor)
            lower_neighbor = grid.get_element(x, y-1) if y > 0 and not element.walls[3] else None
            element.add_neighbor('down', lower_neighbor)
            left_neighbor = grid.get_element(x-1, y) if x > 0 and not element.walls[0] else None
            element.add_neighbor('left', left_neighbor)
            right_neighbor = grid.get_element(x+1, y) if x < width - 1 and not element.walls[2] else None
            element.add_neighbor('right', right_neighbor)

    return grid

def reassign_nextstate_probs(move_probs, blocked, controllability, wind):
    """
    Apply controllability and wind to the move probabilities.
    """
    control_x = controllability[0]
    control_y = controllability[1]
    wind_x = wind[0]
    wind_y = wind[1]

    # Add inertia to the agent as a function of controllability
    control_x_loss = (1 - control_x)*sum(move_probs[[0, 2]])
    move_probs[0] = control_x*move_probs[0]
    move_probs[2] = control_x*move_probs[2]
    move_probs[4] += control_x_loss
    control_y_loss = (1 - control_y)*sum(move_probs[[1, 3]])
    move_probs[1] = control_y*move_probs[1]
    move_probs[3] = control_y*move_probs[3]
    move_probs[4] += control_y_loss
    
    # Add wind to the move probabilities
    if wind_x > 0: # right wind
        # remove energy from the left move and add it to the stay move
        left_loss = (abs(wind_x))*move_probs[0]
        move_probs[0] -= left_loss
        move_probs[4] += left_loss
        # if there is no blockage, move energy from up, down, and stay to the right
        if not blocked[2]:
            right_gain = (abs(wind_x))*sum(move_probs[[1, 3, 4]])
            move_probs[[1, 3, 4]] = (1 - abs(wind_x))*move_probs[[1, 3, 4]]
            move_probs[2] += right_gain
    if wind_x < 0: # left wind
        # remove energy from the right move and add it to the stay move
        right_loss = (abs(wind_x))*move_probs[2]
        move_probs[2] -= right_loss
        move_probs[4] += right_loss
        # if there is no blockage, move energy from up, down, and stay to the left
        if not blocked[0]:
            left_gain = (abs(wind_x))*sum(move_probs[[1, 3, 4]])
            move_probs[[1, 3, 4]] = (1 - abs(wind_x))*move_probs[[1, 3, 4]]
            move_probs[0] += left_gain

    if wind_y > 0: # up wind
        # remove energy from the down move and add it to the stay move
        down_loss = (abs(wind_y))*move_probs[3]
        move_probs[3] -= down_loss
        move_probs[4] += down_loss
        # if there is no blockage, move energy from left, right, and stay to the up
        if not blocked[1]:
            up_gain = (abs(wind_y))*sum(move_probs[[0, 2, 4]])
            move_probs[[0, 2, 4]] = (1 - abs(wind_y))*move_probs[[0, 2, 4]]
            move_probs[1] += up_gain
    if wind_y < 0: # down wind
        # remove energy from the up move and add it to the stay move
        up_loss = (abs(wind_y))*move_probs[1]
        move_probs[1] -= up_loss
        move_probs[4] += up_loss
        # if there is no blockage, move energy from left, right, and stay to the down
        if not blocked[3]:
            down_gain = (abs(wind_y))*sum(move_probs[[0, 2, 4]])
            move_probs[[0, 2, 4]] = (1 - abs(wind_y))*move_probs[[0, 2, 4]]
            move_probs[3] += down_gain

    return move_probs


class GridWorld():
    """
    A class representing a grid world MDP.
    """

    def __init__(self, width, height, wall_horiz=None, wall_vert=None, reward=None, controllability=[1.0,1.0], wind=[0.0,0.0]):
        """
        Initialize a grid world with given width and height.
        The grid is represented as a list of GridElement objects.
        Each GridElement can have walls on its sides and neighbors.
        """
        self.width = width
        self.height = height
        self.n_states = width * height
        self.n_actions = 4  # left, up, right, down

        self.grid = fullgrid_init(width, height, wall_horiz, wall_vert, reward)

        self.controllability = controllability
        self.wind = wind

    def get_TPM(self):
        """
        Get the Transition Probability Matrix (TPM) for the grid world.
        The TPM is a 3D tensor of shape [state, action, next_state].
        """
        tpm = torch.zeros((self.n_states, 4, self.n_states), dtype=torch.float32)
        for x_curr in range(self.width):
            for y_curr in range(self.height):
                curr_element = self.grid.get_element(x_curr, y_curr)
                curr_index = y_curr * self.width + x_curr
                
                neighbour_left = curr_element.neighbors[0]
                neighbour_up = curr_element.neighbors[1]
                neighbour_right = curr_element.neighbors[2]
                neighbour_down = curr_element.neighbors[3]
                neighbours = [neighbour_left, neighbour_up, neighbour_right, neighbour_down]
                target_indices = [neighbour.x + neighbour.y * self.width if neighbour is not None else None for neighbour in neighbours]
                target_indices.append(curr_index)

                blocked_left = True if neighbour_left is None else False
                blocked_up = True if neighbour_up is None else False
                blocked_right = True if neighbour_right is None else False
                blocked_down = True if neighbour_down is None else False
                blocked = [blocked_left, blocked_up, blocked_right, blocked_down]
                
                # Look at the possible moves in different directions
                for direction in range(4):
                    default_move_probs = np.zeros(5)
                    if blocked[direction] is False:
                        default_move_probs[direction] = 1.0
                    else:
                        default_move_probs[4] = 1.0
                    
                    # Apply controllability and wind to the move probabilities
                    move_probs = reassign_nextstate_probs(default_move_probs, blocked, self.controllability, self.wind)
                    
                    # Assign the move probabilities to the TPM
                    for i in range(5):
                        if target_indices[i] is not None:
                            if move_probs[i] > 0:
                                # Assign the move probabilities to the TPM
                                tpm[curr_index, direction, target_indices[i]] = move_probs[i]
                           

        #assert tpm.sum(dim=2) == torch.ones_like(tpm.sum(dim=2)), "next state probabilities must sum to 1"

        return tpm    
    
    def get_TPM_sparse(self):
        """
        Get a sparse TPM in the format of "MDPSolver" 
        """
        tpm = self.get_TPM()
        sparse = []
        for i in range(tpm.shape[0]):
            for j in range(4):
                for k in range(tpm.shape[2]):
                    if tpm[i, j, k] > 0:
                        sparse.append((i, j, k, float(tpm[i, j, k].item())))

        return sparse           
    
    def get_rewards_vec(self):
        """
        Get the rewards for each state in the grid world.
        The rewards are stored in a 1D tensor of shape [state].
        """
        rewards = torch.zeros(self.n_states, dtype=torch.float32)
        for x in range(self.width):
            for y in range(self.height):
                element = self.grid.get_element(x, y)
                if element.reward is not None:
                    index = y * self.width + x
                    rewards[index] = element.reward
        
        return rewards

    def get_rewards_list(self):
        """
        Get the rewards for each state in the grid world.
        The rewards are stored in a 1D tensor of shape [state].
        """
        rewards = []
        for x in range(self.width):
            for y in range(self.height):
                element = self.grid.get_element(x, y)
                if element.reward is not None:
                    rewards.append([element.reward, element.reward, element.reward, element.reward])
                else:
                    rewards.append([0, 0, 0, 0])
        
        return rewards

In [79]:
gridworld = GridWorld(3, 3, wall_horiz=[[0, 1]], wall_vert=[[1, 1]], reward={(2, 2): 1}, controllability=[1.0, 1.0], wind=[-0.1, -0.1])
print(gridworld.grid)

grid_str = f'{gridworld.grid}'

tpm = gridworld.get_TPM()
rewards = gridworld.get_rewards_list()
print(tpm)
print(rewards)

x―x―x―x
|     |
x x x x
| |   |
x―x x x
|     |
x―x―x―x

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.1000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.1000, 0.0900, 0.0000, 0.0000, 0.8100, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0100, 0.0900, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.1000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000]],

        [[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.1000, 0.0900, 0.0000, 0.0000, 0.8100, 0.0000, 0.0000,
          0.0000],
         

In [5]:
def interact(state, tpm, action):
    state = int(torch.argmax(state).item())
    # Get the next state based on the current state and action
    next_state = torch.multinomial(tpm[state, action, :], 1).item()
    state = torch.zeros(9, dtype=torch.float32)
    state[next_state] = 1.0
    return state

def print_state(grid_str, state):
    grid_str = grid_str.replace('o', ' ')
    state = int(torch.argmax(state).item())
    x = state % 3
    y = state // 3
    x_adj = 1+x*2
    y_adj = (0+3*2)-(1+y*2)
    grid_split = grid_str.split('\n')
    grid_split[y_adj] = grid_split[y_adj][:x_adj] + 'o' + grid_split[y_adj][x_adj+1:]
    grid_str = '\n'.join(grid_split)
    
    print(grid_str)

In [6]:
state = torch.zeros(9, dtype=torch.float32)
state[0] = 1.0

print_state(grid_str, state)

x―x―x―x
|     |
x x x x
| |   |
x―x x x
|o    |
x―x―x―x



In [7]:
action = 0
state = interact(state, tpm, 3)
print_state(grid_str, state)

x―x―x―x
|     |
x x x x
| |   |
x―x x x
|o    |
x―x―x―x



In [72]:
GAMMA = 0.9

mdl = mdpsolver.model()

mdl.mdp(
    discount=GAMMA,
    rewards = gridworld.get_rewards(),
    tranMatElementwise=gridworld.get_TPM_sparse(),
    
)

mdl.solve()

print(mdl.getValueVector())
print(mdl.getPolicy())

def mdpsolver_policy_to_matrix(policy):
    """
    Convert the policy from mdpsolver (which is a 1D array of the best action for each state)
    to a 2D matrix representation.
    """
    policy = np.array(policy)
    policy_matrix = torch.zeros(policy.shape[0], 4)
    for i in range(policy.shape[0]):
        policy_matrix[i, int(policy[i])] = 1.0

    return policy_matrix

policy_matrix = mdpsolver_policy_to_matrix(mdl.getPolicy())
print(policy_matrix)

[4.866296707602909, 5.467087772011805, 6.148192474557321, 5.474177367212733, 6.2912092378483235, 7.075654828582033, 6.150002154960483, 7.077059082365852, 8.155168412061817]
[2, 1, 1, 1, 1, 1, 2, 2, 1]
tensor([[0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.]])


In [82]:
def calculate_marginalised_T(tpm, policy):
    """
    Calculate the marginalised transition matrix T(s'|s) given a policy and a transition probability matrix.
    tpm is a tensor of shape [state, action, next_state]
    policy is a tensor of shape [state, action]
    """
    sr = torch.zeros((tpm.shape[0], tpm.shape[0]), dtype=torch.float32)
    # M_{s,s'} = \sum_{a} \pi(a|s) T(s'|s,a)
    for s in range(tpm.shape[0]):
        for s_ in range(tpm.shape[0]):
            sr[s, s_] = torch.sum(policy[s] * tpm[s, :, s_])

    return sr



T = calculate_marginalised_T(tpm, policy_matrix)
print(T)

M = torch.linalg.inv(torch.eye(T.shape[0]) - GAMMA * T)
print(M)    

rew_vec = gridworld.get_rewards_vec()
print(rew_vec)

V = torch.matmul(M, rew_vec)
print(V)

tensor([[0.1000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1000, 0.0900, 0.0000, 0.0000, 0.8100, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.1000, 0.0900, 0.0000, 0.0000, 0.8100, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],
        [0.0000, 0.0100, 0.0000, 0.0000, 0.0900, 0.0000, 0.0000, 0.9000, 0.0000],
        [0.0000, 0.0000, 0.0190, 0.0000, 0.0900, 0.0810, 0.0000, 0.0000, 0.8100],
        [0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0000, 0.0900, 0.8100, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0090, 0.0810, 0.8100],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0900, 0.8100]])
tensor([[1.2050e+00, 1.0730e+00, 8.9207e-03, 1.2697e-03, 1.0236e+00, 4.7942e-01,
         1.2838e-02, 1.3296e+00, 4.8663e+00],
        [1.1922e-01, 1.2055e+00, 1.0022e-02, 1.4264e-03, 1.1500e+00, 5.3861e-01,
         1.4423e-02, 1.4937e+00, 5.4671e+00],
       