In [1]:
import numpy as np
import sympy
import gurobipy as gp
from gurobipy import GRB
import time

In [2]:
def PrimeBranching(x, name = 'Naive', arg = None):
    '''
    *** Make sure that x.lb >= arg[1] (k the remainder) ***
    Find an element of list x which has the highest distance from its closest prime number
    If name = Naive; do choose non-prime which is furthest, name = Modulo, arg = k branch to prime of the form 12*m+k
    
    Input argument: solution from the model m; m.x
    arg = [n,k] n: number of solution, k residue modulo 12
    
    Output: index of the branching variable
    '''
    dist = [-1 for _ in x]
    if name == 'Naive':
        for ind, val in enumerate(x):
            round_val = int(round(val,0))
            if sympy.isprime(round_val):
                dist[ind] = abs(val - round_val)
            else:
                next_prime = sympy.nextprime(val)
                prev_prime = sympy.prevprime(val)
                dist[ind] = min(val - prev_prime, next_prime - val)
    elif name == 'Modulo':
        n,k = arg
        k0 = k%6
        for ind, val in enumerate(x):
            # x variable
            if ind < n:
                round_val = int(round(val,0))
                if sympy.isprime(round_val) and (round_val%12 == k):
                    dist[ind] = abs(val - round_val)
                else:
                    next_prime = sympy.nextprime(val)
                    while(next_prime%12 != k):
                        next_prime = sympy.nextprime(next_prime + 1)
                    prev_prime = sympy.prevprime(val)
                    while(prev_prime%12 != k):
                        prev_prime = sympy.prevprime(prev_prime - 1)
                    dist[ind] = min(val - prev_prime, next_prime - val)
            else:
                round_val = int(round(val,0))
                if sympy.isprime(round_val) and (round_val%6 == k0):
                    dist[ind] = abs(val - round_val)
                else:
                    next_prime = sympy.nextprime(val)
                    while(next_prime%6 != k0):
                        next_prime = sympy.nextprime(next_prime + 1)
                    prev_prime = sympy.prevprime(val)
                    while(prev_prime%6 != k0):
                        prev_prime = sympy.prevprime(prev_prime - 1)
                    dist[ind] = min(val - prev_prime, next_prime - val)
    br_ind = np.argmax(dist)
    return br_ind

In [3]:
def Branch_val(x, br_ind, name = 'Naive', arg = None):
    '''
    Output: branch_down, branch_up := branching values
    '''
    if name == 'Naive':
        branch_down = sympy.prevprime(x[br_ind])
        branch_up = sympy.nextprime(x[br_ind])
    elif name == 'Modulo':
        n,k = arg
        k0 = k%6
        if br_ind < n:
            prev_prime = sympy.prevprime(x[br_ind])
            while(prev_prime%12 != k):
                prev_prime = sympy.prevprime(prev_prime - 1)
            branch_down = prev_prime
            next_prime = sympy.nextprime(x[br_ind])
            while(next_prime%12 != k):
                next_prime = sympy.nextprime(next_prime + 1)
            branch_up = next_prime
        else:
            prev_prime = sympy.prevprime(x[br_ind])
            while(prev_prime%6 != k0):
                prev_prime = sympy.prevprime(prev_prime - 1)
            branch_down = prev_prime
            next_prime = sympy.nextprime(x[br_ind])
            while(next_prime%6 != k0):
                next_prime = sympy.nextprime(next_prime + 1)
            branch_up = next_prime
    return branch_down, branch_up

In [4]:
def testInteger(x):
    '''
    Test if the given list of number are all integers
    
    Input argument: x (list of numbers)
    
    Output: Boolean, if all elements of x are integer or not
    '''
    test = np.round(x, 6) - np.round(x, 0)
    test = np.abs(test)
    tolerance = np.ones(len(test)) * -1e-6
    if all(test < tolerance) or all(test == 0):
        return True
    else:
        return False

def testPrime(x, name = 'Naive', arg = None):
    '''
    Naive: test if the given list of number are all prime numbers 
    Modulo: test if the given list of number are all prime numbers of the form 12m+k for the first n terms and 6m+k0 else
    
    Input argument: x (list of numbers)
    
    Output: Boolean, if all elements of x are prime or not
    '''
    
    # we have to first test if it is an integer solution
    if testInteger(x):
        test = np.round(x, 6)
        s = 0
        for val in test:
            if not sympy.isprime(int(val)):
                return False
        return True
    else:
        return False

In [5]:
def check_prime_sol(x):
    # Check if the obtained solution satisfy the conditions
    n = len(x)
    for i in range(n):
        if sympy.isprime(x[i]) == False:
            print(x[i])
            return 0
        for j in range(i+1,n):
            if sympy.isprime(int((x[i]+x[j])/2)) == False:
                print(int((x[i]+x[j])/2))
                return 0
    return 1

