In [1]:
#  Use the appropriate install depending on your local directory

##%pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.

#%pip install --quiet -U pip -r ./requirements/rzequirements-train.txt ./.

In [2]:
import pickle
import jax
from hydra import compose, initialize

#export PATH="~/Documents/InstadeepGroupProject/jumanji_routing"
from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device

  from neptune import new as neptune


In [12]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import numpy as np
import chex
import jax.numpy as jnp
import jumanji
from typing import Dict

from jumanji.environments import Connector
from jumanji.environments.routing.connector.types import Agent, Observation, State
from jumanji.types import TimeStep
from jumanji.environments.routing.connector.utils import get_position, get_target
from routing_board_generation.board_generation_methods.jax_implementation.board_generation.seed_extension import SeedExtensionBoard
from routing_board_generation.board_generation_methods.numpy_implementation.utils.utils import get_heads_and_targets
from routing_board_generation.interface.board_generator_interface import BoardGenerator, BoardName
from datetime import datetime
from copy import deepcopy

In [13]:
from routing_board_generation.board_generation_methods.numpy_implementation.utils.post_processor_utils_numpy import \
    training_board_from_solved_board, extend_wires, count_detours
from routing_board_generation.board_generation_methods.jax_implementation.utils.post_processor_utils_jax import \
    training_board_from_solved_board_jax, extend_wires_jax

In [14]:
import statistics
from routing_board_generation.benchmarking.benchmarks.empty_board_evaluation import EvaluateEmptyBoard

## Load configs

In [17]:
#with initialize(version_base=None, config_path="../jumanji/training/configs"):
    #cfg = compose(config_name="config.yaml", overrides=["env=connector", "agent=a2c"])
with initialize(version_base=None, config_path="../agent_training/configs"):
    cfg = compose(config_name="config.yaml", overrides=["env=connector", "agent=a2c"])
cfg

{'agent': 'a2c', 'seed': 0, 'logger': {'type': 'terminal', 'save_checkpoint': False, 'name': '${agent}_${env.name}'}, 'env': {'name': 'connector', 'registered_version': 'Connector-v0', 'ic_board': {'generation_type': 'offline_seed_extension', 'board_name': 'none', 'grid_size': 10, 'num_agents': 5}, 'seed_extension': {'randomness': 0, 'two_sided': True, 'extension_iterations': 1, 'extension_steps': 1e+23, 'number_of_boards': 100000}, 'network': {'transformer_num_blocks': 4, 'transformer_num_heads': 8, 'transformer_key_size': 16, 'transformer_mlp_units': [512], 'conv_n_channels': 32}, 'training': {'num_epochs': 500, 'num_learner_steps_per_epoch': 100, 'n_steps': 20, 'total_batch_size': 128}, 'evaluation': {'eval_total_batch_size': 5000, 'greedy_eval_total_batch_size': 5000}, 'a2c': {'normalize_advantage': False, 'discount_factor': 0.99, 'bootstrapping_factor': 0.95, 'l_pg': 1.0, 'l_td': 1.0, 'l_en': 0.01, 'learning_rate': 0.0002}}}

## Load a saved checkpoint

In [18]:
file = "../trained_agents/uniform/training_state"
with open(file,"rb") as f:
    training_state = pickle.load(f)

params = first_from_device(training_state.params_state.params)
#print(params)
env = setup_env(cfg).unwrapped
#print(env)
agent = setup_agent(cfg, env)
#print(agent)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
#print(params.num_agents)
#print(policy)


# Roll out a few episodes of their default "uniform" board generation

In [19]:
env_fn = jax.jit(env.reset) # Speed up reset
step_fn = jax.jit(env.step)  # Speed up env.step
GRID_SIZE = 10
NUM_AGENTS = 5
NUM_EPISODES = 10 # 10000

print(datetime.now())
states = []
key = jax.random.PRNGKey(cfg.seed)

