In [None]:
import sys
!{sys.executable} -m pip install sockets
!{sys.executable} -m pip install numpy
!{sys.executable} -m pip install igraph
!{sys.executable} -m pip install matplotlib
!{sys.executable} -m pip install pyvis

In [None]:
import time

import numpy  as np
import igraph as ig
import matplotlib.pyplot as plt

plt.rc("text", usetex=True)

%run ../SocketServer.ipynb
%run ./2-ImplementationFactor.ipynb
%run ./3-ImplementationPGM.ipynb
%run ./customizedLBP.ipynb

In [None]:
def factorLoader(factor_input, dilimiter):
    tokens = factor_input.split(dilimiter)
    for idx in range(0, len(tokens), 4):
        yield tokens[idx], tokens[idx+1], tokens[idx+2], tokens[idx+3]

In [None]:
def printNodeCount(graph):
    factor_node_count = 0
    var_node_count = 0
    
    for i in range(graph.vcount()):
        if graph.vs[i]['is_factor']:
            factor_node_count += 1
        else:
            var_node_count += 1
            
    print("var node count:", var_node_count, "factor node count:", factor_node_count, "total:", factor_node_count + var_node_count)
    print("edges count:", graph.ecount())

In [None]:
def checkDuplicateVar(str_):
    str_tokens = [i.split('(') for i in str_.split(')') if i != '']
    for token in str_tokens:
        vars = token[1].split(',')
        if len(vars) != len(set(vars)):
            print("contain duplicate variables: ", token[0])
            print(vars)
            raise ValueError('Duplicated variables: ' + token[0])

In [None]:
class BeliefPropagationServer(SocketServer):
    
    def __init__(self, host, port):
        super().__init__(host, port)
        self.MUL_SIGN = "*"
        self.MAX_ITR = 10
        
    def func(self):
        while True:
            print("-"*20)
            graph_input = self.recvMsg()
            factor_input = self.recvMsg()
            
            if not graph_input or not factor_input:
                response = ""
                self.sendMsg(response)
                print("receive empty message")
                break
            
            checkDuplicateVar(graph_input)
            if graph_input == self.END and factor_input == self.END:
                response = self.END
                self.sendMsg(response)
                self._stillWorking = False
                print("receive end message")
                break
                
            fg = string2factor_graph(graph_input)
            predIDs_all = set()
            for order, constraintID, predIDs_str, probs_str in factorLoader(factor_input, self.DILIMITER_2):
                try:
                    predIDs = predIDs_str.split(self.DILIMITER_1)

#                     print(predIDs)
                    predIDs_all.update(predIDs)
                    predCount = len(predIDs)

                    shape = [2 for _ in range(predCount)]
                    shape = tuple(shape)

                    probs_tokens = probs_str.split(self.DILIMITER_1)

                    probs = []
                    for probs_token in probs_tokens:
                        probs_str, count = probs_token.split(self.MUL_SIGN)
                        count = int(count)
                        prob = float(probs_str)
                        for _ in range(count):
                            probs.append(prob)

                    probs = np.array(probs)
                    probs = probs.reshape(shape)

                    fg.change_factor_distribution(constraintID, factor(predIDs,  probs))
                except Exception as e:
                    print(e)
                    print("constraintID:", constraintID)
                    print("predIDs_str:", predIDs_str)
                    print("probs_str:", probs_str)
                    return

            printNodeCount(fg.get_graph())

            lbp = myLBP(fg)
            start = time.time()
            margProb = lbp.belief(predIDs_all, self.MAX_ITR)
            end = time.time()

            print("time needed: ", end - start)
            plot_factor_graph(fg)
            output_str = ""
            for predID, prob in margProb.items():
#                         print(predID, prob)
                output_str += predID + self.DILIMITER_1 + str(prob)
                output_str += self.DILIMITER_2

            response = output_str[:-1]
            self.sendMsg(response)



In [None]:
# Server setting
HOST = "127.0.0.1"
PORT = 8080
STATEMENT_ID_PREFIX = "S_"

In [None]:
server = BeliefPropagationServer(HOST, PORT)

In [None]:
server.start()