# **Recursion implementation. Hard coded the min-max bounds**
- 0.033 (n=30) discretization is feasible for q = 4, q = 5 should be doable if we let it run long

In [2]:
import tensorflow as tf
import itertools
import pickle
from functools import lru_cache
import matplotlib.pyplot as plt


num_samples = 20  # sample discretization of axis (e.g. num = 10 ---> 0.1 increments)
q = 4  # of colors
max_iter = 5  # max number of iterations to compute S for

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
  def F(x):
      # size (n, d, q) (n lists of d distributions over q colors)
      # this is a working TF version of f

      n = tf.shape(x)[0]
      delt = tf.shape(x)[1]
      num_colors = tf.shape(x)[2]

      # Initialize Z and p_x
      Z = tf.zeros([n], dtype=tf.float32)
      p_x = tf.ones([n, num_colors], dtype=tf.float32)

      # Compute p_x and Z
      for j in range(num_colors):
          # Compute product for each color j
          p_x_j = tf.reduce_prod(1 - x[:, :, j], axis=1)

          # Accumulate in Z
          Z += p_x_j

          # Update the j-th column of p_x
          indices = tf.stack([tf.range(n), tf.fill([n], j)], axis=1)
          p_x = tf.tensor_scatter_nd_update(p_x, indices, p_x_j)

      # Normalize p_x by Z
      p_x /= tf.expand_dims(Z, axis=1)

      return p_x


# add probability of 1 if the child below is fixed a color
def generate_colors(q):
    cols = ['a','b','c','d','e','f','g','h','i','j','k']
    return cols[:q]

def possible_pairings(q):
    # valid types on a q coloring for the tree where d = q-2
    colors = generate_colors(q)
    subsets = []
    for r in range(2, len(colors) + 1):
        for combo in itertools.combinations(colors, r):
            subsets.append(list(combo))

    pairings = []
    for sub in subsets:
        pairings.append((0, sub))

    for d in range(1, q-1):
        subsets = itertools.combinations(colors, d + 2)
        for subset in subsets:
            pairings.append((d, list(subset)))

    return pairings

def get_sublist_indices(lst, sublist):
    # returns a list of indices of the sublist values in the original list
    original_list_set = set(lst)
    indices = []
    for element in sublist:
        if element in original_list_set:
            indices.append(lst.index(element))
        else:
            indices.append(None)
    return indices

def uniform(q, L):
    # returns the uniform distribution over L (as a subset of the full list of q colors)
    colors = generate_colors(q)
    dist = tf.zeros([q], dtype=tf.float32)

    indices = get_sublist_indices(colors, L)
    prob = 1.0 / len(L)
    for ind in indices:
        if ind is not None:
            dist = tf.tensor_scatter_nd_update(dist, [[ind]], [prob])


    tensr = tf.reshape(dist, (1, q))

    return tensr

def generate_combos(delta, q):
    # generates list of all possible configurations over the delta children
    pairings = possible_pairings(q)
    return list(itertools.product(pairings, repeat=delta))

def discretize_simplex(q, subset, n, mass=1, dim=1):
    # n is the number of parts we split the unit mass into, and divide that amongst subset
    # discretizing the simplex on subset, a list of colors that is a subset of generate_colors(q)
    d = len(subset)
    full_list = generate_colors(q)

    @lru_cache(None)
    def generate_points(mass, dim):
        points = []
        if dim == d:
            return [[mass]]

        values = tf.range(0, mass + 1. / n, 1. / n).numpy()
        for v in values:
            suffixes = generate_points(mass - v, dim + 1)
            points.extend([[v] + s for s in suffixes])

        return points

    points = generate_points(mass, dim)

    if dim == 1:
        # Embed points in 4D space with zeros in the appropriate positions
        points_qd = tf.zeros([len(points), q], dtype=tf.float32)
        indices = [full_list.index(item) for item in subset]

        for i, point in enumerate(points):
            for j, index in enumerate(indices):
                points_qd = tf.tensor_scatter_nd_update(points_qd, [[i, index]], [point[j]])

        return points_qd
    else:

        return points


def box_bound(tensor, min, max):
  # Condition to keep rows where all elements are between min and max (likely around like .15-.45)
  # Create a boolean mask for values between min and max (excluding 0 values)
  mask = tf.math.logical_or(tensor == 0, tf.math.logical_and(tensor >= min, tensor <= max))

  # Apply mask to all rows, ensuring the entire row passes the condition
  condition = tf.reduce_all(mask, axis=1)


  # Apply the condition to extract the desired rows
  subtensor = tf.boolean_mask(tensor, condition)
  return subtensor


