In [1]:
from pickle import load
import os
import torch
from utils import SimplexEnvironment

MDP = load(open( os.path.join("data", "MDP_1.bin"), "rb" ))
t = torch.from_numpy(MDP.get_transitions()).to(torch.float32)

100%|██████████| 100/100 [00:00<00:00, 549.45it/s]


In [2]:
print(MDP.A, MDP.S, MDP.d)

100 1000 10


We will use a Neural Network To approximate the embedding functions of our MDP.

In [4]:
import torch.nn as nn
import torch.nn.functional as F

In [91]:
class Net(nn.Module): 
    def __init__(self, states, actions, d):
        super(Net, self).__init__()
        self.states = states
        self.actions = actions
        self.d = d
    
        self.l1 = nn.Linear(self.states + self.actions, 120)  #First Linear layers, Receives concat onehot enconding of state-action pair
        self.l2 = nn.Linear(120, 60)
        self.embedding = nn.Linear(60, self.d)

        self.mu_weights = nn.Parameter(torch.rand((self.d, self.states)))

    def encode_input(self, s, a):
         """
        # Parameters:
        s: State id or list like of state ids between 0 and self.states
        a: Action id or list like of actions ids between 0 and self.actions

        If s and a are list-like, both need to be the same lenght
         """
         input_len = len(s) if hasattr(s, '__len__') else 1
         actions_len = len(a) if hasattr(a, '__len__') else 1
         assert input_len == actions_len, f"The input lenghts do not coincide. Input States: {input_len}; Input Actions: {actions_len}"

         s_hot = F.one_hot(torch.tensor(s).view(input_len, 1), self.states).to(torch.float32)
         a_hot = F.one_hot(torch.tensor(a).view(input_len, 1), self.actions).to(torch.float32)
         x = torch.cat((s_hot, a_hot), dim=-1)# Concat one hot vectors 
         return x

    def enconde_output(self, s):
        input_len = len(s) if hasattr(s, '__len__') else 1
        return F.one_hot(torch.tensor(s).view(input_len, 1), self.states).to(torch.float32)


    def phi(self, s, a):
        """
        # Parameters:
        s: State id or list like of state ids between 0 and self.states
        a: Action id or list like of actions ids between 0 and self.actions

        If s and a are list-like, both need to be the same lenght
         """
        
        x = self.encode_input(s, a) 
        x = F.leaky_relu(self.l1(x), negative_slope=0.01)
        x = F.leaky_relu(self.l2(x), negative_slope=0.01)
        x = F.softmax(self.embedding(x), dim=-1) # Apply softmax row wise

        return x

    def mu(self):
        return F.softmax(self.mu_weights, dim=-1)

    def forward(self, s, a):
        """
        # Parameters:
        s: State id between 0 and self.states - 1
        a: Action id between 0 and self.actions - 1
        """
        x = self.phi(s, a)
        soft_mu = self.mu()

        # We use bradcasting in here so the same parameters are used for every element of the batch
        x = torch.matmul(x, soft_mu) # Mat multiplication of (1, 1, d) @ (batch_size, d, states) --> (batch_size, 1, states) # Distribution over states

        return x

net = Net(MDP.S, MDP.A, MDP.d)

In [79]:
s = [1]
a = [1]
input_len = len(s) if hasattr(s, '__len__') else 1


s_hot = F.one_hot(torch.tensor(s).view(input_len, 1), 5).to(torch.float32)
a_hot = F.one_hot(torch.tensor(a).view(input_len, 1), 5).to(torch.float32)
x = torch.cat((s_hot, a_hot), dim=-1)

print(x.shape)
print(x)





torch.Size([1, 1, 10])
tensor([[[0., 1., 0., 0., 0., 0., 1., 0., 0., 0.]]])


In [89]:
s = [30, 5]
a = [10, 3]

y = net.forward(s, a)
print(y.size())

torch.Size([2, 1, 1000])


In [90]:
torch.sum(y, dim=2)

tensor([[1.0000],
        [1.0000]], grad_fn=<SumBackward1>)

In [20]:
m = nn.Softmax(dim=-1)

input = torch.randn(2, 3)
print(input)
output = m(input)
print(output)

tensor([[-0.7659, -0.6957, -2.9609],
        [-0.0528,  1.4539,  1.0764]])
tensor([[0.4579, 0.4911, 0.0510],
        [0.1162, 0.5243, 0.3595]])


In [21]:
torch.sum(output, dim=(-1))

tensor([1.0000, 1.0000])

In [27]:
F.one_hot(torch.tensor([1]), 10).size()

torch.Size([1, 10])