In [None]:
global_var = {

    # Connection
    'ip_address': '127.0.0.1',
    'port_input': 8280,
    'port_output': 8281,
    'buffer_size': 32,
    'connection_dim_input': 4,
    'connection_dim_output': 1,
    'stop_flag': [-999, -999, -999, -999], 
    'use_pid': True,

    # Network 
    'input_dim': 3,
    'output_dim': 3,
    'hidden_dim': 10,
    'bias': False,
    'learning_rate': 0.001,
    'model_name': 'model.pt',

    #Train
    'batch_size': 32,
    'epochs': 10
}

class Color:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'
    MAGENTA = '\033[95m'
    CYAN = '\033[96m'
    WHITE = '\033[97m'
    RESET = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import struct
import socket
import time
from simple_pid import PID



# Architecture

In [None]:
class FeedForwardNet(nn.Module):

    def __init__(self, input_dim, output_dim, hidden_dim, bias ) -> None:
        super(FeedForwardNet, self).__init__()

        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim =  output_dim

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        self.layer_1 = nn.Linear(
            in_features = self.input_dim,
            out_features = self.hidden_dim,
            bias = bias
        )

        self.layer_2 = nn.Linear(
            in_features = self.hidden_dim,
            out_features = self.output_dim,
            bias = bias
        )
    
    def forward(self, x):
        
        x = self.layer_1(x)
        x = self.relu(x)
        x = self.layer_2(x)

        return x

In [None]:
class ControllerNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, lr, bias, model_name ) -> None:
        super(ControllerNetwork, self).__init__()

        self.network = FeedForwardNet(
            input_dim = input_dim,
            output_dim = output_dim,
            hidden_dim = hidden_dim,
            bias = bias
        )

        self.loss_fnc = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=lr)

    def forward(self, x):
        network_output = self.network(x)
        return network_output
    
    def save(self, name):
        torch.save(self.state_dict(), name )
        print(f"Saved: {name}")

    def load(self, name):
        self.load_state_dict(torch.load(name) )
        print(f"loaded: {name}")

In [None]:
class PID_CONTROLLER():
    def __init__(self):
    
        self.error_1_step_back = 0
        self.error_2_step_back = 0
        self.threshold = 0.01   
        self.controller_output = 0

    
    def control(self, Kp, Ki, Kd, current_error):

        if current_error <= self.threshold:
            current_error = 0
            
        proportional_output = Kp * (current_error - self.error_1_step_back)
        integral_output = Ki * current_error
        differential_output = Kd * ( current_error - 2*self.error_1_step_back + self.error_2_step_back )

        delta_controller_output = proportional_output + integral_output + differential_output

        self.error_2_step_back = self.error_1_step_back
        self.error_1_step_back = current_error

        self.controller_output += delta_controller_output

        return self.controller_output

# Utils


In [None]:
def setup_socket(ip_address, port_input):

    print(f"{Color.YELLOW}ip: {ip_address}, port_input: {port_input}{Color.RESET}")

    try:
        socket_nn_input = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        socket_nn_input.bind( (ip_address, port_input) )
        socket_nn_input.listen(1)

        print(f"{Color.BOLD}{Color.GREEN}Sockets listener created{Color.RESET}")

        return socket_nn_input
    except Exception as e:
        print(f"{Color.RED}An error occurred: {e}{Color.RESET}")

        if 'socket_nn_input' in locals():
            socket_nn_input.close()
            
        return None, None
    
def accept_connection(socket_server):
    connection, address = socket_server.accept()
    print(f"{Color.GREEN}Connection accepted!{Color.RESET}")
    return connection, address

def receive_data(connection, buffer_size, dim_input):
    
    expected_bytes = buffer_size  # Size for one double
    data = b''
    while len(data) < expected_bytes:
        more_data = connection.recv(expected_bytes - len(data))

        data += more_data
    data = list(struct.unpack(f'!{str(dim_input)}d', data))  # Unpack one double
    
    return data

