In [2]:
# basics
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
import pickle

# multiprocessing
import torch.multiprocessing as mp
from functools import partial

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
import torch_directml

# custom
from scripts.classes import Botzee
from scripts.functions import play_botzee, model_pick_dice, model_pick_score, reinforce_by_turn

In [None]:
with open('models/botzee_base_args.pkl', 'rb') as f:
    botzee_base_args = pickle.load(f)

botzee = Botzee(
    input_sizes = botzee_base_args['input_sizes'],
    lstm_sizes = botzee_base_args['lstm_sizes'], 
    dice_output_size = botzee_base_args['dice_output_size'], 
    score_output_size = botzee_base_args['score_output_size'], 
    masks = botzee_base_args['masks']
)

In [1]:
import gymnasium as gym

In [3]:
class YahtzeeEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        super(YahtzeeEnv, self).__init__()

        # Define action and observation space
        # They must be gym.spaces objects
        # Example: 5 dice, each can be 1-6
        self.observation_space = gym.spaces.MultiDiscrete([6] * 5)

        # Example action space: choose dice to re-roll (binary choice for each dice)
        self.action_space = gym.spaces.MultiBinary(5)

        # Initialize state
        self.state = None

    def step(self, action):
        # Execute one time step within the environment
        self._take_action(action)

        done = False # Define your termination condition here
        reward = self._get_reward() # Define your reward calculation here

        # Optionally we can pass additional info, we don't use that for now
        info = {}

        return self.state, reward, done, info

    def reset(self):
        # Reset the state of the environment to an initial state
        self.state = self.observation_space.sample()
        return self.state

    def render(self, mode='human', close=False):
        # Render the environment to the screen
        print(f"Current state: {self.state}")

    def _take_action(self, action):
        # Implement logic to change state based on action
        # Example: Re-roll selected dice
        for i in range(len(action)):
            if action[i] == 1:
                self.state[i] = np.random.choice([1, 2, 3, 4, 5, 6])

    def _get_reward(self):
        # Implement your reward calculation
        # Example: simple reward for demonstration
        return np.sum(self.state)

In [None]:
class BotzeeRoll(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BotzeeRoll, self).__init__()
        # Define the neural network layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Forward pass through the network
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
env = YahtzeeEnv()

In [6]:
# Example of interacting with the environment
obs = env.reset()
print("Initial Observation:", obs)

Initial Observation: [3 4 5 5 5]


In [10]:
env.render()

Current state: [3 5 1 4 3]


In [9]:
for _ in range(10):
    action = env.action_space.sample() # Random action
    obs, reward, done, info = env.step(action)
    print(
        f'Observation: {obs}'
        f'\nReward: {reward}'
        f'\nDone: {done}'
        f'\nInfo: {info}\n\n'
    )

    if done:
        break

Observation: [4 2 6 4 2]
Reward: 18
Done: False
Info: {}


Observation: [5 2 6 4 2]
Reward: 19
Done: False
Info: {}


Observation: [5 2 6 4 5]
Reward: 22
Done: False
Info: {}


Observation: [3 6 5 1 3]
Reward: 18
Done: False
Info: {}


Observation: [3 6 5 1 3]
Reward: 18
Done: False
Info: {}


Observation: [3 6 3 1 2]
Reward: 15
Done: False
Info: {}


Observation: [3 2 6 1 6]
Reward: 18
Done: False
Info: {}


Observation: [3 5 6 1 4]
Reward: 19
Done: False
Info: {}


Observation: [3 5 6 3 4]
Reward: 21
Done: False
Info: {}


Observation: [3 5 1 4 3]
Reward: 16
Done: False
Info: {}


