In [2]:
import numpy as np
import itertools

def evaluate(n: int) -> int:
  """Returns the size of an `n`-dimensional cap set."""
  capset = solve(n)
  return len(capset)


def solve(n: int) -> np.ndarray:
  """Returns a large cap set in `n` dimensions."""
  all_vectors = np.array(list(itertools.product((0, 1, 2), repeat=n)), dtype=np.int32)

  # Powers in decreasing order for compatibility with `itertools.product`, so
  # that the relationship `i = all_vectors[i] @ powers` holds for all `i`.
  powers = 3 ** np.arange(n - 1, -1, -1)

  # Precompute all priorities.
  priorities = np.array([priority(tuple(vector), n) for vector in all_vectors])

  # Build `capset` greedily, using priorities for prioritization.
  capset = np.empty(shape=(0, n), dtype=np.int32)
  while np.any(priorities != -np.inf):
    # Add a vector with maximum priority to `capset`, and set priorities of
    # invalidated vectors to `-inf`, so that they never get selected.
    max_index = np.argmax(priorities)
    vector = all_vectors[None, max_index]  # [1, n]
    blocking = np.einsum('cn,n->c', (- capset - vector) % 3, powers)  # [C]
    priorities[blocking] = -np.inf
    priorities[max_index] = -np.inf
    capset = np.concatenate([capset, vector], axis=0)

  return capset



def priority(el: tuple[int, ...], n: int) -> float:
  """Returns the priority with which we want to add `element` to the cap set.
  el is a tuple of length n with values 0-2.
  """
  """Improved version of `priority_v1`."""

  # Count the number of 0s, 1s, and 2s in the cap set
  counts = np.bincount(el, minlength=3)

  # Calculate the priority as the size of the cap set divided by the sum of the
  # squared differences between the counts, with a bonus for having more 1s and
  # a penalty for having more 2s than 0s and for having duplicate elements.
  return len(el) / (np.sum(np.square(np.diff(counts))) + np.square(counts[2] - counts[0]) + (len(el) - len(set(el)))) + counts[1] / n
# #alternative priority function
# def priority(k: int, n: int) -> float:
#   """Returns the priority with which we want to add `element` to the salem-spencer set.
#   n is the number of possible integers, and k is the integer we want to determine priority for. 
#   """
#   """Improved version of `priority_v0`."""
#   freq = sum(1 for i in range(1, n + 1) if i % k == 0)  # Frequency of k in [1, n]
#   mid = n // 2 + 1  # Middle of the range [1, n]
#   return freq / (mid - abs(k - mid))  # Prioritize closer to the middle and higher frequency

In [4]:
evaluate_capset=(solve(10))
print("len capset:",len(evaluate_capset))


len capset: 1537