In [6]:
def fixing_strategy(name = 'Naive', x_bar = [], p =[], l=[]):
    # x_bar: solution from previous iteration 
    # p = [p_1,p_2,...] unassigned position 
    # l = [l_1, l_2,...] position of x_bar to be excluded
    # return (pos, val): list of indices to be fixed and values
    
    pos = []
    val = []
    if x_bar == []:
        return (pos,val)
    if name == 'Naive':
        #do not fix any value
        pos = []
        val = []
    if name == 'SelectAll':
        #extract x_bar and p
        for i in range(len(x_bar)+1):
            if i <= p[0]-1:
                pos.append(i)
                val.append(x_bar[i])
            elif i > p[0]:
                pos.append(i)
                val.append(x_bar[i-1])
    if name == 'ExcludeOne':
        x_prime = []
        for i,x in enumerate(x_bar):
             if i not in l:
                    x_prime.append(x)
        for i in range(len(x_bar)+1):
            if i <= p[0]-1:
                pos.append(i)
                val.append(x_prime[i])
            elif i > p[0] and i <= p[1]-1:
                pos.append(i)
                val.append(x_prime[i-1])
            elif i > p[1]:
                pos.append(i)
                val.append(x_prime[i-2])
    if name == 'ExcludeTwo':
        x_prime = []
        for i,x in enumerate(x_bar):
             if i not in l:
                x_prime.append(x)
        for i in range(len(x_bar)+1):
            if i <= p[0]-1:
                pos.append(i)
                val.append(x_prime[i])
            elif i > p[0] and i <= p[1]-1:
                pos.append(i)
                val.append(x_prime[i-1])
            elif i > p[1] and i <= p[2]-1:
                pos.append(i)
                val.append(x_prime[i-2])
            elif i > p[2]:
                pos.append(i)
                val.append(x_prime[i-3])
    if name == 'SelectHalf':
        x_prime = []
        for i,x in enumerate(x_bar):
             if i not in l:
                x_prime.append(x)
        for i in range(len(x_bar)+1):
            if i <= p[0]-1:
                pos.append(i)
                val.append(x_prime[i])
            else:
                for j in range(1,len(p)):
                    if i > p[j-1] and i <= p[j]-1:
                        pos.append(i)
                        val.append(x_prime[i-j])
            if i > p[-1]:
                pos.append(i)
                val.append(x_prime[i-len(p)])
    return (pos,val)

In [7]:
def find_average_prime(n, upper_bound = 10**6, time_limit = 3000, alpha = 0.2, branch_name = 'Naive', remainder = 0, fix_name = 'Naive', prev_sol = [], free_pos = [], exclude_pos =[]):
    '''
    Find n primes such that am average of every two primes is also prime
    
    Output: 
        node_count (int): number of nodes in the enumeration tree
        global_ub (float): the optimal objective value
        incumbent (list): prime number solution to the problem
        computation_time (float): computation time of the function in seconds
    
    '''
    initial_bound = upper_bound
    start_time = time.time()
    
    # Create optimization problem
    #Suppressing Gurobi output
    env = gp.Env(empty=True)
    env.setParam("OutputFlag",0)
    env.start()
    m = gp.Model("Prime Optimization", env = env)
    
    #Set lower bound for variables
    if branch_name == 'Naive':
        #Variables
        x = m.addVars(n, lb = 3, ub = upper_bound, name = 'x')
        y = m.addVars(n, n, lb = 5, ub = upper_bound, name = 'y')
        for i in range(n):
            for j in range(0,i+1):
                y[i,j].lb = 5
                y[i,j].ub = 5
    elif branch_name == 'Modulo':
        if remainder == 1:
            lower_bound = 13
        elif remainder == 5:
            lower_bound = 5
        elif remainder == 7:
            lower_bound = 7
        elif remainder == 11:
            lower_bound = 11
        x = m.addVars(n, lb = lower_bound, ub = upper_bound, name = 'x')
        y = m.addVars(n, n, lb = lower_bound, ub = upper_bound, name = 'y')
        # Set lower bound to be smallest prime of the form 12m + remainder
        for i in range(n):
            for j in range(0,i+1):
                if remainder == 1:
                    y[i,j].lb = 13
                    y[i,j].ub = 13
                elif remainder == 5:
                    y[i,j].lb = 5
                    y[i,j].ub = 5
                elif remainder == 7:
                    y[i,j].lb = 7
                    y[i,j].ub = 7
                elif remainder == 11:
                    y[i,j].lb = 11
                    y[i,j].ub = 11
                else:
                    print("Problem is infeasible!")
                    return (0,0,[],0)
    
    # Fix varaibles w.r.t. fixing strategy
    # Extract x_bar, p, l from kwargs
    pos, val = fixing_strategy(name = fix_name, x_bar=prev_sol, p=free_pos, l=exclude_pos)
    for i in range(len(pos)):
        x[pos[i]].lb = val[i]
        x[pos[i]].ub = val[i]
        
    # Increase the bound if the last element is free
    if n in free_pos and upper_bound > initial_bound:
        x[n-1].lb = upper_bound/(1+alpha)
    
    # add constraints
    m.addConstrs(2*y[i,j] == x[i] + x[j] for i in range(n) for j in range(i+1, n))
    if branch_name == 'Modulo':
        m.addConstrs(x[i+1] >= x[i] + 12 for i in range(n-1))
        m.addConstrs(y[i,j] >= x[i] + 6 for i in range(n) for j in range(i+1,n))
        m.addConstrs(y[i,j] <= x[j] - 6 for i in range(n) for j in range(i+1,n))
    else:
        m.addConstrs(x[i+1] >= x[i] + 4 for i in range(n-1))
        m.addConstrs(y[i,j] >= x[i] + 2 for i in range(n) for j in range(i+1,n))
        m.addConstrs(y[i,j] <= x[j] - 2 for i in range(n) for j in range(i+1,n))
    
    # set objective
    m.setObjective(0, GRB.MINIMIZE)

    # optimize
    m.optimize()
    
    #Check if feasible
    if m.status !=2:
        print("Problem is infeasible!")
        return (0,0,[],0)
    else:
        # if the root node gives a feasible solution
        if testPrime(m.x):
            return (1, m.objval, m.x, 0)
