In [1]:
from paho.mqtt import client as mqtt_client
import numpy as np
import time
from datetime import datetime
import matplotlib.pyplot as plt
import pickle
import os
import json
from utils import *

import sys
import time
from IPython.display import display, clear_output
import torch
from envs import OptimizationEnv
from sac import SoftActorCritic
from rlutils import ReturnTracker, ReplayBuffer

In [2]:
broker = 'localhost'  # 'broker.hivemq.com'
port = 1883
client_id = 'control-secondary'
username = 'fwagner'
password = '1234'

In [3]:
channel = 0
buffer_size = 10000

path_models = 'models/'
path_buffer = 'data/'

In [4]:
def receive_and_respond(client, userdata, msg):
    
    try:
        if 'acknowledge' in msg.topic:
            data = json.loads(msg.payload)
            userdata['action'] = np.array([data["Action 0"], data["Action 1"]])
            print('acknowledge received and action set to {}'.format(userdata['action']))
            
        elif 'events' in msg.topic:
            clear_output(wait=True)

            # get data
            data = json.loads(msg.payload)
            print('message received: ', data)

            # calc state, reward
            new_state = np.array([data["new_state 0"], data["new_state 1"]])
            reward = float(data["reward"])
            terminated = data["terminated"]
            truncated = data["truncated"]

            # write to buffer
            if not userdata['greedy']:
                userdata['buffer'].store_transition(state = userdata['state'], 
                                              action = userdata['action'],  # should we give new state here?
                                              reward = reward, 
                                              next_state = new_state, 
                                              terminal = terminated)
                print('buffer total: ', userdata['buffer'].buffer_total)

            # update state
            userdata['state'] = new_state

            # get new action
            if buffer.buffer_total > userdata['learning_starts']:
                userdata['agent'] = SoftActorCritic.load(userdata['env'], userdata['path_models'])
                action, _ = userdata['agent'].predict(state, greedy=userdata['greedy'])
                greedy_action, greedy_likelihood = userdata['agent'].predict(state, greedy=True)
                print('greedy action is: {}, with likelihood: {}'.format(greedy_action, np.exp(greedy_likelihood)))
            else:
                action = env.action_space.sample().reshape(1,-1)
                print('Taking random action.')

            # respond
            payload_response = {
                "Action 0": float(action[0,0]),
                "Action 1": float(action[0,1]),
            }

            # plot 
            print('message with greedy={} respond: {}'.format(userdata['greedy'], payload_response))

            result = client.publish('control/channel_{}/set_control'.format(channel), json.dumps(payload_response))
            check(result)
            
        else:
            print('Message topic unknown: ', msg.topic)
        
    except KeyError as err_msg:
        print('KeyError: ', err_msg)
        pass

In [5]:
env = OptimizationEnv(reset_params=False)
state, info = env.reset(new_params=False)
action = env.action_space.sample()

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [6]:
buffer = ReplayBuffer(buffer_size=buffer_size, input_shape=(env.observation_space.shape[0],), n_actions=env.action_space.shape[0], memmap_loc=path_buffer)

In [7]:
agent = SoftActorCritic.load(env, path_models)

In [8]:
buffer.erase()

In [8]:
userdata = {'agent': agent,
            'env': env,
            'state': state,
            'action': action,
            'buffer': buffer,
            'learning_starts': 64, 
            'path_models': path_models,
            'greedy': False,
           }

client = connect_mqtt(broker, port, client_id, username, password, userdata = userdata)

In [9]:
subscribe(client, 'daq/channel_{}/events'.format(channel))
subscribe(client, 'daq/channel_{}/acknowledge'.format(channel))

In [10]:
client.on_message = receive_and_respond

In [11]:
channel_info = {"SubscribeToChannel": channel}
result = client.publish('control/metainfo', json.dumps(channel_info))
check(result)

In [12]:
userdata['greedy'] = True

In [13]:
client.loop_forever()

message received:  {'new_state 0': 0.7635462284088135, 'new_state 1': -0.9469873905181885, 'reward': 0.0010628321421555207, 'terminated': False, 'truncated': False}
greedy action is: [[ 0.7635462 -0.9469874]], with likelihood: [[26.070961]]
message with greedy=True respond: {'Action 0': 0.7635462284088135, 'Action 1': -0.9469873905181885}
acknowledge received and action set to [ 0.76354623 -0.94698739]


KeyboardInterrupt: 

In [14]:
for p in userdata['agent'].policy.parameters():
    print(p.shape)
    if len(p.shape) == 0:
        print(p.item())
    elif len(p.shape) == 1:
        print(p[0].item())
    elif len(p.shape) == 2:
        print(p[0,0].item())

torch.Size([256, 2])
-0.0024986122734844685
torch.Size([256])
-0.04798027500510216
torch.Size([256, 256])
-5.377599444878717e-38
torch.Size([256])
-0.005090019199997187
torch.Size([2, 256])
-0.019565533846616745
torch.Size([2])
0.563148021697998
torch.Size([2, 256])
0.04247375577688217
torch.Size([2])
-0.5413356423377991
