# **General graph code**

- seems to be the same as the tree case at face value
- haven't actually updated all the optimizations in the way that i updated for the tree case yet
- so look at that first

In [None]:
import tensorflow as tf
import itertools
import pickle
from functools import lru_cache
import matplotlib.pyplot as plt



num_samples = 5  # sample discretization of axis (num = 10 ---> 0.1 increments)
q = 5  # of colors
max_iter = 5  # max number of iterations to compute S for

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
  def F(x):
      # size (n, d, q) (n lists of d distributions over q colors)
      # this is a working TF version of f

      n = tf.shape(x)[0]
      delt = tf.shape(x)[1]
      num_colors = tf.shape(x)[2]

      # Initialize Z and p_x
      Z = tf.zeros([n], dtype=tf.float32)
      p_x = tf.ones([n, num_colors], dtype=tf.float32)

      # Compute p_x and Z
      for j in range(num_colors):
          # Compute product for each color j
          p_x_j = tf.reduce_prod(1 - x[:, :, j], axis=1)

          # Accumulate in Z
          Z += p_x_j

          # Update the j-th column of p_x
          indices = tf.stack([tf.range(n), tf.fill([n], j)], axis=1)
          p_x = tf.tensor_scatter_nd_update(p_x, indices, p_x_j)

      # Normalize p_x by Z
      p_x /= tf.expand_dims(Z, axis=1)

      return p_x


# add probability of 1 if the child below is fixed a color
def generate_colors(q):
    cols = ['a','b','c','d','e','f','g','h','i','j','k']
    return cols[:q]

def possible_pairings(q):
    # valid types on a q coloring for the tree where d = q-2
    colors = generate_colors(q)
    subsets = []
    for r in range(2, len(colors) + 1):
        for combo in itertools.combinations(colors, r):
            subsets.append(list(combo))

    pairings = []
    for sub in subsets:
        pairings.append((0, sub))

    for d in range(1, q-1):
        subsets = itertools.combinations(colors, d + 2)
        for subset in subsets:
            pairings.append((d, list(subset)))

    return pairings

def get_sublist_indices(lst, sublist):
    # returns a list of indices of the sublist values in the original list
    original_list_set = set(lst)
    indices = []
    for element in sublist:
        if element in original_list_set:
            indices.append(lst.index(element))
        else:
            indices.append(None)
    return indices

def uniform(q, L):
    # returns the uniform distribution over L (as a subset of the full list of q colors)
    colors = generate_colors(q)
    dist = tf.zeros([q], dtype=tf.float32)

    indices = get_sublist_indices(colors, L)
    prob = 1.0 / len(L)
    for ind in indices:
        if ind is not None:
            dist = tf.tensor_scatter_nd_update(dist, [[ind]], [prob])

    return tf.reshape(dist, (1, q))

def generate_combos(delta, q):
    # generates list of all possible configurations over the delta children
    pairings = possible_pairings(q)
    return list(itertools.product(pairings, repeat=delta))

def discretize_simplex(q, subset, n, mass=1, dim=1):
    # n is the number of parts we split the unit mass into, and divide that amongst subset
    # discretizing the simplex on subset, a list of colors that is a subset of generate_colors(q)
    d = len(subset)
    full_list = generate_colors(q)

    @lru_cache(None)
    def generate_points(mass, dim):
        points = []
        if dim == d:
            return [[mass]]

        values = tf.range(0, mass + 1. / n, 1. / n).numpy()
        for v in values:
            suffixes = generate_points(mass - v, dim + 1)
            points.extend([[v] + s for s in suffixes])

        return points

    points = generate_points(mass, dim)

    if dim == 1:
        # Embed points in 4D space with zeros in the appropriate positions
        points_qd = tf.zeros([len(points), q], dtype=tf.float32)
        indices = [full_list.index(item) for item in subset]

        for i, point in enumerate(points):
            for j, index in enumerate(indices):
                points_qd = tf.tensor_scatter_nd_update(points_qd, [[i, index]], [point[j]])

        return points_qd
    else:
        return points

def find_nearest_vectors(vectors, probs):
    # Calculate the distances from each prob_vector to each vectors
    distances = tf.norm(vectors[:, tf.newaxis] - probs, axis=2)

    # Find the indices of the nearest vectors
    nearest_indices = tf.argmin(distances, axis=0)

    # Select the nearest vectors based on the indices
    nearest_vectors = tf.gather(vectors, nearest_indices)

    return nearest_vectors.numpy()


