## Notebook controller to update the state of Simulation


In [16]:
import time
import socket
import pickle

import numpy as np
import jax.numpy as jnp
from flax import serialization

from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.utils.network import SERVER

In [17]:
PORT = 5050
ADDR = (SERVER, PORT)
DATA_SIZE = 40000
EVAL_TIME = 10

color_map = {"r": (1.0, 0.0, 0.0),
             "g": (0.0, 1.0, 0.0),
             "b": (0.0, 0.0, 1.0)}

In [18]:
# Start the server and intialize connection

def connect_client():
    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    client.connect(ADDR)
    print(f"Connected to {ADDR}")

    msg = client.recv(1024).decode()
    state_example = pickle.loads(client.recv(DATA_SIZE))
    state_bytes_size = len(serialization.to_bytes(state_example))
    response = "NOTEBOOK"
    client.send(response.encode())
    time.sleep(1)

    return client, state_example, state_bytes_size

In [19]:
def close_client(client):
    client.send("CLOSE_CONNECTION".encode())

def get_state(client, state_example, state_bytes_size):
    client.send("GET_STATE".encode())
    response = client.recv(state_bytes_size)
    return serialization.from_bytes(state_example, response)

def change_state_agent_color(state, idx, color):
    colors = np.array(state.colors)
    colors[idx] = color_map[color]
    state = state.replace(colors=colors) 
    return state

def set_color(client, agent_idx, color, state_example, state_bytes_size):
    client.send("SET_STATE".encode())
    recv_state =  client.recv(state_bytes_size)
    current_state = serialization.from_bytes(state_example, recv_state)
    response_state = change_state_agent_color(current_state, agent_idx, color)
    client.send(serialization.to_bytes(response_state))

    return 

# def add_agent(client, agent_idx):
#     client.send(f"ADD_AGENT,{agent_idx}".encode())

def pause(client):
    client.send("PAUSE".encode())

def resume(client):
    client.send("RESUME".encode())

def stop(client):
    client.send("STOP".encode())

def start(client):
    client.send("START".encode())

In [20]:
client, state_example, state_bytes_size = connect_client()

Connected to ('localhost', 5050)


In [21]:
pause(client)

In [22]:
resume(client)

In [55]:
state = get_state(client, state_example, state_bytes_size)

In [23]:
# Apply this function with whatever agent_idx or color
set_color(client, 1, 'b', state_example, state_bytes_size)

sent set state
len_received_state521
sent SimpleSimState(time=array(866, dtype=int32), grid_size=array(20, dtype=int32), alive=array([1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], dtype=float32), x_pos=array([ 1.4000001, 11.       , 18.699999 , 15.2999935,  2.3      ,
        6.6999984,  7.2      , 17.399994 ,  0.4      ,  1.0000001],
      dtype=float32), y_pos=array([ 9.600005 ,  0.       , 12.399982 ,  9.900005 ,  9.100002 ,
        5.7999964,  7.999998 ,  9.899984 ,  7.3999963, 11.500006 ],
      dtype=float32), obs=array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32), colors=array([[1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.]], dtype=float32))
sent encoded b'\x8

In [15]:
close_client(client)

In [None]:
# start(client)
# stop(client)