connections = []
for episode in range(NUM_EPISODES):  
    key, reset_key = jax.random.split(key)
    #state, timestep = jax.jit(env.reset)(reset_key)
    state, timestep = env_fn(reset_key)
        
        
    while not timestep.last():
        key, action_key = jax.random.split(key)
        observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
        # Two implementations for calling the policy, about equivalent speed
        #action, _ = policy(observation, action_key)
        action, _ = jax.jit(policy)(observation, action_key)
        # Three implementations for updating the state/timestep.  The third is much faster.
        #state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0)) # original jit = 0.32, 52sec/10
        #state, timestep = env.step(state, action.squeeze(axis=0)) # no jit = 0.13, 26sec/10
        state, timestep = step_fn(state, action.squeeze(axis=0)) # jit function = 0.003 5 sec/10, 49sec/100d
        states.append(state)

        
        
    # Freeze the terminal frame to pause the GIF.
    for _ in range(10):
        states.append(state)
        
    #################################################
    # Evaluate the number of wires connected
    num_connected = timestep.extras["num_connections"]
    connections.append(num_connected)

#print(connections)
connections = np.array (connections)
print(f"{NUM_EPISODES} Evaluations: Average connections for default Uniform board = {connections.mean()}, std={connections.std()}")
print(datetime.now())

2023-05-02 10:59:25.711194
10 Evaluations: Average connections for default Uniform board = 3.8, std=0.8717797887081347
2023-05-02 10:59:29.197112


# Roll out a few episodes, extending the wires of each board generators

In [20]:
#WE PROBABLY DON'T NEED TO EXTEND THE WIRES OF EACH ANYMORE BUT KEEP THIS CELL FOR HISTORICAL REASONS

step_fn = jax.jit(env.step) #  Speed up env.step
BOARD_GENERATORS = [
        BoardName.BFS_BASE,
        BoardName.BFS_MIN_BENDS,
        BoardName.BFS_FIFO,
        BoardName.BFS_SHORTEST,
        BoardName.BFS_LONGEST,
        BoardName.RANDOM_WALK,
        BoardName.LSYSTEMS,
        BoardName.WFC,
        BoardName.NUMBERLINK]

GRID_SIZE = 10
NUM_AGENTS = 5
NUM_EPISODES = 1 # 10000

print(datetime.now())
states = []
key = jax.random.PRNGKey(cfg.seed)

for BOARD_GENERATOR in BOARD_GENERATORS:
    print(BOARD_GENERATOR)
    connections = []
    detours_raw = []
    detours_extended = []
    for episode in range(NUM_EPISODES):
   
        key, reset_key = jax.random.split(key)


        ### Copied from ic_rl_training  ##########################
        # Create empty grid.
        grid = jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32)

        board_class = BoardGenerator.get_board_generator(
            board_enum=BOARD_GENERATOR)    
        board = board_class(GRID_SIZE, GRID_SIZE, NUM_AGENTS)
        pins = board.return_training_board()

        ### WIRE EXTENSION #####
        #print("PINS")
        #print(pins)
        board_solved = board.return_solved_board()
        detours_raw.append(count_detours(board_solved, False))
        #print("BOARD SOLVED")
        #print(board_solved)
        board_extended = extend_wires(board_solved)
        detours_extended.append(count_detours(board_extended, False))
        #print("BOARD EXTENDED")
        #print(board_extended)
        pins = training_board_from_solved_board_jax(board_extended)
        #print("PINS EXTENDED")
        #print(pins)
        ### END OF WIRE EXTENSION ####

        starts_flat, targets_flat = get_heads_and_targets(pins)
        starts = jnp.divmod(np.array(starts_flat), GRID_SIZE)
        targets = jnp.divmod(np.array(targets_flat), GRID_SIZE)

        agent_position_values = jax.vmap(get_position)(jnp.arange(NUM_AGENTS))
        agent_target_values = jax.vmap(get_target)(jnp.arange(NUM_AGENTS))

        # Place the agent values at starts and targets.
        grid = grid.at[starts].set(agent_position_values)
        grid = grid.at[targets].set(agent_target_values)

        # Create the agent pytree that corresponds to the grid.
        agents = jax.vmap(Agent)(
            id=jnp.arange(NUM_AGENTS),
            start=jnp.stack(starts, axis=1),
            target=jnp.stack(targets, axis=1),
            position=jnp.stack(starts, axis=1),
        )

        step_count = jnp.array(0, jnp.int32)

        state = State(key=key, grid=grid, step_count=step_count, agents=agents)
        # END OF CODE FROM IC_RL_TRAINING
        ########################################################################################

        
        ###############AMENDMENT 2 FROM CLEMENT TO INITIALIZE TIMESTEP############################
        action_mask = jax.vmap(env._get_action_mask, (0, None))(
            state.agents, state.grid
            )
        observation = Observation(
            grid=env._obs_from_grid(state.grid),
            action_mask=action_mask,
            step_count=state.step_count,
            )
        extras = env._get_extras(state)
        timestep = jumanji.types.restart(
            observation=observation, extras=extras, shape=(env.num_agents,)
            )
        ####################################################################
        
        
        while not timestep.last():
            key, action_key = jax.random.split(key)
            observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
            # Two implementations for calling the policy, about equivalent speed
            #action, _ = policy(observation, action_key)
            action, _ = jax.jit(policy)(observation, action_key)
            # Three implementations for updating the state/timestep.  The third is much faster.
            #state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0)) # original jit = 0.32, 52sec/10
            #state, timestep = env.step(state, action.squeeze(axis=0)) # no jit = 0.13, 26sec/10
            state, timestep = step_fn(state, action.squeeze(axis=0)) # jit function = 0.003 5 sec/10, 49sec/100d
            states.append(state)
        # Freeze the terminal frame to pause the GIF.
        for _ in range(10):
            states.append(state)

        # Evaluate the number of wires connected
        num_connected = timestep.extras["num_connections"]
        connections.append(num_connected)

    connections = np.array(connections)
    print(f"{NUM_EPISODES} Evaluations: Average connections for {BOARD_GENERATOR} = {connections.mean()}, std={connections.std()}")
    detours_raw = np.array(detours_raw)    
    print(f"Average Detours (RAW) = {detours_raw.mean()}, (std = {detours_raw.std()}) ")
    detours_extended = np.array(detours_extended)
    print(f"Average Detours (extended) = {detours_extended.mean()}, (std = {detours_extended.std()}) ")
    print(datetime.now())

