Based on

- https://github.com/pytorch/examples/blob/87d9a1e930b5b813/reinforcement_learning/reinforce.py
- https://medium.com/@ts1829/policy-gradient-reinforcement-learning-in-pytorch-df1383ea0baf [Notebook](https://nbviewer.jupyter.org/urls/gist.githubusercontent.com/ts1829/ebbe2cf946bf36951b724818c52e36b9/raw/4da449bffe9835e201f2fb34f381fbb53568d1ca/Policy%20Gradient%20with%20Cartpole%20and%20PyTorch%20%28Medium%20Version%29.ipynb)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import count
from collections import namedtuple
from tqdm import tqdm, trange
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical
%matplotlib inline

In [72]:
import random
from tuplestate import *
from gamestate import *
from benchmarking import random_solved_endgame
from vectorize import *
random.seed(0)
klonstate = random_solved_endgame(19)
# print(to_pretty_string(klonstate))

In [3]:
Args = namedtuple('Args', 'gamma lr seed render log_interval')
args = Args(
    gamma=0.01,
    lr=5e-2,
    seed=1,
    render=False,
    log_interval=20)

- In features: game state vector of size $233 \times 104$

- Out features: legal moves vector of size $623$

In [4]:
IN = 233*104
OUT = 623

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(IN, 800)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(800, OUT)
        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = F.relu(x)
        x = self.dropout(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

In [16]:
def select_action(klonstate):
    state_vec = state_to_vec(klonstate)
    movefilter = vector_legal_moves(klonstate)
    torch_state_vec = torch.from_numpy(state_vec).float().reshape(-1).unsqueeze(0)
    torch_filter = torch.from_numpy(movefilter.astype(np.float32)).unsqueeze(0)
    probs = policy(torch_state_vec) * torch_filter
    if (probs == 0).all():
        torch_filter.requires_grad_()
        # sample all legal moves with uniform probability
        m = Categorical(torch_filter)
    else:
        # :attr:`probs` will be normalized to sum to 1
        m = Categorical(probs)
    action = m.sample()
    log_prob = m.log_prob(action)
    policy.saved_log_probs.append(log_prob)
    return action.item()

def finish_episode():
    R = 0
    policy_loss = []
    returns = []
    for r in policy.rewards[::-1]:
        R = r + args.gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)
    optimizer.zero_grad()
    tcpl = torch.cat(policy_loss)
    policy_loss = -tcpl.sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]
    
def step(curr_state, action_idx):
    move_code = all_moves[action_idx]
    new_state = play_move(curr_state, move_code)
    reward = 0  
    if state_is_win(new_state) or all_cards_faceup(new_state):
        reward = 1
    if not state_is_legal(new_state):
        print('\n got illegal state by playing move', move_code)
        print('prev state')
        print(to_pretty_string(curr_state))
        print('\nnew state')
        print(to_pretty_string(new_state))
        assert state_is_legal(new_state), 'got illegal state'
    return new_state, reward

policy = Policy()
optimizer = optim.SGD(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()

In [17]:
for i_episode in range(1, 100):
    klonstate = random_solved_endgame(20)
    ep_reward = 0
    done = False
    for t in range(1, 100):
        action = select_action(klonstate)
        klonstate, reward = step(klonstate, action)
        done = reward > 0
        policy.rewards.append(reward)
        ep_reward += reward
        if done:
            print(f'{i} done after {t:4} steps {ep_reward:.2f}')
            break
    finish_episode()

done after    3 steps 1.00
done after   32 steps 1.00
done after    1 steps 1.00
done after    1 steps 1.00
done after   21 steps 1.00
done after    1 steps 1.00
done after    5 steps 1.00
done after    6 steps 1.00
done after   12 steps 1.00
done after    1 steps 1.00
done after   14 steps 1.00
done after    1 steps 1.00
done after    3 steps 1.00
done after    1 steps 1.00
done after    1 steps 1.00
done after    1 steps 1.00
done after   29 steps 1.00
done after   13 steps 1.00
done after   42 steps 1.00
done after    1 steps 1.00
done after    9 steps 1.00
done after   45 steps 1.00
done after    1 steps 1.00
done after    2 steps 1.00
done after   19 steps 1.00
done after    1 steps 1.00
done after   21 steps 1.00
done after   10 steps 1.00
done after    3 steps 1.00
done after   99 steps 1.00
done after   30 steps 1.00
done after   11 steps 1.00
done after    9 steps 1.00
done after   14 steps 1.00
done after    8 steps 1.00
done after    1 steps 1.00
done after   17 steps 1.00
d

In [97]:
klonstate = random_solved_endgame(20)
print(to_pretty_string(klonstate))
print()

done = False
visited = set()
for i in range(500):
    visited.add(klonstate)
    action = select_action(klonstate)
#     print(f'iteration {i}  action {all_moves[action]}')
    klonstate, reward = step(klonstate, action)
    done = reward > 0
    if done:
        break

print()
if done:
    print("Solved!")
else:
    print(f'Not solved after {i+1} iterations')
    
print(to_pretty_string(klonstate))

Stock: JH
Waste: QD TS QC KD 9H JS
Fnd C: AC 2C 3C 4C 5C 6C 7C 8C 9C
Fnd D: AD 2D 3D 4D 5D 6D 7D 8D 9D TD
Fnd S: AS 2S 3S 4S 5S 6S 7S 8S 9S
Fnd H: AH 2H 3H 4H 5H 6H 7H 8H
Tab 1: TH
Tab 2: KH
Tab 3: QS JD TC
Tab 4: KC
Tab 5: 
Tab 6: KS
Tab 7: qh JC


Not solved after 500 iterations
Stock: JS 9H KD
Waste: QD TS QC
Fnd C: AC 2C 3C 4C 5C 6C 7C 8C 9C
Fnd D: AD 2D 3D 4D 5D 6D 7D 8D 9D TD JD
Fnd S: AS 2S 3S 4S 5S 6S 7S 8S 9S
Fnd H: AH 2H 3H 4H 5H 6H 7H 8H
Tab 1: 
Tab 2: KH QS JH TC
Tab 3: 
Tab 4: KC
Tab 5: 
Tab 6: KS
Tab 7: qh JC TH