def find_nearest_vectors(vectors, probs):
    # project of probs onto vectors
    # Calculate the distances from each prob_vector to each vectors
    distances = tf.norm(vectors[:, tf.newaxis] - probs, axis=2)

    # Find the indices of the nearest vectors
    nearest_indices = tf.argmin(distances, axis=0)

    # Select the nearest vectors based on the indices
    nearest_vectors = tf.gather(vectors, nearest_indices)

    return nearest_vectors.numpy()


def find_indices(full, final):

    # Expand dimensions for broadcasting
    full_expanded = tf.expand_dims(full, axis=1)  # Shape: (n, 1, q)
    final_expanded = tf.expand_dims(final, axis=0)  # Shape: (1, m, q)

    # Check where vectors match
    matches = tf.reduce_all(tf.equal(full_expanded, final_expanded), axis=-1)  # Shape: (n, m)

    # Find the index of the first occurrence in 'full' for each vector in 'final'
    indices = tf.argmax(tf.cast(matches, tf.int32), axis=0)

    return indices


def unique_elements(matrix):
    # Flatten the 2D matrix into a 1D tensor
    flattened = tf.reshape(matrix, [-1])

    # Get the unique elements using tf.unique
    unique_elements, _ = tf.unique(flattened)

    return unique_elements



def S(q, delta, L, r):
    # this the mega recursion for general delta and q
    # delta is likely to be q - 2 exactly, for q >= 4

    discretized = discretize_simplex(q, L, num_samples)
    subset = box_bound(discretized, 0.15, 0.45) # hard coded the min-max based on analytics

    final_probabilities = tf.zeros([0, q], dtype=tf.float32)
    F_inpts = tf.zeros((0, q-2, q)) # q-2 is the number of distributions fed up (can be zeros sometimes)

    if delta == 0:
        uni = uniform(q, L)
        return uni, F_inpts

    if r == 1:
        # print('disc shape: ', discretized.shape)
        # print('box shape: ', subset.shape)

        # trying to restrict the initial discretization set
        # then modify the algorithm to return any set of valid inputs that match the outputs, using which we compute the jacobian

        return subset, F_inpts


    combos = generate_combos(delta, q)  # gives me possible combinations over the delta children
    diff = (q-2) - delta
    zros = [[0 for i in range(q)]]

    for combo in combos:
        sets = []  # list of delta lists of possible distributions over the delta children

        for tup in combo:
            col = "".join(tup[1])
            name = f"S_{q}_{tup[0]}_{col}_{r-1}_{num_samples}.pkl"

            with open(name, 'rb') as file:
                s = pickle.load(file)

            sets.append(s)

        for j in range(diff):
            sets.append(zros)  # padding if needed, padding done natively not using numpy

        collection = list(itertools.product(*sets))
        distributions = [list(dis) for dis in collection]  # convert everything into lists of lists of lists (shape n,d,q)

        distributions_array = tf.convert_to_tensor(distributions, dtype=tf.float32)

        probs = F(distributions_array)  # outputs list of possible probabilities
        projs = find_nearest_vectors(subset, tf.convert_to_tensor(probs, dtype=tf.float32))  # the projection takes in vectors

        final_probabilities = tf.concat([final_probabilities, projs], axis=0)

        F_inpts = tf.concat([F_inpts, distributions_array], axis=[0]) # do this in list format, and index explicitly

    final_probs, idx = tf.raw_ops.UniqueV2(x=final_probabilities, axis=[0])

    indices = find_indices(final_probabilities, final_probs) # finds the indices where each of the distributions in the final set were found in the redundant set
    unique_indices = unique_elements(indices)

    F_inpts_final = tf.gather(F_inpts, unique_indices) # indexes the inputs at these unique indices

    return final_probs, F_inpts_final


# Computing S for different r, accessing stored files.
for r in range(6,11):
    for tup in possible_pairings(q):  # evaluate S for all value combinations
        delta, L = tup
        colors = "".join(L)
        name = f"S_{q}_{delta}_{colors}_{r}_{num_samples}.pkl"
        input_name = f"S_{q}_{delta}_{colors}_{r}_{num_samples}_inputs.pkl"

        lst, inpts = S(q, delta, L, r)

        with open(name, 'wb') as file:
            pickle.dump(lst, file)

        with open(input_name, 'wb') as file:
            pickle.dump(inpts, file)

        print(name + " is done")

