## Install All Packages Needed

In [1]:
import sys
!{sys.executable} -m pip install sockets
!{sys.executable} -m pip install numpy
!{sys.executable} -m pip install igraph
!{sys.executable} -m pip install matplotlib
!{sys.executable} -m pip install pyvis

Collecting sockets
  Using cached sockets-1.0.0-py3-none-any.whl (4.5 kB)
Installing collected packages: sockets
Successfully installed sockets-1.0.0
Collecting numpy
  Downloading numpy-1.23.3-cp310-cp310-win_amd64.whl (14.6 MB)
     ---------------------------------------- 14.6/14.6 MB 1.4 MB/s eta 0:00:00
Installing collected packages: numpy
Successfully installed numpy-1.23.3
Collecting igraph
  Downloading igraph-0.10.1-cp39-abi3-win_amd64.whl (2.9 MB)
     ---------------------------------------- 2.9/2.9 MB 694.0 kB/s eta 0:00:00
Collecting texttable>=1.6.2
  Using cached texttable-1.6.4-py2.py3-none-any.whl (10 kB)
Installing collected packages: texttable, igraph
Successfully installed igraph-0.10.1 texttable-1.6.4
Collecting matplotlib
  Downloading matplotlib-3.6.1-cp310-cp310-win_amd64.whl (7.2 MB)
     ---------------------------------------- 7.2/7.2 MB 1.5 MB/s eta 0:00:00
Collecting pillow>=6.2.0
  Using cached Pillow-9.2.0-cp310-cp310-win_amd64.whl (3.3 MB)
Collecting kiw

## Import Packages

In [None]:
import socket
import time

import numpy  as np
import igraph as ig
import matplotlib.pyplot as plt

plt.rc("text", usetex=True)

%run ./2-ImplementationFactor.ipynb
%run ./3-ImplementationPGM.ipynb

## Environment Configuration

The configuration of the environment include:
- HOST, PORT: Server setting
- ENCODING_METHOD: Encoding method for the message communicate (do not change)
- BUFFER_SIZE: Buffer size for server and client (do not change)
- END, MSG_BREAK, DILIMITER_1, DILIMITER_2, MUL_SIGN, STATEMENT_ID_PREFIX: Setting for the message (do not change)
- MAX_ITR: Maximum iteration run by the loopy belief propagation if it does not converge

In [None]:
# Server setting
HOST = "127.0.0.1"
PORT = 8080

# Encoding method
ENCODING_METHOD = "UTF-8"

# Buffer size for the message
BUFFER_SIZE = pow(2, 20)

# Setting for the message
END = "END"
MSG_BREAK = "BREAK"
DILIMITER_1 = ','
DILIMITER_2 = "&"
MUL_SIGN = '*'
STATEMENT_ID_PREFIX = "S_"

# Maximum iteration
MAX_ITR = 20

