Evaluating
=========

Loads weights in the network and plays against random and simple bots

In [36]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
from nets.basic import ConnNet 
from kaggle_environments import make
from kaggle_environments.envs.connectx.connectx import random_agent, negamax_agent
from decider.mcts import MCTS
import torch 

In [42]:
net = ConnNet(cols=5, rows=4)
weights = torch.load('points/test4x5_too/500_step.pth')
net.load_state_dict(weights)
net.eval()
pass

In [43]:
env = make("connectx", debug=True, configuration={'rows':4,'columns': 5,'inarow':3})
tret = MCTS(model=net, env=env)
res = env.render(mode="ansi")
print(res)

+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+



In [45]:
class SmartPlayer:
    def __init__(self, weights, env):
        self.net = ConnNet(cols=5, rows=4)
        self.net.load_state_dict(weights)
        self.net.eval()
        
    def __call__(self, obs, config):
        state = self.obs_to_state(obs)
        temp_env = make("connectx", debug=False, configuration=config, state = state)
        with torch.no_grad():
            res = self.net.get_probs(obs['board'], 1)
            print(res)
        monte_carlo = MCTS(self.net, temp_env)
        for i in range(3):
            monte_carlo.playout()
        print(monte_carlo.get_move_probs())
        return max(monte_carlo.get_move_probs())
    
    @staticmethod
    def obs_to_state(obs):
        return {
            'action': 1,
            'reward': 0,
            'observation': obs,
            'status': 'ACTIVE'
        } 
        
sps = SmartPlayer(weights, env)
env.run([sps,negamax_agent])
res = env.render(mode="ansi")
print(res)