2023-05-02 10:59:29.236935
BoardName.BFS_BASE
1 Evaluations: Average connections for bfs_base = 4.0, std=0.0
Average Detours (RAW) = 1.0, (std = 0.0) 
Average Detours (extended) = 3.0, (std = 0.0) 
2023-05-02 10:59:31.373208
BoardName.BFS_MIN_BENDS
1 Evaluations: Average connections for bfs_min_bend = 4.0, std=0.0
Average Detours (RAW) = 11.0, (std = 0.0) 
Average Detours (extended) = 13.0, (std = 0.0) 
2023-05-02 10:59:31.506163
BoardName.BFS_FIFO
1 Evaluations: Average connections for bfs_fifo = 3.0, std=0.0
Average Detours (RAW) = 0.0, (std = 0.0) 
Average Detours (extended) = 7.0, (std = 0.0) 
2023-05-02 10:59:31.670605
BoardName.BFS_SHORTEST
1 Evaluations: Average connections for bfs_short = 5.0, std=0.0
Average Detours (RAW) = 1.0, (std = 0.0) 
Average Detours (extended) = 3.0, (std = 0.0) 
2023-05-02 10:59:31.845211
BoardName.BFS_LONGEST
1 Evaluations: Average connections for bfs_long = 3.0, std=0.0
Average Detours (RAW) = 1.0, (std = 0.0) 
Average Detours (extended) = 6.0, (std

# TEST ALL THE AGENTS ON ALL THE BOARDS

In [None]:
# TEST ALL THE AGENTS ON ALL THE BOARDS

  
key = jax.random.PRNGKey(cfg.seed)
step_fn = jax.jit(env.step) # Speed up env.step
BOARD_GENERATORS = [
        #BoardName.BFS_BASE,
        ##BoardName.BFS_MIN_BENDS, # Don't use this. It's redundant
        ##BoardName.BFS_FIFO,      # Don't use this. It's redundant
        #BoardName.BFS_SHORTEST,
        #BoardName.BFS_LONGEST,
        #BoardName.RANDOM_WALK,
        #BoardName.RANDOM_SEED,
        BoardName.JAX_SEED_EXTENSION,
        BoardName.LSYSTEMS,
        ##BoardName.WFC,           # Don't use this.  Wave Form Collapse will not be pursued.
        #BoardName.NUMBERLINK
        ]

#agents = ["parallel_randomwalk", "uniform"] # other available agents
agents = ["bfs_base"]

GRID_SIZE = 10
NUM_AGENTS = 5
NUM_EPISODES = 5000 # 10000


for agentname in agents:
    file = "../trained_agents/"+agentname+"/training_state"
    #file = "trained_agents/fullsteps_random1/"+agentname+"/training_state"
    print("\nAGENT = ",agentname)
    with open(file,"rb") as f:
        training_state = pickle.load(f)
    params = first_from_device(training_state.params_state.params)
    env = setup_env(cfg).unwrapped
    agent = setup_agent(cfg, env)
    policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
    
    connections_all = []
    detours_all = []
    keys_all = []
    episodes_all = []
    densities_all = []
    diversities_all = []
    print(datetime.now())
    for BOARD_GENERATOR in BOARD_GENERATORS: # use this to test multiple generators
    #BOARD_GENERATOR = BoardName.RANDOM_SEED # use this to test varying randomseed parameters
    #for iterations in range(11, 22): # use this to test varying randomseed parameters
        #print("ITERATIONS = ", iterations) # use this to test varying randomseed parameters
        
        states = []
        connections = []
        print(BOARD_GENERATOR)         
        if (BOARD_GENERATOR == BoardName.JAX_SEED_EXTENSION):  #BoardName.RANDOM_SEED):
            board = SeedExtensionBoard(GRID_SIZE, GRID_SIZE, NUM_AGENTS)
            solved_board_compiled = jax.jit(board.return_solved_board, 
                    static_argnames=['extension_iterations', 'randomness','two_sided','extension_steps'])
            #training_board_compiled = jax.jit(board.return_training_board)
            #starts_targets_compiled = jax.jit(board.generate_starts_ends)

            #extended_board_init = board.return_extended_board  #jax.jit(board.return_extended_board)

        for episode in range(NUM_EPISODES):

            key, reset_key = jax.random.split(key)
            #### Copied from ic_rl_training ###################################
            # Create empty grid.
            grid = jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32)

            board_class = BoardGenerator.get_board_generator(
                board_enum=BOARD_GENERATOR)    

            if (BOARD_GENERATOR == BoardName.JAX_SEED_EXTENSION): #RANDOM_SEED):
                key, subkey = jax.random.split(key)

                solved_board = solved_board_compiled(subkey, randomness = 1.0, two_sided = True, 
                                                     extension_iterations=21, extension_steps=1e23) 
                """  For varying randomseed parameters
                solved_board = solved_board_compiled(subkey, randomness = 0.85, two_sided = True, 
                                     extension_iterations=iterations, extension_steps=1e23) 
                """
                
                #print("SOLVED BOARD")
                #print(solved_board)
                #print(solved_board)
                #print(np.array(solved_board))
                np_solved_board = np.array(solved_board)
                #print("shape=",np_solved_board.shape)
                training_board = training_board_from_solved_board(np.array(solved_board))
                pins = training_board
                #print("TRAINING BOARD")
                #print(training_board)
            
            #elif BOARD_GENERATOR == BoardName.LSYSTEMS:
            #    board = board_class(GRID_SIZE, GRID_SIZE, NUM_AGENTS)
            #    key, subkey = jax.random.split(key)
            #    pins = board.return_training_board(subkey)
            #    solved_board = board.return_solved_board(subkey)

            else:
                board = board_class(GRID_SIZE, GRID_SIZE, NUM_AGENTS)
                #print("SOLVED BOARD")
                #print(board.return_solved_board())
                solved_board = board.return_solved_board()
                pins = board.return_training_board() 


            starts_flat, targets_flat = get_heads_and_targets(pins)
            starts = jnp.divmod(np.array(starts_flat), GRID_SIZE)
            targets = jnp.divmod(np.array(targets_flat), GRID_SIZE)

            agent_position_values = jax.vmap(get_position)(jnp.arange(NUM_AGENTS))
            agent_target_values = jax.vmap(get_target)(jnp.arange(NUM_AGENTS))

            # Place the agent values at starts and targets.
            grid = grid.at[starts].set(agent_position_values)
            grid = grid.at[targets].set(agent_target_values)

            # Create the agent pytree that corresponds to the grid.
            agents = jax.vmap(Agent)(
                id=jnp.arange(NUM_AGENTS),
                start=jnp.stack(starts, axis=1),
                target=jnp.stack(targets, axis=1),
                position=jnp.stack(starts, axis=1),
            )
            step_count = jnp.array(0, jnp.int32)
            state = State(key=key, grid=grid, step_count=step_count, agents=agents)
            ### END OF CODE FROM IC_RL_TRAINING  ##################################

            ### AMENDMENT 2 FROM CLEMENT TO INITIALIZE TIMESTEP #################
            action_mask = jax.vmap(env._get_action_mask, (0, None))(
                state.agents, state.grid
                )
            observation = Observation(
                grid=env._obs_from_grid(state.grid),
                action_mask=action_mask,
                step_count=state.step_count,
                )
            extras = env._get_extras(state)
            timestep = jumanji.types.restart(
                observation=observation, extras=extras, shape=(env.num_agents,)
                )
            ####################################################################
            # Let the agent try to solve each board
            while not timestep.last():
                key, action_key = jax.random.split(key)
                observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
                # Two implementations for calling the policy, about equivalent speed
                #action, _ = policy(observation, action_key)
                action, _ = jax.jit(policy)(observation, action_key)
                # Three implementations for updating the state/timestep.  The third is much faster.
                #state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0)) # original jit = 0.32, 52sec/10
                #state, timestep = env.step(state, action.squeeze(axis=0)) # no jit = 0.13, 26sec/10
                state, timestep = step_fn(state, action.squeeze(axis=0)) # jit function = 0.003 5 sec/10, 49sec/100d
                states.append(state)
            # Freeze the terminal frame to pause the GIF.
            for _ in range(10):
                states.append(state)

            # Evaluate the number of wires connected
            num_connected = timestep.extras["num_connections"]
            connections.append(num_connected)
            connections_all.append(num_connected)
            keys_all.append(key[0])
            detours_all.append(count_detours(np.array(solved_board)))
            episodes_all.append(episode)
            densities_all.append(np.count_nonzero(np.array(solved_board)>0))
            board_evaluator = EvaluateEmptyBoard(np.array(solved_board))
            board_stats = board_evaluator.board_statistics
            diversity = board_stats["heatmap_score_diversity"] 
            diversities_all.append(diversity)
            # End of section for each individual board
            
        connections = np.array(connections)
        print(f"{NUM_EPISODES} Evaluations: Average connections for {BOARD_GENERATOR} = {connections.mean()}, std={connections.std()}")  
        print(datetime.now())
        # End of section for each board type
        
    connections_all = np.array(connections_all)
    print(f"{NUM_EPISODES} Evaluations: Total average connections for {agentname} = {connections_all.mean()}, std={connections_all.std()}")  
    print("CORRELATIONS")
    detours_all = np.array(detours_all)
    keys_all = np.array(keys_all)
    episodes_all = np.array(episodes_all)
    #print("mean keys = ", keys_all.mean())
    #print("mean episodes = ", episodes_all.mean())
    print("Correlation of connections with detours is ", statistics.correlation(connections_all, detours_all))
    print("Correlation of connections with keys is ", statistics.correlation(connections_all, keys_all))
    print("Correlation of connections with episodes is ", statistics.correlation(connections_all, episodes_all))
    print("Correlation of connections with densities is ", statistics.correlation(connections_all, densities_all))
    print("Correlation of connections with diversities is ", statistics.correlation(connections_all, diversities_all))



