In [3]:
from itertools import product
from sympy import symbols
import numpy as np
import time 
from pysat.solvers import Solver
import random
from pysat.solvers import Glucose3

In [22]:
# FUNCTIONS
def combinatorial_line(word, alphabet):
    lines = []
    for a in alphabet:
        new_word = list(a if val == x else val for val in word)  # tuple instead of list
        lines.append(new_word)
    return lines

def create_var_map(tuples):
    """
    Input: list of tuples (each tuple can have arbitrary length)
    Output: dictionary mapping each tuple -> unique integer literal
    """
    return {t: i+1 for i, t in enumerate(tuples)}

def tuple_to_litmap(cl, map, lit_map):
    """
    Map a tuple -> var_map ID -> lit_map value
    """
    literals = [map[tuple(c)] for c in cl]
    literals = [lit_map[l] for l in literals]
    return literals


def split_into_chunks(lst, num_chunks):
    n = len(lst)
    chunk_sizes = [(n + i) // num_chunks for i in range(num_chunks)]
    chunks = []
    start = 0
    for size in chunk_sizes:
        chunks.append(lst[start:start + size])
        start += size
    return chunks

def compatible(set1, set2):
    # check if any element in set1 has its negation in set2
    return not any(-x in set2 for x in set1)

def multi_compatible(list_of_lists):
    sets = [set(lst) for lst in list_of_lists]  # convert to sets
    for i in range(len(sets)):
        for j in range(i + 1, len(sets)):
            if any(-x in sets[j] for x in sets[i]):
                return False  # incompatible
    union_set = set().union(*sets)
    return union_set

def split_into_random_chunks(lst, num_chunks):
    # Shuffle the list first
    shuffled = lst[:]
    random.shuffle(shuffled)
    
    n = len(shuffled)
    chunk_sizes = [(n + i) // num_chunks for i in range(num_chunks)]
    
    chunks = []
    start = 0
    for size in chunk_sizes:
        chunks.append(shuffled[start:start + size])
        start += size
    return chunks

In [26]:
# SETTING VARIABLES 
x = symbols('x')
m = 2 # m is the potential Hales-Jewett number 
n = 3 # n is the length of the alphabet
c = 2 # don't actually use this, just a reminder that we set c = 2
# so, we're looking at HJ(n,2)
ch_num = 4

In [27]:
print(f"Generating combinatorial lines and boolean literals for {m} = HJ({n},2)")

alphabet = list(range(1, n+1))
configs = []
variable_words = []
cells = list(product(alphabet, repeat=m))

for mask in product([0, 1], repeat=m):
    if sum(mask) == 0:  # skip all-blank
        continue
    config = [x if bit else 0 for bit in mask]
    configs.append(config)

for c in configs:
    zero_positions = [i for i, val in enumerate(c) if val == 0]
    k = len(zero_positions)
        
    for fill in product(alphabet, repeat=k):
        word = list(c)  
        for pos, val in zip(zero_positions, fill):
            word[pos] = val
        variable_words.append(word)  

combinatorial_lines = [] # create the combinatorial lines
for vw in variable_words:
    combinatorial_lines.append(combinatorial_line(vw, alphabet))

Generating combinatorial lines and boolean literals for 2 = HJ(3,2)


In [28]:
# MAPPING FROM LINES TO BOOLEANS
map = create_var_map(cells) # define the map from line to boolean

for cl in combinatorial_lines: # for each combinatorial line, we are going to generate 2 clauses
    literals = [map[tuple(c)] for c in cl]

In [29]:
chunk_list = []

combinatorial_chunks = split_into_chunks(combinatorial_lines, ch_num)
#combinatorial_chunks = split_into_random_chunks(combinatorial_lines, ch_num)

for i, chunk in enumerate(combinatorial_chunks):
    chunk_list.append(chunk)          # store in a list for iteration
    print(f"Chunk {i} has {len(chunk)} lines")

Chunk 0 has 1 lines
Chunk 1 has 2 lines
Chunk 2 has 2 lines
Chunk 3 has 2 lines


In [30]:


def solve_chunk(chunk):
    """
    Solve a chunk using a SAT solver with optional symmetry blocking on x1.
    
    Args:
        chunk (list of lists/tuples): The chunk of constraints to solve.
        map (dict): Mapping from tuple elements to integer literals.
        tuple_to_litmap (func): Function to convert tuples to literal mapping.
        symmetry_block (bool, optional): Whether to symmetry block x1 if present.
        
    Returns:
        solutions (list of sets): List of solutions in terms of original tuples.
        solution_count (int): Number of solutions found.
        lit_map (dict): Mapping from original literals to solver literals.
    """
    solver = Solver(name='glucose4')
    literals_in_chunk = set()
    x1_in_lits = False

    # Collect literals
    for cl in chunk:
        literals = [map[tuple(c)] for c in cl]
        literals_in_chunk.update(literals)

    # Generate literal maps
    lit_map = {orig: i+1 for i, orig in enumerate(literals_in_chunk)}
    rev_map = {v: k for k, v in lit_map.items()}

    if 1 in literals_in_chunk:
        x1_in_lits = True
        #print("Performing symmetry blocking on x_1")

    #print("Literals before mapping:", literals_in_chunk)
    #print("Literals after mapping:", lit_map)

    # Add clauses
    for cl in chunk:
        literals = tuple_to_litmap(cl, map, lit_map)
        solver.add_clause(literals)             # all-blue check
        solver.add_clause([-lit for lit in literals])  # all-red check

    if x1_in_lits:
        solver.add_clause([1])

    # Enumerate solutions
    solutions = []
    for model in solver.enum_models():
        mapped_model = set(
            rev_map[abs(lit)] if lit > 0 else -rev_map[abs(lit)]
            for lit in model if abs(lit) in rev_map
        )
        solutions.append(mapped_model)

    solution_count = len(solutions)
    print(f"Found {solution_count} solutions in this chunk.")
    return solutions, solution_count, lit_map



# Trying a quick lower bound method here

In [31]:

def one_solve(chunk, blocks):
    """
    Solve a chunk using a SAT solver with optional symmetry blocking on x1.
    
    Args:
        chunk (list of lists/tuples): The chunk of constraints to solve.
        map (dict): Mapping from tuple elements to integer literals.
        tuple_to_litmap (func): Function to convert tuples to literal mapping.
        symmetry_block (bool, optional): Whether to symmetry block x1 if present.
        
    Returns:
        solutions (list of sets): List of solutions in terms of original tuples.
        solution_count (int): Number of solutions found.
        lit_map (dict): Mapping from original literals to solver literals.
    """
    solver = Solver(name='glucose4')
    literals_in_chunk = set()
    x1_in_lits = False

    # Collect literals
    for cl in chunk:
        literals = [map[tuple(c)] for c in cl]
        literals_in_chunk.update(literals)

    # Generate literal maps
    lit_map = {orig: i+1 for i, orig in enumerate(literals_in_chunk)}
    rev_map = {v: k for k, v in lit_map.items()}

    if 1 in literals_in_chunk:
        x1_in_lits = True
        #print("Performing symmetry blocking on x_1")

    #print("Literals before mapping:", literals_in_chunk)
    #print("Literals after mapping:", lit_map)

    # Add clauses
    for cl in chunk:
        literals = tuple_to_litmap(cl, map, lit_map)
        solver.add_clause(literals)             # all-blue check
        solver.add_clause([-lit for lit in literals])  # all-red check

    if x1_in_lits:
        solver.add_clause([1])

    for b in blocks:
        solver.add_clause(b)

    if solver.solve():
        model = solver.get_model()
        mapped_model = set(
        rev_map[abs(lit)] if lit > 0 else -rev_map[abs(lit)]
        for lit in model if abs(lit) in rev_map
        )
        blocks.append([-lit for lit in model])
    else:
        return -1

    return model, blocks

In [32]:
start_time = time.time()

ct = 0
blocks = [[] for _ in range(len(chunk_list))]

while solution == False and ct < 10000:
    sols = []
    if ct % 1000 == 0:
        print(f"Iteration {ct}")
    for i,chunk in enumerate(chunk_list):
        solutions, blocks[i] = one_solve(chunk, blocks[i])
        if solutions != -1:
            sols.append(solutions)    
    solution = multi_compatible(sols)
    ct += 1

end_time = time.time()
elapsed_time = end_time - start_time 
print(f"Elapsed time: {elapsed_time:.5f} seconds")  
print(f"Found solution: {solution}")

NameError: name 'solution' is not defined

# End lower bound method

In [33]:
start_time = time.time()

sols_per_chunk = [0 * i for i in range(len(chunk_list))]

for i,chunk in enumerate(chunk_list):
    solutions, count, lit_map = solve_chunk(chunk)
    sols_per_chunk[i] = solutions

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.5f} seconds")


Found 3 solutions in this chunk.
Found 36 solutions in this chunk.
Found 18 solutions in this chunk.
Found 9 solutions in this chunk.
Elapsed time: 0.01462 seconds


In [10]:

print("helo")

start_time = time.time()
ct = 0
combined_solutions = set(tuple(sorted(sol, key=abs)) for sol in sols_per_chunk[0])
for chunk_solutions in sols_per_chunk[1:]:
    next_combined = set()
    for sol_a in combined_solutions:
        for sol_b in chunk_solutions:
            if compatible(sol_a, sol_b):
                combined = tuple(sorted(set(sol_a) | set(sol_b), key=abs))
                next_combined.add(combined)
    combined_solutions = next_combined

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.5f} seconds")
print(f"There are {len(combined_solutions)} valid solutions using all subsets of solutions")