In [None]:
class loopy_belief_propagation():
    def __init__(self, pgm):
        if type(pgm) is not factor_graph:
            raise Exception('PGM is not a factor graph')
        if not pgm.is_connected():
            print("[Warning] graph is not connected")
        
        self.__t       = 0
        self.__msg     = {}
        self.__msg_new = {}
        self.__pgm     = pgm
        self.threshold = 1e-4
        
        # Initialization of messages
        # Set all the message to one
        for edge in self.__pgm.get_graph().es:
            start_index, end_index = edge.tuple[0], edge.tuple[1]
            start_name, end_name = self.__pgm.get_graph().vs[start_index]['name'], self.__pgm.get_graph().vs[end_index]['name']
            
            if self.__pgm.get_graph().vs[start_index]['is_factor']:
                self.__msg[(start_name, end_name)] = factor([end_name],   np.array([1.]*self.__pgm.get_graph().vs[end_index]['rank']))
            else:
                self.__msg[(start_name, end_name)] = factor([start_name], np.array([1.]*self.__pgm.get_graph().vs[start_index]['rank']))
            self.__msg[(end_name, start_name)] = self.__msg[(start_name, end_name)]
            
            self.__msg_new[(start_name, end_name)] = 0
            self.__msg_new[(end_name, start_name)] = 0
    
    # Get marginal propability of target variables
    def belief(self, v_names, num_iter):
        if self.__t > num_iter:
            raise Exception('Invalid number of iterations. Current number: ' + str(self.__t))
        elif self.__t < num_iter:
            self.__loop(num_iter)
        
        margProb = {}
        for v_name in v_names:
            
            incoming_messages = []
            for f_name_neighbor in self.__pgm.get_graph().vs[self.__pgm.get_graph().neighbors(v_name)]['name']:
                incoming_messages.append(self.get_factor2variable_msg(f_name_neighbor, v_name))
        
            prob = self.__normalize_msg(joint_distribution(incoming_messages))
            margProb[v_name] = prob.get_distribution()[1]

        return margProb
    
    # ----------------------- Variable to factor ------------
    def get_variable2factor_msg(self, v_name, f_name):
        return self.__msg[(v_name, f_name)]
    
    def __compute_variable2factor_msg(self, v_name, f_name):
        incoming_messages = []
        for f_name_neighbor in self.__pgm.get_graph().vs[self.__pgm.get_graph().neighbors(v_name)]['name']:
            if f_name_neighbor != f_name:
                incoming_messages.append(self.get_factor2variable_msg(f_name_neighbor, v_name))
        
        if not incoming_messages:
            return factor([v_name], np.array([1]*self.__pgm.get_graph().vs.find(name=v_name)['rank']))
        else:
            return self.__normalize_msg(joint_distribution(incoming_messages))
    
    # ----------------------- Factor to variable ------------
    def get_factor2variable_msg(self, f_name, v_name):
        return self.__msg[(f_name, v_name)]
    
    def __compute_factor2variable_msg(self, f_name, v_name):
        incoming_messages = [self.__pgm.get_graph().vs.find(f_name)['factor_']]
        marginalization_variables = []
        for v_name_neighbor in self.__pgm.get_graph().vs[self.__pgm.get_graph().neighbors(f_name)]['name']:
            if v_name_neighbor != v_name:
                incoming_messages.append(self.get_variable2factor_msg(v_name_neighbor, f_name))
                marginalization_variables.append(v_name_neighbor)
        return self.__normalize_msg(factor_marginalization(
            joint_distribution(incoming_messages),
            marginalization_variables
        ))
    
    # ----------------------- Other -------------------------
    def __loop(self, num_iter):
        # Message updating
        isConverge = False
        while self.__t < num_iter and not isConverge:
            for edge in self.__pgm.get_graph().es:
                start_index, end_index = edge.tuple[0], edge.tuple[1]
                start_name, end_name   = self.__pgm.get_graph().vs[start_index]['name'], self.__pgm.get_graph().vs[end_index]['name']
                if self.__pgm.get_graph().vs[start_index]['is_factor']:
                    self.__msg_new[(start_name, end_name)] = self.__compute_factor2variable_msg(start_name, end_name) if not str(start_name).startswith(STATEMENT_ID_PREFIX) else factor([start_name], np.array([0.5, 0.5]))
                    self.__msg_new[(end_name, start_name)] = self.__compute_variable2factor_msg(end_name, start_name)
                else:
                    self.__msg_new[(start_name, end_name)] = self.__compute_variable2factor_msg(start_name, end_name) if not str(start_name).startswith(STATEMENT_ID_PREFIX) else factor([start_name], np.array([0.5, 0.5]))
                    self.__msg_new[(end_name, start_name)] = self.__compute_factor2variable_msg(end_name, start_name)
            converge = True
            for (start_name, end_name), new_msg in self.__msg_new.items():
                old_msg = self.__msg[(start_name, end_name)]
                if (abs(old_msg.get_distribution() - new_msg.get_distribution()) > self.threshold).sum() == old_msg.get_distribution().size:
                    converge = False
            
            isConverge = converge
            
            self.__msg.update(self.__msg_new)
            self.__t += 1
    
    def __normalize_msg(self, message):
        return factor(message.get_variables(), message.get_distribution()/np.sum(message.get_distribution()))

In [None]:
def factorLoader(factor_input):
    tokens = factor_input.split(DILIMITER_2)
    for idx in range(0, len(tokens), 4):
        yield tokens[idx], tokens[idx+1], tokens[idx+2], tokens[idx+3]

In [None]:
def printGraphInput(str_):
    print("graph input:")
    str_tokens = [i.split('(') for i in str_.split(')') if i != '']
    for token in str_tokens:
        print(token[0], token[1].split(','))

In [None]:
def checkDuplicateVar(str_):
    str_tokens = [i.split('(') for i in str_.split(')') if i != '']
    for token in str_tokens:
        vars = token[1].split(',')
        if len(vars) != len(set(vars)):
            print("contain duplicate variables: ", token[0])
            print(vars)
            raise ValueError('Duplicated variables: ' + token[0])