AGENT =  bfs_base
2023-05-02 10:59:44.772540
BoardName.JAX_SEED_EXTENSION




EVALUATING EVERYONE ON L-SYSTEMS

AGENT =  uniform
2023-05-02 00:19:21.753036
BoardName.LSYSTEMS
5000 Evaluations: Average connections for lsystems_standard = 4.98, std=0.14422205101855956
2023-05-02 00:25:21.383975
5000 Evaluations: Total average connections for uniform = 4.98, std=0.14422205101855956
CORRELATIONS
Correlation of connections with detours is  -0.01935806893685983
Correlation of connections with keys is  0.005346677659771794
Correlation of connections with episodes is  -0.02031642004850264
Correlation of connections with densities is  -0.009123005946855816
Correlation of connections with diversities is  -0.028970764358208122

AGENT =  bfs_base
2023-05-02 00:25:21.807770
BoardName.LSYSTEMS
5000 Evaluations: Average connections for lsystems_standard = 4.971, std=0.1837362239733907
2023-05-02 00:31:21.863111
5000 Evaluations: Total average connections for bfs_base = 4.971, std=0.1837362239733907
CORRELATIONS
Correlation of connections with detours is  -0.029094437054042795
Correlation of connections with keys is  0.004213078077117279
Correlation of connections with episodes is  0.008364618147003282
Correlation of connections with densities is  -0.018583133729298918
Correlation of connections with diversities is  -0.028214478944084118