#         else:
#             print("Relaxation solution at root node is ", m.x)
    
    #Intializing global upper bound and incumbent
    global_ub = 1e12
    incumbent = []
    
    # print("Initial objective value is ", m.objval)
    nodes_model = [m]
    
    #Node count
    node_count = 1
    
    #------------------------------------
    # Starting branch-and-bound iteration
    #------------------------------------ 
#     k = 0
    while( len(nodes_model) != 0 and time.time()- start_time <= time_limit):
#     while(k <= 5):
#         k += 1

        #------------------------------------------------------
        #Depth first search: Last node added to the list is choosen
        #-------------------------------------------------------
        m = nodes_model[-1]
        m.update()
        m.optimize()

        # prune the node by infeasiblility or pruned by bound
        if m.status != 2:
            #Removing parent node
            nodes_model.pop(-1)
            continue
        elif m.objval >= global_ub:
            #Removing parent node
            nodes_model.pop(-1)
            continue
        else:
            # update the global upper bound
            if m.objval < global_ub and testPrime(m.x):
                global_ub = m.objval
                incumbent = m.x
                # prune by optimality
                nodes_model.pop(-1)
                continue
            else:
                # find branching variable
                br_ind = int(PrimeBranching(m.x, name = branch_name, arg = [n,remainder]))
                branch_down, branch_up = Branch_val(m.x, br_ind, name = branch_name, arg = [n,remainder])
            
                # subproblem 1 branch down
                left_model = m.copy()
                left_model.getVars()[br_ind].ub = branch_down
                left_model.update()
        
                # subproblem 2 branch up
                right_model = m.copy()
                right_model.getVars()[br_ind].lb = branch_up
                right_model.update()
                
                #Removing parent node
                nodes_model.pop(-1)
        
                # Stored the node      
                # nodes_model.append(left_model)
                nodes_model.append(right_model)
                nodes_model.append(left_model)
        
                node_count += 2
            
    end_time = time.time()
    computation_time = end_time - start_time
        
    return (node_count, global_ub, incumbent, computation_time)

In [8]:
def main_algorithm(n = 8, M = 1000, TIME_LIMIT = 3000, solve_time_limit = 300, alpha = 0.2, fixing_name = 'Naive', branching_name = 'Naive', br_arg = 0, fix_arg = {}):
    start_main_time = time.time() 
    update_time = 0
    k = 0
    x = []
    m = 1
    while m <= n:
        if fixing_name != 'Naive':
            # Extract free_pos (p) and exclude_pos (l) from dict fix_arg
            free_pos = fix_arg[m][0]
            exclude_pos = fix_arg[m][1]
        else:
            free_pos = []
            exclude_pos = []
        node_count, global_ub, incumbent, computation_time = find_average_prime(n=m, upper_bound=M, time_limit = solve_time_limit, alpha = alpha, branch_name=branching_name, remainder=br_arg, fix_name=fixing_name, prev_sol=x, free_pos=free_pos, exclude_pos=exclude_pos)
        k += 1
        # if feasible solution found
        if incumbent != []:
            x = incumbent[:m]
            print('Number of primes: ', m)
            print('Current Solution: ', x)
            print('Total Time Spent: ', time.time() - start_main_time)
            if m == n:
                return (x,time.time() - start_main_time)
            m += 1
        else:
            M = M * (1 + alpha)
