In [None]:
import socketio
import urllib
import requests
from time import sleep
from IPython.display import display

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

import torchvision
from torchvision.io import decode_png

In [None]:
device = 'cpu'

In [None]:
BASE_URL = 'http://localhost:3000'

In [None]:
class SimNs(socketio.ClientNamespace):
    def on_connect(self):
        print('Connected')
    def on_disconnect(self):
        print('Disconnected')

class Environment:
    def __init__(self, sim_id:str=None, base_url=BASE_URL):
        self.base_url = base_url
        self.sim_id = sim_id
        self._is_stop = False
        self._created_sims = []
        self._on_stop = None

        self._sock = socketio.Client()
        self.wait = self._sock.wait

        if sim_id:
            self._register_namespace()

        self._episode_ct = 0 
        self._max_episodes = None
    
    def _register_namespace(self):
        self._ns = SimNs(f'/sim-{self.sim_id}')
        self._sock.register_namespace(self._ns)

    def create(self):
        '''Create new simulation'''
        response = requests.post(f'{self.base_url}/sims').json()
        self.sim_id = response['id']
        self._created_sims.append(self.sim_id)
        self._register_namespace()
    
    def connect(self, close_on_stop=True):
        '''Connect to socket.io server'''
        self._sock.on('connect', lambda: print(f'Connected to {self._ns.namespace}'))
        self._sock.connect(self.base_url)

        if close_on_stop:
            self._on_stop = lambda: self.close()

    def init(self, max_episodes=None):
        '''Send initialization signal to simulation server. 
        In other words, init simulation-loop.'''
        self._is_stop = False
        self._max_episodes = max_episodes
        self._sock.emit('sim:init', namespace=self._ns.namespace)

    def close(self):
        '''Close socket.io connection'''
        self._sock.disconnect()

    def on_state(self, on_state):
        '''Attach simulation-loop
        
        Parameters:
            on_state: Function that takes `state`, `reward`, `env` and return `action`.
        
        '''
        
        def on_render(data_uri, reward):
            '''Thin wrapper to decode data_uri and send encoded actions.'''
            try:
                # Convert data_uri to state
                res = urllib.request.urlopen(data_uri)
                raw_image:bytes = res.file.read()
                tensor = torch.frombuffer(raw_image, dtype=torch.uint8)
                state = decode_png(tensor)[:3]  # iqnore last alpha channel as it always full

                # Calls user-defined on_state function
                actions:torch.Tensor = on_state(state, reward, self)

                # Increment episode count
                self._episode_ct += 1
                if (self._episode_ct >= self._max_episodes):
                    self._is_stop = True

                # Intercept the loop and stop it
                if self._is_stop:
                    self.stop()
                    return None

                # Here goes the loop
                action_dict = dict(zip(('x', 'y', 'down'), map(float, actions)))
                self._sock.emit('sim:action', action_dict, self._ns.namespace)

            except Exception as e:
                print(f'An exception occured: {e}')
                self._is_stop = True
                return None

        self._sock.on('sim:render', on_render, self._ns.namespace)
    
    def on_stop(self, on_stop):
        '''Attach custom simulation-loop stop callback'''
        self._on_stop = on_stop
    
    def stop(self, stop=True):
        '''Set stop flag to stop simulation-loop'''
        self._is_stop = stop
        if self._on_stop:
            self._on_stop()

    def destroy(self):
        '''Destroy any sim created by this environment'''
        for sim_id in self._created_sims:
            requests.delete(f'{self.base_url}/sims/{sim_id}')
        self._created_sims = []

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(3*400*600, 512),
            nn.ReLU(),
            nn.Linear(512, 3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
def create_responder(model):
    def responder(state: torch.Tensor, reward: torch.Tensor, env: Environment):
        """Get state and returns action"""
        try:
            image = state * 1./255
            images = torch.unsqueeze(image, dim=0)
            X = torch.rand(image.shape) * images * 10

            actions = model(X)
            
            # post-process actions
            actions = torch.squeeze(actions) * 20
            return actions
        except Exception as e:
            print(e)
            env.stop()
            return torch.ones((3,))

    return responder

In [None]:
model = NeuralNetwork().to(device)
print(model)

In [None]:
env = Environment()
env.create()

In [None]:
env.on_state(create_responder(model))
env.connect(close_on_stop=True)
env.init(max_episodes=100)
env.wait()

In [None]:
env.destroy()

In [None]:
envs = [Environment() for _ in range(2)]

print('Initializing...')
for env in envs:
    env.create()
    model = NeuralNetwork().to(device)
    env.on_state(create_responder(model))

print('Connecting...')
for env in envs:
    env.connect()

print('Starting...')
for env in envs:
    env.init(max_episodes=10)

print('Working...')
for env in envs:
    env.wait()

print('Cleaning...')
for env in envs:
    env.destroy()