In [9]:
from spinup.algos.pytorch.sac.sac import sac
from spinup.algos.pytorch.ppo.ppo import ppo
from spinup.algos.pytorch.ddpg.ddpg import ddpg
from spinup.algos.pytorch.td3.td3 import td3

In [2]:
import numpy as np
import gym
import pooltool as pt

In [3]:
import math

In [4]:
import torch
reg_net = torch.load('reg_net.pt')

In [5]:
n_balls = 2
labels = ['cue', '1', '2', '3', '4', '5', '6', '7', '8', '9']
table = pt.PocketTable(model_name="7_foot")
pockets = table.get_pockets()

def random_ball():
    a = np.random.rand(2)
    a[0] = (0.93)*a[0]+0.03
    a[1] = (1.92)*a[1]+0.03
    return a

def generate_balls():
    points = []
    points.append(random_ball())
    
    for _ in range(n_balls):
        close = True
        while close:
            test = random_ball()
            close = False
            for point in points:
                if np.square(point - test).sum() < 0.06:
                    close = True
                    break
        points.append(test)

    balls = {}
    
    for i in range(n_balls+1):
        balls[labels[i]] = pt.Ball(labels[i], xyz=points[i])

    return balls

def balls_to_obs(balls):
    ar = tuple(x.rvw[0,:2] for x in balls.values())
    return np.hstack(ar)


def shoot(balls, speed, vspin, hspin):

    cue = pt.Cue(cueing_ball=balls['cue'])
    cue.aim_for_best_pocket(balls['1'], pockets.values())

    cue.strike(V0=speed, b=vspin, a=hspin)
    shot = pt.System(cue=cue, table=table, balls=balls)

    try:
        shot.simulate(continuize=True)
    except:
        pass
                
    return balls

In [6]:
from gym import spaces
import numpy as np


class Pool(gym.Env):
    #metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, size=5):
        self.size = size  # The size of the square grid
        self.window_size = 512  # The size of the PyGame window

        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        self.observation_space = spaces.Box(low=np.tile(np.array([0.03],dtype=np.float32),n_balls*2+2),
                                            high=np.tile(np.array([0.96, 1.95],dtype=np.float32), n_balls+1),
                                            dtype=np.float32)

        # 3 dimensions corresponding to speed, vspin, hspin
        self.action_space = spaces.Box(low=np.array([-1, -1, -1],dtype=np.float32),
                                       high=np.array([1, 1, 1],dtype=np.float32), dtype=np.float32)


        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
    
    
    def reset(self, seed=None, options=None):
            
        # Generate stargint layout of balls
        balls = generate_balls()
        self.balls = balls

        observation = balls_to_obs(balls)

        return observation
    
    def step(self, action):
        
        balls = self.balls
        
        action = action*np.array([3,0.4,0.3])+np.array([3,0,0])
                
        balls = shoot(balls, action[0], action[1], action[2])
        
        on_table = [key for key, value in balls.items() if value.rvw[0,2] > 0]
        
        if 'cue' not in on_table:
            reward = -4
        elif '1' in on_table:
            reward = 0
        elif '2' in on_table:
            pred = reg_net(torch.as_tensor(np.hstack((balls['cue'].rvw[0,:2],
                                                      balls['2'].rvw[0,:2])), dtype=torch.float32))
            reward = math.exp(pred.item()*n_balls)
        else:
            reward = 0
        
        terminated = True

        observation = balls_to_obs(balls)
        
        return observation, reward, terminated, False

In [13]:
td3(Pool)

[32;1mLogging data to /tmp/experiments/1677945030/progress.txt[0m
[36;1mSaving config:
[0m
{
    "ac_kwargs":	{},
    "act_noise":	0.1,
    "actor_critic":	"MLPActorCritic",
    "batch_size":	100,
    "env_fn":	"Pool",
    "epochs":	100,
    "gamma":	0.99,
    "logger":	{
        "<spinup.utils.logx.EpochLogger object at 0x2c89f2ec0>":	{
            "epoch_dict":	{},
            "exp_name":	null,
            "first_row":	true,
            "log_current_row":	{},
            "log_headers":	[],
            "output_dir":	"/tmp/experiments/1677945030",
            "output_file":	{
                "<_io.TextIOWrapper name='/tmp/experiments/1677945030/progress.txt' mode='w' encoding='UTF-8'>":	{
                    "mode":	"w"
                }
            }
        }
    },
    "logger_kwargs":	{},
    "max_ep_len":	1000,
    "noise_clip":	0.5,
    "num_test_episodes":	10,
    "pi_lr":	0.001,
    "policy_delay":	2,
    "polyak":	0.995,
    "q_lr":	0.001,
    "replay_size":	1000000,
    "

In [None]:
6

In [None]:
sac(Pool, epochs = 100, alpha = 0)

[32;1mLogging data to /tmp/experiments/1678142861/progress.txt[0m
[36;1mSaving config:
[0m
{
    "ac_kwargs":	{},
    "actor_critic":	"MLPActorCritic",
    "alpha":	0,
    "batch_size":	100,
    "env_fn":	"Pool",
    "epochs":	100,
    "gamma":	0.99,
    "logger":	{
        "<spinup.utils.logx.EpochLogger object at 0x2c8b0ae90>":	{
            "epoch_dict":	{},
            "exp_name":	null,
            "first_row":	true,
            "log_current_row":	{},
            "log_headers":	[],
            "output_dir":	"/tmp/experiments/1678142861",
            "output_file":	{
                "<_io.TextIOWrapper name='/tmp/experiments/1678142861/progress.txt' mode='w' encoding='UTF-8'>":	{
                    "mode":	"w"
                }
            }
        }
    },
    "logger_kwargs":	{},
    "lr":	0.001,
    "max_ep_len":	1000,
    "num_test_episodes":	10,
    "polyak":	0.995,
    "replay_size":	1000000,
    "save_freq":	1,
    "seed":	0,
    "start_steps":	10000,
    "steps_per_epo

In [19]:
torch.save(ddpg_net, 'ddpg_net.pt')

In [28]:
td3

<function spinup.algos.pytorch.td3.td3.td3(env_fn, actor_critic=<class 'spinup.algos.pytorch.td3.core.MLPActorCritic'>, ac_kwargs={}, seed=0, steps_per_epoch=4000, epochs=100, replay_size=1000000, gamma=0.99, polyak=0.995, pi_lr=0.001, q_lr=0.001, batch_size=100, start_steps=10000, update_after=1000, update_every=50, act_noise=0.1, target_noise=0.2, noise_clip=0.5, policy_delay=2, num_test_episodes=10, max_ep_len=1000, logger_kwargs={}, save_freq=1)>