def S(q, delta, L, r):
    # this the mega recursion for general delta and q
    # delta is likely to be q - 2 exactly, for q >= 4

    discretized = discretize_simplex(q, L, num_samples)

    if delta == 0:
        return uniform(q, L)

    if r == 1:
        return discretized

    combos = generate_combos(delta, q)  # gives me possible combinations over the delta children
    diff = (q-2) - delta
    zros = [[0 for i in range(q)]]

    final_probabilities = tf.zeros([0, q], dtype=tf.float32)


    types = possible_pairings(q)
    temps = []

    for dlt1, lst1 in types:
      col = "".join(lst1)
      name = f"S_{q}_{dlt1}_{col}_{r-1}_{num_samples}.pkl"

      with open(name, 'rb') as file:
        temp = pickle.load(file)

      temps.append(temp)

    total = tf.concat(temps, axis=0)
    filtered, inds = tf.raw_ops.UniqueV2(x=total, axis=[0])


    for dlt2, lst2 in types:
      col = "".join(lst2)
      name = f"S_{q}_{dlt2}_{col}_{r-1}_{num_samples}.pkl"

      with open(name, 'rb') as file:
        temp2 = pickle.load(file)

      sets = []
      sets.append(temp2)

      for i in range(delta-1):
        sets.append(filtered)

      for i in range(diff):
        sets.append(zros)

      collection = list(itertools.product(*sets))

      distributions = [list(dis) for dis in collection]  # convert everything into lists of lists of lists (shape n,d,q)

      distributions_array = tf.convert_to_tensor(distributions, dtype=tf.float32)

      probs = F(distributions_array)  # outputs list of possible probabilities
      projs = find_nearest_vectors(discretized, tf.convert_to_tensor(probs, dtype=tf.float32))  # the projection takes in vectors

      final_probabilities = tf.concat([final_probabilities, projs], axis=0)

    final_probs, idx = tf.raw_ops.UniqueV2(x=final_probabilities, axis=[0])

    return final_probs

# Computing S for different r, accessing stored files.
for r in range(1, max_iter + 1):
    for tup in possible_pairings(q):  # evaluate S for all value combinations
        delta, L = tup
        colors = "".join(L)
        name = f"S_{q}_{delta}_{colors}_{r}_{num_samples}.pkl"

        lst = S(q, delta, L, r)

        with open(name, 'wb') as file:
            pickle.dump(lst, file)

        print(name + " is done")

# Code for visualizing the first two coordinates
import matplotlib.pyplot as plt

q_ = 4
delt = 2
lst = ['a','b','c','d']

cols = "".join(lst)

# comparing to the original discretization at r = 1
name_orig = f"S_{q_}_{delt}_{cols}_{1}_{num_samples}.pkl"
with open(name_orig, 'rb') as file:
    S_orig = pickle.load(file)


name = f"S_{q_}_{delt}_{cols}_{max_iter}_{num_samples}.pkl"
with open(name, 'rb') as file:
    S_r = pickle.load(file)


# Extract p1 and p2 values
p1_orig = [p[0] for p in S_orig]
p2_orig = [p[1] for p in S_orig]

p1_values = [q[0] for q in S_r]
p2_values = [q[1] for q in S_r]

# Create the plot
plt.figure(figsize=(10,10))
plt.scatter(p1_orig, p2_orig, c='green', label='p1 (r=1), p2 (r=1)', s=12)
plt.scatter(p1_values, p2_values, c='red', label=f'p1 (r={max_iter}), p2 (r={max_iter})', s=15)

plt.xlabel('p1')
plt.ylabel('p2')
plt.title(f'Probability Vectors [p1, p2] for S({delt}, {lst}, {max_iter}), initially using {num_samples} points to discretize')
plt.legend()
plt.grid(True)

# Set axis limits to show the whole simplex
plt.xlim(-0.1, 1.1)
plt.ylim(-.1, 1.1)

# Show plot
plt.show()


S_5_0_ab_1_5.pkl is done
S_5_0_ac_1_5.pkl is done
S_5_0_ad_1_5.pkl is done
S_5_0_ae_1_5.pkl is done
S_5_0_bc_1_5.pkl is done
S_5_0_bd_1_5.pkl is done
S_5_0_be_1_5.pkl is done
S_5_0_cd_1_5.pkl is done
S_5_0_ce_1_5.pkl is done
S_5_0_de_1_5.pkl is done
S_5_0_abc_1_5.pkl is done
S_5_0_abd_1_5.pkl is done
S_5_0_abe_1_5.pkl is done
S_5_0_acd_1_5.pkl is done
S_5_0_ace_1_5.pkl is done
S_5_0_ade_1_5.pkl is done
S_5_0_bcd_1_5.pkl is done
S_5_0_bce_1_5.pkl is done
S_5_0_bde_1_5.pkl is done
S_5_0_cde_1_5.pkl is done
S_5_0_abcd_1_5.pkl is done
S_5_0_abce_1_5.pkl is done
S_5_0_abde_1_5.pkl is done
S_5_0_acde_1_5.pkl is done
S_5_0_bcde_1_5.pkl is done
S_5_0_abcde_1_5.pkl is done
S_5_1_abc_1_5.pkl is done
S_5_1_abd_1_5.pkl is done
S_5_1_abe_1_5.pkl is done
S_5_1_acd_1_5.pkl is done
S_5_1_ace_1_5.pkl is done
S_5_1_ade_1_5.pkl is done
S_5_1_bcd_1_5.pkl is done
S_5_1_bce_1_5.pkl is done
S_5_1_bde_1_5.pkl is done
S_5_1_cde_1_5.pkl is done
S_5_2_abcd_1_5.pkl is done
S_5_2_abce_1_5.pkl is done
S_5_2_abde_1_

In [None]:
# Jacobian calculations



In [None]:
pairs = possible_pairings(q)
temps = []

for dlt, lst in pairs:
  col = "".join(lst)
  name = f"S_{q}_{dlt}_{col}_{3}_{num_samples}.pkl"

  with open(name, 'rb') as file:
    temp = pickle.load(file)
  temps.append(temp)

  print(temp.shape)

total = tf.concat(temps, axis=0)
filtered, inds = tf.raw_ops.UniqueV2(x=total, axis=[0])

filtered.shape

(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(1, 4)
(6, 4)
(6, 4)
(6, 4)
(6, 4)
(56, 4)


TensorShape([91, 4])