AGENT =  random_walk
2023-05-02 00:31:22.287055
BoardName.LSYSTEMS
5000 Evaluations: Average connections for lsystems_standard = 4.9632, std=0.20553773376195428
2023-05-02 00:37:16.381361
5000 Evaluations: Total average connections for random_walk = 4.9632, std=0.20553773376195428
CORRELATIONS
Correlation of connections with detours is  -0.03418112837273609
Correlation of connections with keys is  -0.0017943014669556208
Correlation of connections with episodes is  -0.020961467600573856
Correlation of connections with densities is  -0.06580300202526061
Correlation of connections with diversities is  -0.03744817817784907

AGENT =  numberlink
2023-05-02 00:37:16.809231
BoardName.LSYSTEMS
5000 Evaluations: Average connections for lsystems_standard = 4.9334, std=0.2663915163814343
2023-05-02 00:43:35.281696
5000 Evaluations: Total average connections for numberlink = 4.9334, std=0.2663915163814343
CORRELATIONS
Correlation of connections with detours is  -0.018604440193449458
Correlation of connections with keys is  -0.018489353858393426
Correlation of connections with episodes is  -0.008100066280397018
Correlation of connections with densities is  -0.023229531459681183
Correlation of connections with diversities is  -0.04347433663018336

AGENT =  lsystems
2023-05-02 00:43:35.727887
BoardName.LSYSTEMS
5000 Evaluations: Average connections for lsystems_standard = 4.8818, std=0.3690918042980635
2023-05-02 00:50:09.011681
5000 Evaluations: Total average connections for lsystems = 4.8818, std=0.3690918042980635
CORRELATIONS
Correlation of connections with detours is  -0.019051585907028455
Correlation of connections with keys is  0.00023288046099994157
Correlation of connections with episodes is  -0.011855168147067333
Correlation of connections with densities is  -0.0912969748602712
Correlation of connections with diversities is  -0.05610474760955783

# Save GIF

IT'S FASTER WITHOUT THE JAX.JIT() CALLS IN THE WIRE EXTENSION

JAX.JIT() CALL APPLIED TO BOTH THE EXTENSION AND PIN EXTRACTION
2023-04-18 23:17:35.299726
BoardName.RANDOM_WALK
1000 Evaluations: Average connections for random_walk = 4.514, std=0.6795616234014396
2023-04-18 23:22:00.377340

JAX.JIT() CALL APPLIED TO JUST THE PIN EXTRACTION
2023-04-18 23:22:53.809285
BoardName.RANDOM_WALK
1000 Evaluations: Average connections for random_walk = 4.486, std=0.6721636705446077
2023-04-18 23:24:32.408968

JAX.JIT() CALL APPLIED TO NEITHER
2023-04-18 23:25:10.222635
BoardName.RANDOM_WALK
1000 Evaluations: Average connections for random_walk = 4.484, std=0.655548625198772
2023-04-18 23:26:47.720919



In [None]:
#env.animate(states, interval=150).save("./connector.gif")

# Save PNG

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

#env.render(states[-1])
state.grid = solved_board
env.render(state)
#plt.savefig("/home/randy/Downloads/randomseed_extended_jax00.png", dpi=300)