#             print('The upper is increased to: ', M)
        # Check if solve_time_limit is exceeded
        if time.time() - start_main_time > TIME_LIMIT:
            print('The current upper bound is: ', M)
            return ([], time.time() - start_main_time)
    return ([], time.time() - start_main_time)

In [9]:
n = 12
M = 1000
TIME_LIMIT = 3000
solve_time_limit = 600
alpha = 0.2

# Create fixing args
fix_arg_ExcludeOne = {}
for i in range(1,n+1):
    fix_arg_ExcludeOne[i] = [[i-2, i-1],[i-1]]

fix_arg_ExcludeTwo = {}
for i in range(1,n+1):
    fix_arg_ExcludeTwo[i] = [[i-3, i-2, i-1],[i-2, i-1]]

fix_arg_SelectAll = {}
for i in range(1,n+1):
    fix_arg_SelectAll[i] = [[i-1],[]]

fix_arg_SelectHalf = {}
fix_arg_SelectHalf[1] = [[0],[]]
fix_arg_SelectHalf[2] = [[1],[]]
fix_arg_SelectHalf[3] = [[1,2],[1]]
fix_arg_SelectHalf[4] = [[2,3],[2]]
fix_arg_SelectHalf[5] = [[2,3,4],[2,3]]
fix_arg_SelectHalf[6] = [[3,4,5],[3,4]]
fix_arg_SelectHalf[7] = [[3,4,5,6],[3,4,5]]
fix_arg_SelectHalf[8] = [[4,5,6,7],[4,5,6]]
fix_arg_SelectHalf[9] = [[4,5,6,7,8],[4,5,6,7]]
fix_arg_SelectHalf[10] = [[5,6,7,8,9],[5,6,7,8]]
fix_arg_SelectHalf[11] = [[5,6,7,8,9,10],[5,6,7,8,9]]
fix_arg_SelectHalf[12] = [[6,7,8,9,10,11],[6,7,8,9,10]]

list_fixing_name = ['ExcludeOne', 'ExcludeTwo', 'SelectAll', ]
list_fixing_arg = [fix_arg_ExcludeOne, fix_arg_ExcludeTwo, fix_arg_SelectAll]
list_branch_name = ['Modulo', 'Modulo', 'Modulo']
list_branch_arg = [5,5,5]

In [10]:
sol = {}
ttime = {} 
for i in range(len(list_fixing_name)):
    print('The following result is to solve with branching_strategy = {0}; branching_arg = {1}; fixing_strategy = {2}; fixing_arg = {3}'.format(list_branch_name[i], list_branch_arg[i], list_fixing_name[i], list_fixing_arg[i]))
    sol[i], ttime[i] = main_algorithm(n=n, M=M, TIME_LIMIT=TIME_LIMIT, solve_time_limit=solve_time_limit, alpha=alpha, fixing_name=list_fixing_name[i], branching_name=list_branch_name[i], br_arg = list_branch_arg[i], fix_arg = list_fixing_arg[i])
print(sol)
print(ttime)

The following result is to solve with branching_strategy = Modulo; branching_arg = 5; fixing_strategy = ExcludeOne; fixing_arg = {1: [[-1, 0], [0]], 2: [[0, 1], [1]], 3: [[1, 2], [2]], 4: [[2, 3], [3]], 5: [[3, 4], [4]], 6: [[4, 5], [5]], 7: [[5, 6], [6]], 8: [[6, 7], [7]], 9: [[7, 8], [8]], 10: [[8, 9], [9]], 11: [[9, 10], [10]], 12: [[10, 11], [11]]}
Number of primes:  1
Current Solution:  [5.0]
Total Time Spent:  0.007733345031738281
Number of primes:  2
Current Solution:  [5.0, 977.0]
Total Time Spent:  0.008783578872680664
Number of primes:  3
Current Solution:  [5.0, 17.0, 29.0]
Total Time Spent:  0.015506505966186523
Number of primes:  4
Current Solution:  [5.0, 17.0, 29.0, 509.0]
Total Time Spent:  0.04163932800292969
Number of primes:  5
Current Solution:  [5.0, 17.0, 29.0, 89.0, 449.0]
Total Time Spent:  0.08334684371948242
Number of primes:  6
Current Solution:  [5.0, 17.0, 29.0, 89.0, 449.0, 1277.0]
Total Time Spent:  0.5005519390106201
Number of primes:  7
Current Solution

KeyboardInterrupt: 