(tensor([1.9015e-05, 9.9976e-01, 2.1836e-04, 1.9036e-07, 1.5047e-07]), tensor([[-0.5366]]))
{0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0, 4: 0.0}
(tensor([1.3312e-02, 9.5864e-02, 1.1516e-02, 8.7930e-01, 8.6383e-06]), tensor([[-0.6580]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([1.5663e-09, 4.2453e-12, 1.6608e-11, 1.0000e+00, 2.0131e-18]), tensor([[-0.9135]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([7.8885e-14, 2.2501e-20, 1.1440e-20, 1.0000e+00, 0.0000e+00]), tensor([[-0.2875]]))
{0: 0.5, 1: 0.0, 2: 0.0, 3: 0.5}
+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 1 |
+---+---+---+---+---+
| 0 | 0 | 0 | 1 | 2 |
+---+---+---+---+---+
| 0 | 0 | 0 | 2 | 1 |
+---+---+---+---+---+
| 0 | 0 | 2 | 2 | 1 |
+---+---+---+---+---+



In [17]:
env.step([1,1])

FailedPrecondition: Environment done, reset required.

In [118]:
env.state

[{'action': 1,
  'reward': 0,
  'info': {},
  'observation': {'remainingOverageTime': 60,
   'step': 2,
   'board': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0],
   'mark': 1},
  'status': 'ACTIVE'},
 {'action': 1,
  'reward': 0,
  'info': {},
  'observation': {'remainingOverageTime': 60, 'mark': 2},
  'status': 'INACTIVE'}]

In [127]:
env1 = make("connectx", debug=True, configuration={'rows':4,'columns': 5,'inarow':3},state=env.state[0])

In [128]:
# env.reset()
# env.run([sps, 'random'])
res = env1.render(mode="ansi")
print(res)

+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+
| 0 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+
| 0 | 2 | 0 | 0 | 0 |
+---+---+---+---+---+
| 0 | 1 | 0 | 0 | 0 |
+---+---+---+---+---+



In [None]:
tret

In [47]:
class Arena:
    def __init__(self):
        self.env = make("connectx", debug=True, configuration={'rows':4,'columns': 5,'inarow':3})
        
    def play_n_games(self, n:int, bot1, bot2):
        wins = 0
        draws = 0
        for i in range(n):
            self.env.reset()
            self.env.run([bot1,bot2])
            reward = self.env.state[0]['reward']
            wins+=reward
            draws+=(reward==0)
            
        
        print(f'Bot1 winrate is {wins/n} draws: {draws}')
            

In [49]:
cage = Arena()
cage.play_n_games(40, sps, negamax_agent)

(tensor([1.9015e-05, 9.9976e-01, 2.1836e-04, 1.9036e-07, 1.5047e-07]), tensor([[-0.5366]]))
{0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0, 4: 0.0}
(tensor([5.6249e-04, 1.3605e-02, 9.8570e-01, 5.3445e-05, 7.5226e-05]), tensor([[-0.9415]]))
{0: 0.0, 1: 0.0, 2: 1.0, 3: 0.0, 4: 0.0}
(tensor([5.9785e-08, 2.2884e-07, 2.8252e-07, 1.0000e+00, 7.3633e-12]), tensor([[-0.8975]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([3.4095e-09, 4.3218e-08, 1.2518e-07, 1.0000e+00, 0.0000e+00]), tensor([[0.9861]]))
{0: 0.0, 1: 0.0, 2: 0.5, 3: 0.5}
(tensor([1.9015e-05, 9.9976e-01, 2.1836e-04, 1.9036e-07, 1.5047e-07]), tensor([[-0.5366]]))
{0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0, 4: 0.0}
(tensor([5.6249e-04, 1.3605e-02, 9.8570e-01, 5.3445e-05, 7.5226e-05]), tensor([[-0.9415]]))
{0: 0.0, 1: 0.0, 2: 1.0, 3: 0.0, 4: 0.0}
(tensor([5.9785e-08, 2.2884e-07, 2.8252e-07, 1.0000e+00, 7.3633e-12]), tensor([[-0.8975]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([3.4095e-09, 4.3218e-08, 1.2518e-07, 1.0000e+00, 0.0000e+00]), tensor(

(tensor([1.3312e-02, 9.5864e-02, 1.1516e-02, 8.7930e-01, 8.6383e-06]), tensor([[-0.6580]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([1.5663e-09, 4.2453e-12, 1.6608e-11, 1.0000e+00, 2.0131e-18]), tensor([[-0.9135]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([1.2901e-09, 1.7699e-13, 1.0956e-10, 1.0000e+00, 0.0000e+00]), tensor([[-0.9116]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0}
(tensor([9.9494e-01, 1.2603e-09, 4.2251e-03, 8.3838e-04, 0.0000e+00]), tensor([[-0.9540]]))
{0: 0.5, 1: 0.0, 2: 0.5, 3: 0.0}
(tensor([1.9015e-05, 9.9976e-01, 2.1836e-04, 1.9036e-07, 1.5047e-07]), tensor([[-0.5366]]))
{0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0, 4: 0.0}
(tensor([1.3312e-02, 9.5864e-02, 1.1516e-02, 8.7930e-01, 8.6383e-06]), tensor([[-0.6580]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([1.5663e-09, 4.2453e-12, 1.6608e-11, 1.0000e+00, 2.0131e-18]), tensor([[-0.9135]]))
{0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0, 4: 0.0}
(tensor([5.6812e-09, 8.2967e-11, 3.9302e-06, 1.0000e+00, 0.0000e+00]), tensor([[0.815

(tensor([3.4095e-09, 4.3218e-08, 1.2518e-07, 1.0000e+00, 0.0000e+00]), tensor([[0.9861]]))
{0: 0.0, 1: 0.0, 2: 0.5, 3: 0.5}
(tensor([1.9015e-05, 9.9976e-01, 2.1836e-04, 1.9036e-07, 1.5047e-07]), tensor([[-0.5366]]))
{0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0, 4: 0.0}
(tensor([5.6249e-04, 1.3605e-02, 9.8570e-01, 5.3445e-05, 7.5226e-05]), tensor([[-0.9415]]))
{0: 0.0, 1: 0.0, 2: 1.0, 3: 0.0, 4: 0.0}
(tensor([0.0022, 0.9515, 0.0153, 0.0227, 0.0083]), tensor([[0.9964]]))
{0: 0.0, 1: 0.5, 2: 0.0, 3: 0.5, 4: 0.0}
(tensor([6.4459e-09, 5.1834e-08, 1.1124e-05, 9.9999e-01, 0.0000e+00]), tensor([[0.9909]]))
{0: 0.0, 1: 0.0, 2: 0.5, 3: 0.5}
(tensor([1.9015e-05, 9.9976e-01, 2.1836e-04, 1.9036e-07, 1.5047e-07]), tensor([[-0.5366]]))
{0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0, 4: 0.0}
(tensor([5.6249e-04, 1.3605e-02, 9.8570e-01, 5.3445e-05, 7.5226e-05]), tensor([[-0.9415]]))
{0: 0.0, 1: 0.0, 2: 1.0, 3: 0.0, 4: 0.0}
(tensor([5.9785e-08, 2.2884e-07, 2.8252e-07, 1.0000e+00, 7.3633e-12]), tensor([[-0.8975]]))
{0: 0.0, 1: 0.0,