In [None]:
def printFactorInput(factor_input):
    print("factor input:")
    tokens = factor_input.split(DILIMITER_2)
    for idx in range(0, len(tokens), 4):
        print("Node:", tokens[idx], tokens[idx+1], tokens[idx+2], tokens[idx+3])

In [None]:
def recvMsg(conn):
    graph_input = ""
    factor_input = ""
    
    # Read Graph Input
    while True:
        message = conn.recv(BUFFER_SIZE)
        if not message:
            return None, None
        message_str = message.decode(ENCODING_METHOD)
        if message_str == MSG_BREAK:
            break
        else:
            graph_input += message_str
    
    # Read Factor Input
    while True:
        message = conn.recv(BUFFER_SIZE)
        if not message:
            return None, None
        message_str = message.decode(ENCODING_METHOD)
        if message_str == MSG_BREAK:
            break
        else:
            factor_input += message_str
    
    print("graph input size: ", len(graph_input))
    print("factor input size: ", len(factor_input))
    
    return graph_input, factor_input 

In [None]:
def printNodeCount(graph):
    factor_node_count = 0
    var_node_count = 0
    
    for i in range(graph.vcount()):
        if graph.vs[i]['is_factor']:
            factor_node_count += 1
        else:
            var_node_count += 1
            
    print("var node count:", var_node_count, "factor node count:", factor_node_count, "total:", factor_node_count + var_node_count)
    print("edges count:", graph.ecount())

In [None]:
def startServer():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind((HOST, PORT))
        stillWorking = True
        while stillWorking:
            s.listen()
            conn, addr = s.accept()
            with conn:
                print(f"Connected by {addr}")

                while True:
                    print("-"*20)
                    graph_input, factor_input = recvMsg(conn)
                    
                    if not graph_input or not factor_input:
#                         print("Error: Graph input or factor input is null")
                        response = "".encode(ENCODING_METHOD)
                        conn.sendall(response)
                        break

#                     graph_input = graph_input.decode(ENCODING_METHOD)
#                     factor_input = factor_input.decode(ENCODING_METHOD)
                    
                    
#                     print("graph_input", graph_input)
#                     print("factor_input", factor_input)
#                     printGraphInput(graph_input)
#                     printFactorInput(factor_input)
        
                    checkDuplicateVar(graph_input)
                    if graph_input == END and factor_input == END:
                        output_str = END
                        response = output_str.encode(ENCODING_METHOD)
                        print("response:", response)
                        conn.sendall(response)
                        print("Terminate server...")
                        stillWorking = False
                        break

                    fg = string2factor_graph(graph_input)

                    predIDs_all = set()
                    for order, constraintID, predIDs_str, probs_str in factorLoader(factor_input):
                        try:
                            predIDs = predIDs_str.split(DILIMITER_1)

        #                     print(predIDs)
                            predIDs_all.update(predIDs)
                            predCount = len(predIDs)

                            shape = [2 for _ in range(predCount)]
                            shape = tuple(shape)

                            probs_tokens = probs_str.split(DILIMITER_1)
                            
                            probs = []
                            for probs_token in probs_tokens:
                                probs_str, count = probs_token.split(MUL_SIGN)
                                count = int(count)
                                prob = float(probs_str)
                                for _ in range(count):
                                    probs.append(prob)
                                    
                            probs = np.array(probs)
                            probs = probs.reshape(shape)

                            fg.change_factor_distribution(constraintID, factor(predIDs,  probs))
                        except Exception as e:
                            print(e)
                            print("constraintID:", constraintID)
                            print("predIDs_str:", predIDs_str)
                            print("probs_str:", probs_str)
                            return

                    printNodeCount(fg.get_graph())
                            
                    lbp = loopy_belief_propagation(fg)
                    start = time.time()
                    margProb = lbp.belief(predIDs_all, MAX_ITR)
                    end = time.time()
                    
                    print("time needed: ", end - start)
                    plot_factor_graph(fg)
                    output_str = ""
                    for predID, prob in margProb.items():
#                         print(predID, prob)
                        output_str += predID + DILIMITER_1 + str(prob)
                        output_str += DILIMITER_2

                    output_str = output_str[:-1]
                    response = output_str.encode(ENCODING_METHOD)
#                     print("response:", response)
                    conn.sendall(response)

In [None]:
startServer()