In [None]:
# from multiprocessing import Pool
from itertools import product
from sympy import symbols
import numpy as np
import time 
from pysat.solvers import Solver
import random
from pysat.solvers import Glucose3

###  FUNCTIONS ---------------------------------------------

def combinatorial_line(word, alphabet):
    lines = []
    for a in alphabet:
        new_word = list(a if val == x else val for val in word)  # tuple instead of list
        lines.append(new_word)
    return lines

def create_var_map(tuples):
    """
    Input: list of tuples (each tuple can have arbitrary length)
    Output: dictionary mapping each tuple -> unique integer literal
    """
    return {t: i+1 for i, t in enumerate(tuples)}

def tuple_to_litmap(cl, map, lit_map):
    """
    Map a tuple -> var_map ID -> lit_map value
    """
    literals = [map[tuple(c)] for c in cl]
    literals = [lit_map[l] for l in literals]
    return literals


def split_into_chunks(lst, num_chunks):
    n = len(lst)
    chunk_sizes = [(n + i) // num_chunks for i in range(num_chunks)]
    chunks = []
    start = 0
    for size in chunk_sizes:
        chunks.append(lst[start:start + size])
        start += size
    return chunks

def compatible(set1, set2):
    # check if any element in set1 has its negation in set2
    return not any(-x in set2 for x in set1)

def multi_compatible(list_of_lists):  # for the quick lower bound method
    sets = [set(lst) for lst in list_of_lists]  # convert to sets
    for i in range(len(sets)):
        for j in range(i + 1, len(sets)):
            if any(-x in sets[j] for x in sets[i]):
                return False
    return True

def split_into_random_chunks(lst, num_chunks):
    # Shuffle the list first
    shuffled = lst[:]
    random.shuffle(shuffled)
    
    n = len(shuffled)
    chunk_sizes = [(n + i) // num_chunks for i in range(num_chunks)]
    
    chunks = []
    start = 0
    for size in chunk_sizes:
        chunks.append(shuffled[start:start + size])
        start += size
    return chunks
# HPC FUNCTION
def process_item(item):
    # Your processing logic goes here
    result = (item[0] * item[0], item[1] + item[1], item[2] - item[2])
    return result

def combinatorial_lines(n,m)
    print(f"Generating combinatorial lines and boolean literals for {m} = HJ({n},2)")
    alphabet = list(range(1, n+1))
    configs = []
    variable_words = []
    cells = list(product(alphabet, repeat=m))
    for mask in product([0, 1], repeat=m):
        if sum(mask) == 0:  # skip all-blank
            continue
        config = [x if bit else 0 for bit in mask]
        configs.append(config) 
    for c in configs:
        zero_positions = [i for i, val in enumerate(c) if val == 0]
        k = len(zero_positions)
            
        for fill in product(alphabet, repeat=k):
            word = list(c)  
            for pos, val in zip(zero_positions, fill):
                word[pos] = val
            variable_words.append(word)  
    
    combinatorial_lines = [] # create the combinatorial lines
    for vw in variable_words:
        combinatorial_lines.append(combinatorial_line(vw, alphabet))
    return combinatorial_lines

def lines_to_literals(lines, n, m):
    alphabet = list(range(1, n+1))
    cells = list(product(alphabet, repeat=m))
    map = create_var_map(cells) # define the map from line to boolean

    for cl in lines: # for each combinatorial line, we are going to generate 2 clauses
        literals = [map[tuple(c)] for c in cl]
    return literals


def solve_chunk(chunk):
    """
    Solve a chunk using a SAT solver with optional symmetry blocking on x1.
    
    Args:
        chunk (list of lists/tuples): The chunk of constraints to solve.
        map (dict): Mapping from tuple elements to integer literals.
        tuple_to_litmap (func): Function to convert tuples to literal mapping.
        symmetry_block (bool, optional): Whether to symmetry block x1 if present.
        
    Returns:
        solutions (list of sets): List of solutions in terms of original tuples.
        solution_count (int): Number of solutions found.
        lit_map (dict): Mapping from original literals to solver literals.
    """
    solver = Solver(name='glucose4')
    literals_in_chunk = set()
    x1_in_lits = False

    # Collect literals
    for cl in chunk:
        literals = lines_to_literals(chunk)
        literals_in_chunk.update(literals)

    # Generate literal maps
    lit_map = {orig: i+1 for i, orig in enumerate(literals_in_chunk)}
    rev_map = {v: k for k, v in lit_map.items()}

    if 1 in literals_in_chunk:
        x1_in_lits = True
        #print("Performing symmetry blocking on x_1")

    #print("Literals before mapping:", literals_in_chunk)
    #print("Literals after mapping:", lit_map)

    # Add clauses
    for cl in chunk:
        literals = tuple_to_litmap(cl, map, lit_map)
        solver.add_clause(literals)             # all-blue check
        solver.add_clause([-lit for lit in literals])  # all-red check

    if x1_in_lits:
        solver.add_clause([1])

    # Enumerate solutions
    solutions = []
    for model in solver.enum_models():
        mapped_model = set(
            rev_map[abs(lit)] if lit > 0 else -rev_map[abs(lit)]
            for lit in model if abs(lit) in rev_map
        )
        solutions.append(mapped_model)

    solution_count = len(solutions)
    print(f"Found {solution_count} solutions in this chunk.")
    return solutions, solution_count, lit_map

def generate_chunks(combinatorial_lines) # not random
    chunk_list = []
    combinatorial_chunks = split_into_chunks(combinatorial_lines, ch_num)
        for i, chunk in enumerate(combinatorial_chunks):
        chunk_list.append(chunk)          # store in a list for iteration
        print(f"Chunk {i} has {len(chunk)} lines")
    return chunk_list

def chunk_solutions(chunk_list) # gets smaller solutions
    start_time = time.time()
    
    sols_per_chunk = [0 * i for i in range(len(chunk_list))]
    
    for i,chunk in enumerate(chunk_list):
        solutions, count, lit_map = solve_chunk(chunk)
        sols_per_chunk[i] = solutions
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Elapsed time: {elapsed_time:.5f} seconds")
    return sols_per_chunk
###  SCRIPT ---------------------------------------------





###  EXECUTION ---------------------------------------------

if __name__ == '__main__':
    # Create a Pool object
   # pool = Pool()
    x = symbols('x')
    m = 2 # m is the potential Hales-Jewett number 
    n = 4 # n is the length of the alphabet
    c = 2 # don't actually use this, just a reminder that we set c = 2
    # so, we're looking at HJ(n,2)
    ch_num = 4

    lines = combinatorial_lines(n,m)
    chunk_list = generate_chunks(lines)
    sols = chunk_solutions(chunk_list)

    # Initialize the iterable object
   # items = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]

    # Execute the loop in parallel
    #results = pool.map(process_item, items)

    # Close the created Pool
   # pool.close()
   # pool.join()

    # Print the results
   # print(results)