In [None]:
import json
import websockets, asyncio
import threading

class WaitableQueue(asyncio.Queue):
    def __init__(self):
        super().__init__()
        self.event = threading.Event()

    def put(self, item):
        super().put_nowait(item)
        self.event.set()

    def get(self,timeout=3):
        if self.event.wait(timeout):
            res = super().get_nowait()
            if super().empty():
                self.event.clear()
            return res
        else:
            raise TimeoutError("Environement is not responding.")

# the server
class Server:
    def __init__(self):
        self.inQueue = WaitableQueue()
        self.outQueue = WaitableQueue()
        self.debug = True
        self.ws = None

    def start(self):
        threading.Thread(target=self.message_sender_loop).start()
        asyncio.run(self.main())

    async def main(self):
        try:
            async with websockets.serve(self.echo, "localhost", 8765):
                await asyncio.Future()  # run forever
        except websockets.exceptions.ConnectionClosedError as e: print(e)

    async def echo(self,websocket):
        self.ws = websocket
        print('connect')
        #asyncio.create_task(self.message_sender_loop())
        async for message in websocket:
            try:
                self.recv(json.loads(message))
            except json.decoder.JSONDecodeError:
                self.recv(message)

    def recv(self,message):
        self.inQueue.put(message)
        if self.debug:
            print("recv: ",message)
    
    def send(self,command:str, content):
        self.outQueue.put({'command':command,'content':content})

    def message_sender_loop(self):
        while True:
            try:
                message = self.outQueue.get(None)
                asyncio.run(self.ws.send(json.dumps(message, indent=4)))
            except websockets.exceptions.ConnectionClosedError:
                print("Connection closed")
                break
            except Exception as e:
                print(e)
                break

# start the server in a separate thread to avoid blocking
import threading
server = Server()
t=threading.Thread(target=server.start)
t.start()

# the interface to the server
class WSManager:
    def __init__(self,server:Server):
        self.debug = False
        self.server = server

#server.send("action",{"voltage":[1,0,0,0,100,200,100,100]})

Exception in thread Thread-5:
Traceback (most recent call last):
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\a931e\AppData\Local\Temp\ipykernel_18808\63983185.py", line 33, in start
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\asyncio\runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\asyncio\base_events.py", line 642, in run_until_complete
    return future.result()
  File "C:\Users\a931e\AppData\Local\Temp\ipykernel_18808\63983185.py", line 37, in main
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\server.py", line 1068, in __aenter__
    return await self
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\server.py", line 1086, in __await_impl__
    server = await sel

In [2]:
import numpy as np
import torch
def flatten(list_of_lists):
    if len(list_of_lists) == 0:
        return list(list_of_lists)
    if hasattr(list_of_lists[0], '__iter__'):
        return flatten(list_of_lists[0]) + flatten(list_of_lists[1:])
    return list(list_of_lists[:1]) + flatten(list_of_lists[1:])
def decomposeCosSin(angle):
    return [np.cos(angle), np.sin(angle)]
def processFeature(state:dict,targetPos):
    feature = []
    feature.append(state['baseLinkPos']['x']-targetPos[0].item())
    feature.append(state['baseLinkPos']['y']-targetPos[1].item())
    feature.append(decomposeCosSin(state['baseLinkOrientation']))
    feature.append(state['baseLinkVelocity']['x'])
    feature.append(state['baseLinkVelocity']['y'])
    feature.append(state['baseLinkAngularVelocity'])
    feature.append(decomposeCosSin(state['wheelBaseOrientation']))
    feature.append(state['wheelSpeed'])
    feature = flatten(feature)
    return feature

