In [1]:
import math
import random
import gymnasium as gym
import numpy as np
from stable_baselines3.common.env_checker import check_env
print("No problems")

No problems


In [2]:
class Learner_MW:
    def __init__(self, num_actions):
        self.num_actions = num_actions
        self.weights = [1] * num_actions  # Initial weights for all actions

    def choose_action(self):
        total_weight = sum(self.weights)
        probabilities = [weight / total_weight for weight in self.weights]
        action = random.choices(range(self.num_actions), probabilities)[0]
        return action

    #MW update with a fixed learning rate, this is the theoretical object we study, although having len(transcript) instead provides better results 
    def update_weights(self, action, reward, transcript, total_time):
        eta = math.sqrt(2 * math.log(self.num_actions) / total_time)
        for i in range(self.num_actions):
            self.weights[i] *= math.exp(eta * reward[i])  # Update weights based on counterfactual payoff

    def action_probs_cumulative(self, cumulative, total_time):
        if len(cumulative) != self.num_actions:
            return "Serious error"
        eta = math.sqrt(2 * math.log(self.num_actions) / total_time)
        weights = [1] * self.num_actions  # Initial weights for all actions
        for i in range(self.num_actions):
            weights[i] = math.exp(eta * cumulative[i] )  # Update weights based on counterfactual payoff
        total_weight = sum(weights)
        probabilities = [weight / total_weight for weight in weights]
        #The function has been tested
        #print("Probabilities = ", probabilities)
        return probabilities

In [3]:


class SimplexSpace(gym.spaces.Space):
    """
    Defines the action space as a probability simplex over m outcomes.
    """

    def __init__(self, m):
        """
        Initializes the SimplexSpace.

        Parameters:
            - m (int): Number of outcomes.
        """
        assert m >= 2, "Number of outcomes 'm' must be at least 2."
        super(SimplexSpace, self).__init__(shape=(m,), dtype=np.float32)
        self.m = m

    def sample(self):
        """
        Generates a random sample from the simplex space.
        """
        sample = np.random.dirichlet(np.ones(self.m))
        return sample


    def contains(self, x):
        """
        Checks if a given point is within the simplex space.

        Parameters:
            - x (array-like): Point to be checked.

        Returns:
            - bool: True if the point is within the space, False otherwise.
        """
        if len(x) != self.m:
            return False
        if np.sum(x) != 1.0:
            return False
        if np.any(x < 0):
            return False
        return True

    def seed(self, seed=None):
        """
        Seeds the pseudo-random number generator for this space.

        Parameters:
            - seed (int or None): The seed to use. If None, a random seed will be chosen.

        Returns:
            - list[int]: The list of seeds used for seeding the PRNG.
        """
        if seed is None:
            seed = np.random.randint(0, 2**32 - 1)
        np.random.seed(seed)
        return [seed]
