In [1]:
import json
import websockets, asyncio
import threading

class WaitableQueue(asyncio.Queue):
    def __init__(self):
        super().__init__()
        self.event = threading.Event()

    def put(self, item):
        super().put_nowait(item)
        self.event.set()

    def get(self,timeout=3):
        if self.event.wait(timeout):
            res = super().get_nowait()
            if super().empty():
                self.event.clear()
            return res
        else:
            raise TimeoutError("Environement is not responding.")

# the server
class Server:
    def __init__(self):
        self.inQueue = WaitableQueue()
        self.outQueue = WaitableQueue()
        self.debug = True
        self.ws = None

    def start(self):
        threading.Thread(target=self.message_sender_loop).start()
        asyncio.run(self.main())

    async def main(self):
        try:
            async with websockets.serve(self.echo, "localhost", 8765):
                await asyncio.Future()  # run forever
        except websockets.exceptions.ConnectionClosedError as e: print(e)

    async def echo(self,websocket):
        self.ws = websocket
        print('connect')
        #asyncio.create_task(self.message_sender_loop())
        async for message in websocket:
            try:
                self.recv(json.loads(message))
            except json.decoder.JSONDecodeError:
                self.recv(message)

    def recv(self,message):
        self.inQueue.put(message)
        if self.debug:
            print("recv: ",message)
    
    def send(self,command:str, content):
        self.outQueue.put({'command':command,'content':content})

    def message_sender_loop(self):
        while True:
            message = self.outQueue.get(None)
            asyncio.run(self.ws.send(json.dumps(message, indent=4)))

# start the server in a separate thread to avoid blocking
import threading
server = Server()
t=threading.Thread(target=server.start)
t.start()

# the interface to the server
class WSManager:
    def __init__(self,server:Server):
        self.debug = False
        self.server = server

#server.send("action",{"voltage":[1,0,0,0,100,200,100,100]})

In [2]:
import numpy as np
import torch
def flatten(list_of_lists):
    if len(list_of_lists) == 0:
        return list(list_of_lists)
    if hasattr(list_of_lists[0], '__iter__'):
        return flatten(list_of_lists[0]) + flatten(list_of_lists[1:])
    return list(list_of_lists[:1]) + flatten(list_of_lists[1:])
def decomposeCosSin(angle):
    return [np.cos(angle), np.sin(angle)]
def processFeature(state:dict,targetPos):
    feature = []
    feature.append(state['baseLinkPos']['x']-targetPos[0])
    feature.append(state['baseLinkPos']['y']-targetPos[1])
    feature.append(decomposeCosSin(state['baseLinkOrientation']))
    feature.append(state['baseLinkVelocity']['x'])
    feature.append(state['baseLinkVelocity']['y'])
    feature.append(state['baseLinkAngularVelocity'])
    feature.append(decomposeCosSin(state['wheelBaseOrientation']))
    feature.append(state['wheelSpeed'])
    feature = flatten(feature)
    return feature

In [3]:
from torch import nn
class Q(nn.Module):
    def __init__(self,state_size,action_size,hidden_size):
        super(Q, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.fc = nn.Sequential(
        nn.Linear(state_size+action_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,1)
        )

    def forward(self,state,action):
        return self.fc(torch.cat([state,action],dim=1))

class Policy(nn.Module):
    def __init__(self,state_size,action_size,hidden_size):
        super(Policy, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.fc = nn.Sequential(
        nn.Linear(state_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,action_size)
        )

    def forward(self,state):
        return self.fc(state)

In [4]:

q = Q(state_size=19,action_size=8,hidden_size=64)
policy = Policy(state_size=19,action_size=8,hidden_size=64)
targetPos = [0,0]

In [47]:
server.debug = False
for t in range(100):
    state = processFeature(server.inQueue.get()['content'],targetPos)
    print(f'step {t} start')
    action = policy(torch.tensor(state,dtype=torch.float))*5000
    server.send("action",{"voltage":list(action.detach().numpy().tolist())})

step 0 start
step 1 start
step 2 start
step 3 start
step 4 start
step 5 start
step 6 start
step 7 start
step 8 start
step 9 start
step 10 start
step 11 start
step 12 start
step 13 start
step 14 start
step 15 start
step 16 start
step 17 start
step 18 start
step 19 start
step 20 start
step 21 start
step 22 start
step 23 start
step 24 start
step 25 start
step 26 start
step 27 start
step 28 start
step 29 start
step 30 start
step 31 start
step 32 start
step 33 start
step 34 start
step 35 start
step 36 start
step 37 start
step 38 start
step 39 start
step 40 start
step 41 start
step 42 start
step 43 start
step 44 start
step 45 start
step 46 start
step 47 start
step 48 start
step 49 start
step 50 start
step 51 start
step 52 start
step 53 start
step 54 start
step 55 start
step 56 start
step 57 start
step 58 start
step 59 start
step 60 start
step 61 start
step 62 start
step 63 start
step 64 start
step 65 start
step 66 start
step 67 start
step 68 start
step 69 start
step 70 start
step 71 start
st

In [7]:
server.inQueue

<WaitableQueue at 0x1c7f7aadf10 maxsize=0 _queue=[{'command': 'state', 'content': {'baseLinkPos': {'x': -0.155428052, 'y': -0.04696183}, 'baseLinkOrientation': 6.28281832, 'baseLinkVelocity': {'x': 0.0172630455, 'y': -0.00125027169}, 'baseLinkAngularVelocity': -0.000331740855, 'wheelBaseOrientation': [3.00465751, 3.28057742, 3.01581454, 3.11644673], 'wheelSpeed': [-0.03694538, -0.0458802022, -0.03772101, -0.04005027]}}] tasks=2>

In [45]:

server.send("action",{"voltage":[10,10,10,10,-100,-200,-100,-100]})

In [13]:
policy(torch.tensor(state,dtype=torch.float))*10000

tensor([  458.5170,  -498.8667,  1170.6143,  1094.3835, -1245.0892, -1157.3398,
          691.8254,  -910.0997], grad_fn=<MulBackward0>)