# Code for visualizing the first two coordinates
import matplotlib.pyplot as plt

q_ = 4
delt = 2
lst = ['a','b','c','d']

cols = "".join(lst)

# comparing to the original discretization at r = 1
name_orig = f"S_{q_}_{delt}_{cols}_{1}_{num_samples}.pkl"
with open(name_orig, 'rb') as file:
    S_orig = pickle.load(file)


name = f"S_{q_}_{delt}_{cols}_{max_iter}_{num_samples}.pkl"
with open(name, 'rb') as file:
    S_r = pickle.load(file)


# Extract p1 and p2 values
p1_orig = [p[0] for p in S_orig]
p2_orig = [p[1] for p in S_orig]

p1_values = [q[0] for q in S_r]
p2_values = [q[1] for q in S_r]

# Create the plot
plt.figure(figsize=(10,10))
plt.scatter(p1_orig, p2_orig, c='green', label='p1 (r=1), p2 (r=1)', s=9)
plt.scatter(p1_values, p2_values, c='red', label=f'p1 (r={max_iter}), p2 (r={max_iter})', s=9)

plt.xlabel('p1')
plt.ylabel('p2')
plt.title(f'Probability Vectors [p1, p2] for S({delt}, {lst}, {max_iter}), initially using {num_samples} points to discretize')
plt.legend()
plt.grid(True)

# Set axis limits to show the whole simplex
plt.xlim(-0.1, 1.1)
plt.ylim(-.1, 1.1)

# Show plot
plt.show()


# F(F_inputs) yields an array that should project onto the output final_prob


KeyboardInterrupt: 

# **Jacobian calculations**
- L1 norm ~ 0.75
- ** norm ~ 0.78
- very very crude implementation but functionally fine

In [1]:
import numpy as np

name = f"S_{q_}_{delt}_{cols}_{max_iter}_{num_samples}.pkl"
with open(name, 'rb') as file:
    S_r = pickle.load(file)


input_name = f"S_{q_}_{delt}_{cols}_{max_iter}_{num_samples}_inputs.pkl"
with open(input_name, 'rb') as file:
    F_inpts = pickle.load(file)


# I wrote the jacobian code to be compatible with lists before
# so i convert the probabilities to list type for simplicity

pr_s = F(F_inpts).numpy().tolist()
# note we could also use the projection S_r directly instead of
# feeding the F_inputs back into F and using those for pr_s


max1 = 0
maxstar = 0

random_vectors = []
for i in range(1000):
   vec = np.random.normal(0,1,(2,4))
   random_vectors.append(vec.tolist())


for k in range(len(pr_s)):
  # compute the jacobian
  distribution = np.array(pr_s[k])
  inputs = np.array(F_inpts[k])

  jacs = []

  for i in range(len(F_inpts[k])):
    p_i = np.array(F_inpts[k][i])

    jac = (np.outer(distribution, distribution) - np.diag(distribution)) @ np.linalg.pinv(np.diag(1-p_i))
    jacs.append(jac)

  jacobian = np.concatenate(jacs, axis=1)
  print("shap: ", jacobian.shape)

  temp1 =  np.linalg.norm(jacobian, 1)
  if temp1 > max1:
     max1 = temp1

  maxi = 0

  for inp in random_vectors: # test star star norm at random vectors in the unit ball, could look at vectors in F_inpts as well
    ar = np.array(inp).flatten()
    num = np.linalg.norm(jacobian @ ar, 2)

    norms = []
    for vec in inp:
      norm2 = np.linalg.norm(vec, 2)
      norms.append(norm2)

    val = np.max(np.array(norms))
    val = num/val


    if val > maxi:
      maxi = val

  if maxi > maxstar:
     maxstar = maxi


print("L1 norm: ", max1)
print("** norm: ", maxstar)



NameError: name 'q_' is not defined

# **Bugs/next things to implement**

- only bug i can see is that under the new box bounds, occasionally we're left with some points on the axis. I think that's a result of how i've defined the min-max condition in box_bound()
- also the jacobian norm calculation is done very crudely. I'll need to double check this again at some point.
- other optimizations include exploiting permutation symmetry to speed up (e.g. once you've calculated for abc you know abd, acd, bcd)
- also i do convert once between list and tensor type in the middle of S to make it easier to construct the possible combinations
- Next steps: general graphs, higher-order trees faster. closed form solution on box bounds will help.
