In [None]:
import socketio
import urllib
import requests
from time import sleep
from IPython.display import display

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

import torchvision
from torchvision.io import decode_png

from infant import Environment

In [None]:
device = 'cpu'

In [None]:
BASE_URL = 'http://localhost:3000'

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(3*400*600, 512),
            nn.ReLU(),
            nn.Linear(512, 3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
def create_responder(model):
    def responder(state: torch.Tensor, reward: torch.Tensor, env: Environment):
        """Get state and returns action"""
        try:
            image = state * 1./255
            images = torch.unsqueeze(image, dim=0)
            X = torch.rand(image.shape) * images * 10

            actions = model(X)
            
            # post-process actions
            actions = torch.squeeze(actions) * 20
            return actions
        except Exception as e:
            print(e)
            env.stop()
            return torch.ones((3,))

    return responder

In [None]:
model = NeuralNetwork().to(device)
print(model)

In [None]:
env = Environment()
env.create()

In [None]:
env.on_state(create_responder(model))
env.connect(close_on_stop=True)
env.init(max_episodes=100)
env.wait()

In [None]:
env.destroy()

In [None]:
envs = [Environment() for _ in range(2)]

print('Initializing...')
for env in envs:
    env.create()
    model = NeuralNetwork().to(device)
    env.on_state(create_responder(model))

print('Connecting...')
for env in envs:
    env.connect()

print('Starting...')
for env in envs:
    env.init(max_episodes=10)

print('Working...')
for env in envs:
    env.wait()

print('Cleaning...')
for env in envs:
    env.destroy()