def receive_data_excpt(connection, buffer_size, dim_input, stop_flag):
    
    expected_bytes = buffer_size  # Size for one double
    data = b''
    connection.settimeout(5.0)  # Set timeout to 5 seconds

    try:
        while len(data) < expected_bytes:
            more_data = connection.recv(expected_bytes - len(data))
            if not more_data:
                # No more data is available, break the loop
                break
            data += more_data
        try:
            data = list(struct.unpack(f'!{str(dim_input)}d', data)) 
        except Exception as e:
            print(f"\n{e}")
            print(f"{Color.RED}\nProblem with unpacking, error: {e}{Color.RESET}")
            print(f"{Color.RED}May be due because return empty string when nothing is receive for a certain time{Color.RESET}")
            return stop_flag
        
    except Exception as e:

        print(f"\n{e}")
        if isinstance(e, socket.timeout):
            print(f"{Color.RED}\nTimeout error: No data received within the timeout period{Color.RESET}")
        else:
            print(f"{Color.RED}\nOther exception occurred: {e}{Color.RESET}")
            print(f"{Color.RED}Maybe due to some other error in the code{Color.RESET}")
            
        print(f"{Color.RED}May be due because return empty string when nothing is receive for a certain time{Color.RESET}")
        print(f"{Color.BLUE}\nNo data received within the timeout period, maybe some error of code{Color.RESET}")
       
        return stop_flag
    
    connection.settimeout(None)
    return data

def send_data(connection, message, dim_output):
    try:
        message_to_send = struct.pack(f'!{str(dim_output)}d', *message)  # Pack one float
        connection.sendall(message_to_send)

    except Exception as e:
        print(f"Error sending float: {e}")

def close_connections(socket_input, socket_output, receiver, sender):
    receiver.close()
    sender.close()
    socket_input.close()
    socket_output.close()

    print(f"{Color.GREEN}Sockets closed!")
    return



# Dataset

In [None]:
class ControllDataset(Dataset):
    def __init__(self):
        self.data = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        state, label = item[:-1], item[-1]
        return state, label

    def add_data(self, new_data):
        self.data.append(new_data)

# Train

In [None]:
def train(model, dataset, epochs, batch_size, shuffle):

    model.train()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    for e in range(epochs):
        
        loss_value = 0

        for states, targets in dataloader:

            model.optimizer.zero_grad()
            
            targets = targets.unsqueeze(1)
            
            _ = model(states)
            loss = model.loss_fnc( torch.zeros((32,)), targets )
            loss.backward()

            model.optimizer.step()

            loss_value += loss.item()

    model.eval()
    return 

# Main

In [None]:
network = ControllerNetwork(  
                            input_dim=global_var['input_dim'],
                            output_dim=global_var['output_dim'],
                            hidden_dim=global_var['hidden_dim'],
                            bias=global_var['bias'],
                            lr=global_var['learning_rate'],
                            model_name=global_var['model_name']
                            )

pid = PID(2.0, 2.0, 2.0, setpoint=1) # random initialization

In [None]:
receiver_socket = setup_socket(
    global_var['ip_address'], 
    global_var['port_input']
    )

send_socket = setup_socket(
    global_var['ip_address'], 
    global_var['port_output']
    )

print(f"{Color.CYAN}Waiting someone to connect ...{Color.RESET}")
connection_receiver, addr = accept_connection(receiver_socket)
connection_sender, addr = accept_connection(send_socket)


In [11]:

# https://simple-pid.readthedocs.io/en/latest/reference.html

dataset = ControllDataset()
n_received = 0
n_sent = 0
previous_error = 0

print(f"\rreceived: {n_received}, sent; {n_sent}", end="")

# need to send one data in order to initialize 
# block which wait data from python 
send_data(
        connection=connection_sender, 
        message=[0.0], 
        dim_output=global_var['connection_dim_output']
        )

first_sample = True
previous_sample = None

while True:

    #time.sleep(0.05)
    
    raw_data = receive_data_excpt(
        connection=connection_receiver, 
        buffer_size=global_var['buffer_size'],
        dim_input=global_var['connection_dim_input'],
        stop_flag=global_var['stop_flag']
        )
    n_received += 1

    if raw_data[0] == global_var['stop_flag'][0]:
        break

    
    if first_sample == False:
        previous_sample[-1] = raw_data[-1]
        data = torch.tensor(previous_sample, requires_grad=True)
        dataset.add_data(data)

    previous_sample = raw_data
    first_sample = False
    
    if ( len(dataset) % 32 ) == 0 and len(dataset) > 0:
        train(
            model=network,
            dataset=dataset,
            epochs=global_var['epochs'],
            batch_size=global_var['batch_size'],
            shuffle=True
        )
    
    
    network_input = raw_data[:-1]
    error = raw_data[-1]
    pid_parameters = network(torch.tensor([network_input])).detach().numpy()[0]
    pid.tunings = (pid_parameters[0], pid_parameters[1], pid_parameters[2])
    output = [pid(error)]

    send_data(
        connection=connection_sender, 
        message=output, 
        dim_output=global_var['connection_dim_output']
        )
    n_sent += 1
    
    print(f"\rreceived: {n_received}, sent: {n_sent}, n_data: {len(dataset)} ", end="")



received: 756, sent: 756, n_data: 755 