In [2]:
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Dict, Discrete, Box
import torch

import os 
import sys

sys.path.append(os.path.abspath('../..'))

from environment.env import POMDPDeformedGridworld

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()

        self.f1 = nn.Linear(6, 128)
        self.f2 = nn.Linear(128, 128)
        self.f3 = nn.Linear(128, 128)
        self.f4 = nn.Linear(128, 1)
        
    def forward(self, pos,theta):
        x = torch.cat([pos,theta], dim=1)
        x = F.relu(self.f1(x))
        x = F.relu(self.f2(x))
        x = F.relu(self.f3(x))
        x = F.sigmoid(self.f4(x))
        return x

# Instantiate the model
obs_model = NN()

# Load the model
obs_model.load_state_dict(torch.load('obs_model_0.pth',map_location=torch.device('cpu') ,weights_only=True))

obs_model.eval()
obs_model(torch.tensor([0.5, 0.5]).unsqueeze(0), torch.tensor([0.5, 0.0, 0.0, 0.5]).unsqueeze(0))

tensor([[0.4451]], grad_fn=<SigmoidBackward0>)

# DISCRETE BELIEF UPDATE TESTING

In [10]:
class POMDPAgent():
    
    def __init__(self, env: POMDPDeformedGridworld, discretization=10, update='discrete_exact', obs_model=None):
        self.env = env

        stretch = np.linspace(.4, 1, discretization)
        # shear = np.linspace(0,0, discretization)
        xa,xb = np.meshgrid(stretch, stretch) # , shear, shear
        positions = np.column_stack([xa.ravel(),xb.ravel()]), #  ya.ravel(),yb.ravel()
        positions = torch.tensor(positions, dtype=torch.float32)
        self.belief_points = positions.squeeze()
        self.belief_values = torch.ones(self.belief_points.shape[0], dtype=torch.float32) / len(positions)

        if update == 'discrete_modelled': 
            assert obs_model is not None, f'Need an observation model for discrete_modelled belief update, given {obs_model}'
            self.obs_model = obs_model
            self.belief_update = self.discrete_belief_update
        
        elif update == 'discrete_exact':
            self.belief_update = self.exact_belief_update
        else:
            raise NotImplementedError('Only discrete belief update is supported')
        
        self.original_def = env.transformation_matrix[0][0], env.transformation_matrix[1][1]
        
    def act(self):
        action = input('Enter action: ')
        pomdp_state, reward, terminated,truncated, info = self.env.step(int(action))
        print(pomdp_state)
        self.belief_update(pomdp_state)

    def discrete_belief_update(self, pomdp_state):
        """discrete belief update"""
        pos = pomdp_state['pos']
        obs = pomdp_state['obs']

        batch_pos = pos.repeat(len(self.belief_points), 1)
        
        # need theta because working on two parameters only in this example
        theta = torch.cat([self.belief_points, torch.zeros(len(self.belief_points), 2)], dim=1)
        # permute theta to match the order of pos
        theta = theta[:, [0,3,2,1]]
        

        predictions = self.obs_model(batch_pos,theta)
        likelihood = torch.exp(torch.distributions.Bernoulli(predictions).log_prob(obs))

        tmp = likelihood.squeeze() * self.belief_values
        self.belief_values = tmp  / tmp.sum()
    
    def exact_belief_update(self, pomdp_state):
        """discrete belief update"""
        obs = pomdp_state['obs']
        pos = pomdp_state['pos']

        def f():
            likelihood = []
            for x in self.belief_points:
                try:
                    self.env.set_deformation([x[0], x[1]],[0,0]) # stretch, shear format
                    likelihood.append(torch.all(torch.tensor(self.env.observe(list(pos))) == obs))
                except:
                    raise ValueError('Invalid belief point x', x)
            self.env.set_deformation(self.original_def, [0,0])
            return torch.tensor(likelihood, dtype=torch.float32)

        
        likelihood = f()
        self.belief_values =  likelihood * self.belief_values
        self.belief_values = self.belief_values / self.belief_values.sum()

    def render_act(self):
        """For testing belief convergence"""
        pomdp_state = self.env.get_state()
        self.belief_update(pomdp_state)
        self.env.render()       

In [11]:
from PIL import Image
import torch
from io import BytesIO

# Modify belief_plot to save as an image
def belief_plot(agent):
    import matplotlib.pyplot as plt
    plt.figure()    
    # (Add your plotting logic here)
    plt.imshow(agent.belief_values.detach().numpy().reshape(50,50))
    # Save the figure to an in-memory buffer
    buf = BytesIO()
    plt.savefig(buf, format="png")  # Save as PNG into the buffer
    plt.close()  # Close the plot to avoid memory leaks
    buf.seek(0)  # Move to the beginning of the buffer

    # Open the image from the buffer
    return Image.open(buf)    

def create_gif(images, filename="belief_animation.gif", duration=100, loop=0):
    """
    Create a GIF from a list of PIL Image objects.
    
    Args:
        images (list): A list of PIL Image objects.
        filename (str): Name of the output GIF file.
        duration (int): Duration of each frame in milliseconds.
        loop (int): Number of times to loop the GIF. 0 means infinite.
    """
    if images:
        images[0].save(
            filename,
            save_all=True,
            append_images=images[1:],
            duration=duration,
            loop=loop
        )
        print(f"GIF saved as {filename}")
    else:
        print("No images to create a GIF.")

In [14]:
pomdp_env = POMDPDeformedGridworld(obs_type='single')
pomdp_env.reset()
pomdp_env.set_deformation([0.6, 0.6],[0,0])

agent = POMDPAgent(pomdp_env,50, update='discrete_modelled', obs_model=obs_model)
# agent = POMDPAgent(pomdp_env,50,update='discrete_exact')

images = []
b = agent.belief_points[torch.argmax(agent.belief_values)]
img = belief_plot(agent)
images.append(img)      

print(b)
while True:
    try:
        agent.render_act()
        img = belief_plot(agent)
        images.append(img)      
        # assert torch.allclose(agent.belief_values.sum(), torch.tensor([1.0])), f"Belief values do not sum to 1: {agent.belief_values.sum()}"
        if torch.any(b != agent.belief_points[torch.argmax(agent.belief_values)]):
            b = agent.belief_points[torch.argmax(agent.belief_values)]
            print(b)
    except:
        print('Error')
        break



tensor([0.4000, 0.4000])
tensor([0.8408, 0.6204])
tensor([0.5959, 0.6204])
tensor([0.5959, 0.6082])
tensor([0.5837, 0.6082])
Error


In [15]:
create_gif(images, filename="belief_animation_single_obsmodel.gif", duration=100, loop=0)

GIF saved as belief_animation_single_obsmodel.gif


# VARIATIONAL UPDATE