helo


KeyboardInterrupt: 

### Parallelization of intersection method

In [15]:
start_time = time.time()

combined_solutions = False
print(type(sols_per_chunk[1]))

for sol_combo in product(*sols_per_chunk):
    if multi_compatible(sol_combo):
        combined_solutions = True
        break


end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.5f} seconds")
print(f"Satisfiable: {combined_solutions}")

<class 'list'>
Elapsed time: 0.76148 seconds
Satisfiable: True


In [35]:
def solution_finder(*sols):
    for sol_combo in product(*sols):
        sol = multi_compatible(sol_combo) 
        if sol != False:
            return sol
    return False

def solution_counter(*sols):
    ct = 0
    for sol_combo in product(*sols):
        sol = multi_compatible(sol_combo) 
        if sol != False:
            ct += 1
    return ct


In [36]:
print(solution_finder(*sols_per_chunk))
print(solution_counter(*sols_per_chunk))
print(sols_per_chunk)

{1, 6, 8, 9, -7, -5, -4, -3, -2}
33


In [65]:
# creates smaller sublists which we can parallelize

def solution_sublists_counter(*sols, split):
    sol_ct = 0
    big_sublist = [None] * len(sols)
    for i, sol_list in enumerate(sols): # split all of the solution lists into smaller solution lists
        big_sublist[i] = split_into_chunks(sol_list, split)
    for combo in product(*big_sublist):
        sol_ct += solution_counter(*combo)
    return sol_ct


def solution_sublists_generator(*sols, split):
    combos = []
    big_sublist = [None] * len(sols)
    for i, sol_list in enumerate(sols): # split all of the solution lists into smaller solution lists
        big_sublist[i] = split_into_chunks(sol_list, split)
    for combo in product(*big_sublist):
        if any(not x for x in combo):
            continue
        combos.append(list(combo))
    return combos


In [57]:
print(solution_sublists_counter(*sols_per_chunk, split=4))
# YESSSSSS it works

33


In [75]:
combos = solution_sublists_generator(*sols_per_chunk, split=3)
ct = 0

for combo in combos:
    ct += solution_counter(*combo)
print(ct)

33


Now just need to use python multiprocessing to run each combo in parallel... could also send some, and if they all return nothing, send more, to guarantee that the code will break as soon as possible