In [None]:
"""
### okay so:
1. set up a way to continually play the game and save the networks
    play
    save
    train
    checkpoint
    
2. Think about network diversity

3. Think properly about the hyperparameters and search tree parameters

4. Think about how to stop other things interferring with the generation and training...

5. Read about bayesian q learning and maybe get back into bayesian a bit more in general
    it feels as though bayesian learning and search trees should go well together, measuring uncertainty as well
    as value

6. build a model loader and save game loader separately

7. after the latest stuff is all tested, move away from using chrome so it can be fully shut down!

8. think carefully about a nice structure for the code, in terms of where to sit each function...

"""


# okay so first thing is to improve the saving, build a model loader, build a trainer script, build a save game ifier

# Go through and clean all code and make it NICE to work with!


In [1]:
import random
import matplotlib.pyplot as plt
import datetime
import os
from pathlib import Path

import numpy as np
import copy
from importlib import reload
import torch
import ray
import pickle

import constants
reload(constants)

import game
reload(game)

import plotting
reload(plotting)

import mcts.mcts
reload(mcts.mcts)

import mcts.networks
reload(mcts.networks)

import mcts.agent
reload(mcts.agent)

import augmentor
reload(augmentor)

import raygent
reload(raygent)

from game import Patterns
from plotting import PatternPlotter

from mcts.mcts import Tree, Node
from mcts.networks import PatternsNet
from mcts.agent import Agent

from augmentor import StateAugmentor

from raygent import AgentWorker

rseed = 12387623
random.seed(rseed)
torch.manual_seed(rseed)
np.random.seed(rseed)

my_device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


In [2]:
ray.init()

2025-07-23 16:42:39,761	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.11.5
Ray version:,2.42.1
Dashboard:,http://127.0.0.1:8265


In [3]:
### load the model:
network_path = os.path.join(Path.cwd(), 'saved_networks', 'v4', '2025_07_11_15_25.pt')
my_network = PatternsNet(in_channels=47, out_channels=128)
my_network.load_state_dict(torch.load(network_path, weights_only=True))

<All keys matched successfully>

In [6]:
### how long to keep up the training iterations:
NUM_IT = 1 # 1 validation, 15*30k games:

### Agent parameters:
NUM_TREES = 50
TARGET_GAMES = 50


### exploration schedule, to control number of searches with depth of tree:
SCHEDULE = [
    (0, 0), # explore randomly for the first moves:
    (1, 25), # limited search to avoid worst moves:
    (10, 500), # deeper search for deeper games:
]

### number of moves from terminal to save with:
SAVE_DEPTH = 5

### breadth-search parameters to favor narrower, deeper trees even early on:
TOPN = 6 # 6 best moves according to the puct scores:
RANDM = 4 # 4 additional random moves to favor exploration:

### how the next move is chosen. If Temp is not None, sampled from visit count distribution
SELECTION_TEMP = 1.0

agent_config = {
   # "agent_id": "agent",
    "network": my_network,
    "device": my_device,
    "num_trees": NUM_TREES,
    "target_games": TARGET_GAMES,
    "selection_temperature": SELECTION_TEMP,
    "restrict_topn": TOPN,
    "restrict_randm": RANDM,
    "save_depth": SAVE_DEPTH,
    "explore_steps_schedule": SCHEDULE,
    "debug": False,
}

### number of agents running as separate python processes at once:
NUM_RAYGENTS = 3


In [7]:
### start the ray processes:
def run_raygents(agent_config):
    futures = []
    
    for _it in range(NUM_RAYGENTS):
        _config = agent_config.copy()
        _config["agent_id"] = f"raygent_{_it}"
        _rgent = AgentWorker.remote(_config)
        futures.append(_rgent.run.remote())
    
    # Wait for completion and collect all games
    all_completed_games = ray.get(futures)
    # (Optional) Flatten or save
    print("All agents completed!")
    return futures



In [8]:
%prun run_raygents(agent_config)

[36m(AgentWorker pid=63532)[0m Generating initial games:
[36m(AgentWorker pid=63532)[0m Evaluating tensor states...
[36m(AgentWorker pid=63532)[0m Provisioning inference to root nodes...
[36m(AgentWorker pid=63532)[0m 1 games have been completed!
[36m(AgentWorker pid=63532)[0m 2 games have been completed!
[36m(AgentWorker pid=63532)[0m 3 games have been completed!
[36m(AgentWorker pid=63532)[0m 4 games have been completed!
[36m(AgentWorker pid=63532)[0m 5 games have been completed!
[36m(AgentWorker pid=63532)[0m 6 games have been completed!
[36m(AgentWorker pid=63532)[0m 7 games have been completed!
[36m(AgentWorker pid=63532)[0m 8 games have been completed!
[36m(AgentWorker pid=63532)[0m 9 games have been completed!
[36m(AgentWorker pid=62028)[0m Generating initial games:[32m [repeated 2x across cluster][0m
[36m(AgentWorker pid=62028)[0m Evaluating tensor states...[32m [repeated 2x across cluster][0m
[36m(AgentWorker pid=62028)[0m Provisioning inferen

         153762 function calls (150771 primitive calls) in 79.224 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1   77.075   77.075   77.075   77.075 {method 'get_objects' of 'ray._raylet.CoreWorker' objects}
    627/6    1.518    0.002    1.895    0.316 {function Pickler.dump at 0x000001B4B5AFAAC0}
      621    0.224    0.000    0.224    0.000 {method '_write_file' of 'torch._C.StorageBase' objects}
        3    0.158    0.053    2.113    0.704 {method 'create_actor' of 'ray._raylet.CoreWorker' objects}
        3    0.058    0.019    1.954    0.651 serialization.py:471(_serialize_to_pickle5)
      621    0.018    0.000    0.034    0.000 _tensor.py:257(_reduce_ex_internal)
     7662    0.016    0.000    0.019    0.000 __init__.py:1000(__getitem__)
     5634    0.012    0.000    0.018    0.000 copyreg.py:113(_slotnames)
     4347    0.010    0.000    0.023    0.000 serialization.py:880(persistent_id)
      621    0

[36m(AgentWorker pid=62028)[0m 50 games have been completed!


In [None]:
len(all_completed_games[0])

In [None]:
total_tdata = 0

for _games in all_completed_games:
    total_tdata += len(_games[0])
    total_tdata += len(_games[-1])
    total_tdata += len(_games[1])

total_tdata

In [None]:
print(f"{**agent_config}")

In [None]:
rstr1 = None
rstr2 = 1

my_rstr1 = rstr1 or 0
my_rstr2 = rstr2 or 0

my_rstr1, my_rstr2

In [None]:
jim = np.random.rand(3, 4, 5)

james = np.array(jim)

james[0] *=1000
jim[1] -=1000

jim, james

torch.tensor(np.stack([jim, james])).shape

In [None]:
my_vals = np.random.rand(1000)

def f1(n):
    curr_val = 0
    new_vals = []
    for _ in my_vals[::-1]:
        curr_val += _ * 10
        new_vals.append(curr_val)

    return new_vals

def f2(n):
    curr_val = 0
    new_vals = []
    for _ in reversed(my_vals):
        curr_val += _ * 10
        new_vals.append(curr_val)

    return new_vals


In [None]:
res1

In [None]:
%prun res2 = f2(100)

In [None]:
%prun res1 = f1(10000)

In [None]:
schedule = None
bill = schedule or [(0, 100)]

In [None]:
bill