In [127]:
from torch import nn
class Q(nn.Module):
    def __init__(self,state_size,action_size,hidden_size):
        super(Q, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.fc = nn.Sequential(
        nn.Linear(state_size+action_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,1)
        )

    def forward(self,state,action):
        return self.fc(torch.cat([state,action],dim=1))

class Policy(nn.Module):
    def __init__(self,state_size,action_size,hidden_size):
        super(Policy, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.fc = nn.Sequential(
        nn.Linear(state_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,hidden_size),
        nn.LeakyReLU(),
        nn.Linear(hidden_size,action_size)
        )

    def forward(self,state):
        return self.fc(state)

import random, ou
class Environment:
    def __init__(self,ws_server : Server,device = 'cpu'):
        self.ws = ws_server
        self.replayBuffer = []
        self.t = 0
        self.t_episode = 0
        self.device = device
        self.prevState = None
        self.prevAction = None
        self.pos = None
        self.targetPos = None
        self.ouNoise = ou.ND_OUNoise(8,0, 0.1, 0.2, 0, 100)
        self.noiseIntensity = 0.5
        self.targetRelPos = torch.tensor([0.,-3.])

    def restartEpisode(self):
        self.targetPos = self.pos + self.targetRelPos
        self.t_episode = 0
        self.prevState = None
        self.ws.send("new target",{"pos":{'x':self.targetPos[0].item(),'y':0, 'z':self.targetPos[1].item()}})

    def calculateReward(self,pos,targetPos):
        return -torch.dist(pos,targetPos)

    def terminateCondition(self,pos,targetPos):
        return torch.dist(pos,targetPos)<0.5 or torch.dist(pos,targetPos)>10

    def getPos(self,state):
        return torch.tensor([state['baseLinkPos']['x'],state['baseLinkPos']['y']],dtype=torch.float32)

    def update(self, policy: torch.nn.Module):
        raw_state = None
        while not server.inQueue.empty():
            message = server.inQueue.get()
            if message['command'] == 'state':
                raw_state = message['content']
        if raw_state:
            # If the environment returns a state, the step is finnished.
            self.pos = self.getPos(raw_state)
            if self.t_episode > 100 or self.t == 0 or self.terminateCondition(self.pos,self.targetPos):
                self.restartEpisode()
            state = torch.tensor(processFeature(raw_state,self.targetPos),dtype=torch.float32).to(self.device)
                
            # Add the experience to the replay buffer.
            if self.t_episode > 0: # Skip the first step.
                reward = self.calculateReward(self.pos,self.targetPos)- self.calculateReward(self.prevPos,self.targetPos)#-(torch.max(torch.zeros_like(self.prevAction),(torch.abs(self.prevAction)-2000))**2).mean()*0.001
                self.replayBuffer.append((state,self.prevAction,reward,self.prevState))
                if len(self.replayBuffer) > 5000:
                    self.replayBuffer.pop(random.randint(0,len(self.replayBuffer)-1))
            
            # Give the new action to enable the environment to continue on the next step.
            with torch.no_grad():
                policy.eval()
                action = policy(state).detach().cpu()
                action += self.ouNoise.__next__()*self.noiseIntensity
                #action[5]=action[6]=action[7]=action[4]
                action = torch.clamp(action,-4000,4000)
            self.ws.send("action",{"voltage":list(action.detach().numpy().tolist())})

            
            self.t+=1
            self.t_episode += 1
            self.prevState = state
            self.prevAction = action
            self.prevPos = self.getPos(raw_state)

    def sampleExperience(self,batch_size):
        ns,a,r,s = zip(*random.sample(self.replayBuffer,batch_size))
        return torch.stack(ns),torch.stack(a),torch.stack(r),torch.stack(s)
#env = Environment(server,device)


In [128]:
device = 'cuda'
env = Environment(server,device)
tau = 0.001
gamma = 0.9
batch_size = 128
q = Q(state_size=19,action_size=8,hidden_size=512)
q_target = Q(state_size=19,action_size=8,hidden_size=512)
policy = Policy(state_size=19,action_size=8,hidden_size=512)
policy_target = Policy(state_size=19,action_size=8,hidden_size=512)

In [129]:
q = Q(state_size=19,action_size=8,hidden_size=512)

In [130]:
optimQ = torch.optim.Adam(q.parameters(),lr=0.001)
optimPolicy = torch.optim.Adam(policy.parameters(),lr=0.0001)

In [138]:
from torch.nn import functional as F
def soft_update_target(target:nn.Module, source:nn.Module,tau):
    for t, s in zip(target.parameters(), source.parameters()):
        t.data.copy_(
            (1. - tau) * t.data + tau * s.data)

server.debug = False
q.train()
policy.train()
q_target.eval()
policy_target.eval()

q.to(device)
policy.to(device)
q_target.to(device)
policy_target.to(device)

policy_loss = torch.tensor(torch.nan)

import time
# Fill the replay buffer with random experiences.
while len(env.replayBuffer) < batch_size+1:
    env.update(policy)
    time.sleep(0.02)

# Training.
for t in range(100000000):
    env.update(policy)

    new_state, action, reward, old_state = env.sampleExperience(batch_size)
    new_state = new_state.to(device)
    old_state = old_state.to(device)
    action = action.to(device)
    reward = reward.to(device)

    q_target.eval()
    policy_target.eval()
    
    with torch.no_grad():
        action_ = policy_target(new_state)
        torch.clamp(action_,-4000,4000)
        new_value = q_target(new_state,action_).detach()
        target_value = reward.unsqueeze(1) #+ gamma*new_value
    
    #target_value = reward.unsqueeze(1)
    q.train()
    policy.train()

    # Update the Q network.
    q_loss = F.mse_loss(q(old_state,action),target_value)
    optimQ.zero_grad()
    q_loss.backward()
    optimQ.step()
    
    if q_loss.item()<50:
        # Update the policy network.
        q.eval()
        action = policy(old_state)
        voltage_penalty = (torch.max(torch.zeros_like(action),(torch.abs(action)-2000))**2).mean()*1
        policy_loss = -q(old_state,action).mean() + voltage_penalty
        optimPolicy.zero_grad()
        policy_loss.backward()
        optimPolicy.step()
    
    # Update the target networks.
    soft_update_target(q_target,q,tau)
    #soft_update_target(policy_target,policy,tau)
    
    if t % 1000 == 0:
        print(f"q loss: {q_loss.item()}, policy loss: {policy_loss.item()}")
    


Connection closed
q loss: 0.0011696324218064547, policy loss: -0.5803424119949341
connect
q loss: 0.002539626555517316, policy loss: -0.4132908582687378
q loss: 0.0036600064486265182, policy loss: -5.776938438415527
q loss: 0.0032082637771964073, policy loss: -6.259049892425537
q loss: 0.00295313261449337, policy loss: -6.366705417633057
q loss: 0.0024401224218308926, policy loss: -8.23759651184082
q loss: 0.0021614395081996918, policy loss: -7.441329002380371
q loss: 0.00507876044139266, policy loss: -8.41982650756836


connection handler failed
Traceback (most recent call last):
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\protocol.py", line 944, in transfer_data
    message = await self.read_message()
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\protocol.py", line 1013, in read_message
    frame = await self.read_data_frame(max_size=self.max_size)
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\protocol.py", line 1089, in read_data_frame
    frame = await self.read_frame(max_size)
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\protocol.py", line 1144, in read_frame
    frame = await Frame.read(
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\site-packages\websockets\legacy\framing.py", line 70, in read
    data = await reader(2)
  File "c:\Users\a931e\Anaconda3\envs\nn\lib\asyncio\streams.py", line 721, in readexactly
    raise exceptions.IncompleteReadError(incomplete, n)
asyncio.

connect
q loss: 0.0018345353892073035, policy loss: -7.121649265289307
q loss: 0.005829949863255024, policy loss: -1.8156616687774658
q loss: 0.0025570630095899105, policy loss: -6.7298970222473145
q loss: 0.00744173489511013, policy loss: -5.754608631134033
q loss: 0.0029427488334476948, policy loss: -5.638504505157471
q loss: 0.0016132415039464831, policy loss: -6.957573890686035
q loss: 0.0015990696847438812, policy loss: -6.50506591796875
q loss: 0.0071762800216674805, policy loss: -0.7182192802429199
q loss: 0.0026863086968660355, policy loss: -2.2355542182922363
q loss: 0.00300349248573184, policy loss: -2.094534397125244
q loss: 0.0021662903018295765, policy loss: -2.0946590900421143


KeyboardInterrupt: 

In [136]:
env.noiseIntensity = 0.1
env.targetRelPos = torch.tensor([0,3],dtype=torch.float32)

In [106]:
new_state, action, reward, old_state = env.sampleExperience(batch_size)
new_state = new_state.to(device)
old_state = old_state.to(device)
action = action.to(device)
reward = reward.to(device)
print(q(old_state,action)[0])
action = policy(old_state)
print(q(old_state,action)[0])
print(reward[0])

tensor([969.1476], device='cuda:0', grad_fn=<SelectBackward0>)
tensor([2144.1477], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(-0.0388, device='cuda:0')


In [144]:
server.ws.send({"voltage":list(action.detach().cpu().numpy().tolist())})

<coroutine object WebSocketCommonProtocol.send at 0x0000014D8CB5C640>

In [97]:
new_state, action, reward, old_state = env.sampleExperience(batch_size)
reward


tensor([-0.2141, -0.2820, -0.2855, -0.3078, -0.1516, -0.0655, -0.3429, -0.2886,
        -0.3183, -0.1440, -0.1253, -0.3222, -0.2555, -0.3395, -0.1388, -0.1747,
        -0.2437, -0.2255,  0.0972, -0.3279, -0.3285, -0.2891, -0.1524, -0.3033,
        -0.3226, -0.1426, -0.2665, -0.1397, -0.2500, -0.3043, -0.3331, -0.1957,
        -0.0817, -0.1915, -0.2449,  0.0132, -0.2816, -0.2662, -0.2573, -0.2298,
        -0.1589, -0.2701, -0.3098, -0.0501, -0.0601, -0.3006, -0.2713, -0.0870,
        -0.2514, -0.3083, -0.1783, -0.2688, -0.2306, -0.2980, -0.1285,  0.0586,
        -0.3285, -0.2772, -0.1705, -0.1688, -0.2943, -0.1653, -0.2915, -0.3086,
        -0.2993, -0.1145, -0.0723, -0.3109, -0.1668, -0.3018, -0.1963, -0.2858,
        -0.3372, -0.2997, -0.3229, -0.3214, -0.2191, -0.3148, -0.3162, -0.2828,
         0.1556,  0.0275, -0.2177, -0.3105, -0.2920, -0.2034, -0.2636, -0.1015,
        -0.3027,  0.1306, -0.2835, -0.3080,  0.0337, -0.1594, -0.2613, -0.2097,
        -0.3227, -0.1678, -0.2023, -0.33

In [16]:
server.outQueue

<WaitableQueue at 0x131ce69a580 maxsize=0 _queue=[{'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}, {'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}, {'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}, {'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}, {'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}, {'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}, {'command': 'action', 'content': {'voltage': [0, 0, 0, 0, 0, 0, 0, 0]}}] tasks=8>