### Definitions (run first)

In [None]:
import sympy as s
import numpy as np
import itertools
from collections import defaultdict

###################################################################################

# BASIC UTILITIES

def g_helper(i,j,k,t):
  m = s.symbols("m", integer=True)
  return s.exp(-s.binomial(k,2)*t) * (2*k-1)*(-1)**(k-j)/(s.factorial(j)*s.factorial(k-j)*(j+k-1)) * s.product((j+m)*(i-m)/(i+m),(m,0,k-1))
def g_tavare(i,j,t):
  if j > i or i < 1:
    return 0
  return sum(g_helper(i,j,k,t) for k in range(j, i+1))

def decay(dict, dt):
  result = defaultdict(int)
  for i, val in dict.items():
    for j in range(1,i+1):
      result[j] += val*g_tavare(i,j,dt)
  return result

def mult(dict1, dict2):
  result = defaultdict(int)
  if len(dict1) == 0: return dict2
  if len(dict2) == 0: return dict1
  for i, val1 in dict1.items():
    for j, val2 in dict2.items():
      result[i+j] += val1*val2
  return result

def es(func):
  def inner_func(*args, **kwargs):
      return s.expand(s.simplify(func(*args, **kwargs)))
  return inner_func

def expand_and_simplify(expr):
    return s.expand(s.simplify(expr))

def height(lineage_probs):
  return 2 - sum(val*s.Rational(2, m) for m, val in lineage_probs.items())

def to_dict(func):
  def inner_func(*args, **kwargs):
    return dict(func(*args, **kwargs))
  return inner_func

def subdict(dict1, keys):
  return {key: dict1[key] for key in keys}

def intersect(dict1, dict2):
  keys = set(dict1) & set(dict2)
  return subdict(dict1, keys), subdict(dict2, keys)

###################################################################################

# NODE CLASS

class Node:
  def __init__(self, label=None, parent=None, children=None, time=0.0, length=0.0):
    self.label = label
    self.parent = parent
    self.children = children or []
    self.time = time
    self.length = length

    for child in self.children:
      child.parent = self

  def __iter__(self):
    yield self
    for child_node in reversed(self.children):
      for node in child_node:
        yield node
  
  def __add__(self, other):
    return Node(
      label = self.label | other.label, 
      children = [self, other]
    )

  def __str__(self):
    if self.is_tip():
      return f"{''.join(map(str,self.label))}"
    return f"({''.join(str(child_node) for child_node in self.children)})"
  
  def str_label(self) -> str:
    return ''.join(sorted(map(str,self.label)))
    
  def parent_iter(self, include_root=True):
    yield self
    if isinstance(self.parent, Node) and (include_root or not self.parent.is_root()):
      for node in self.parent.parent_iter(include_root=include_root):
        yield node

  def where(self, function):
    for node in self:
      if function(node):
        yield node

  def where_not(self, function):
    for node in self:
      if not function(node):
        yield node
                
  def is_tip(self):
    return len(self.children) == 0
      
  def is_root(self):
    return self.parent == None
      
  def get_root(self):
    node = self
    while not node.is_root():
      node = node.parent
    return node
      
  def internal(self):
    return self.where_not(lambda node: node.is_tip() or node.is_root())

  def tips_subtended(self):
    return len(self.label)
  
  # for binary trees only
  def tripartition(self):
    if self.is_tip() or self.is_root():
      return None
    left_label, right_label = self.children[0].label, self.children[1].label
    root = self.get_root()
    other_label = root.label - (left_label | right_label)
    return sorted([
      left_label,
      right_label,
      other_label
    ], key = min)
    
  def tripartitions(self):
    return [
      node.tripartition() for node in self.internal()
    ]
  
  def assign_symbolic_times(self, *times) -> None:
    internal_nodes = self.internal()
    for node, time in zip(internal_nodes, reversed(times)):
      node.time = time
            
  def assign_branch_lengths(self, *lengths) -> None:
    self.time = sum(lengths)
    internal_nodes = self.internal()
    for node, length in zip(internal_nodes, reversed(lengths)):
      node.length = length
      node.time = node.parent.time - length
        
  def lineage_probs(self, tips=None):
    if tips is None: tips = self.label
    tips = set(tips)
    if self.is_tip() and self.label <= tips: return {1: 1}
    elif self.is_tip(): return {0: 1}
    result = {0: 1}
    for child_node in self.children:
      result = mult(
        result,
        decay(
          child_node.lineage_probs(tips),
          self.time - child_node.time
        )
    )
    return result

  def mrca(self, tips):
    minimum, tips = self, set(tips)
    for node in self:
      if tips <= node.label:
        minimum = node
    return minimum
    
  # for binary trees only
  def parsimony_cost(self, derived_tips, ancestral_state=0):
    derived_tips = self.label & set(derived_tips)
    if self.is_tip():
      return np.inf if ancestral_state != len(derived_tips) else 0
    left, right = self.children[0], self.children[1]
    left_cost = min(
      left.parsimony_cost(derived_tips, ancestral_state=ancestral_state),
      1 + left.parsimony_cost(derived_tips, ancestral_state=1-ancestral_state),
    )
    right_cost = min(
      right.parsimony_cost(derived_tips, ancestral_state=ancestral_state),
      1 + right.parsimony_cost(derived_tips, ancestral_state=1-ancestral_state),
    )
  
    return left_cost + right_cost
  
  def rooted_parsimony_cost(self, derived_tips):
    return min(
      self.parsimony_cost(derived_tips, ancestral_state=0),
      1 + self.parsimony_cost(derived_tips, ancestral_state=1)
    )
    
  def unrooted_parsimony_cost(self, derived_tips):
    return min(
      self.parsimony_cost(derived_tips, ancestral_state=0),
      self.parsimony_cost(derived_tips, ancestral_state=1)
    )
        
###################################################################################

# BRANCH LENGTH UTILITIES

def isps(iterable):
  s = list(iterable)
  return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(2,len(s)))

def isps_unrooted(iterable):
  s = list(iterable)
  return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(2,len(s)-1))

def subtending_length(tree, tips):
  tips_mrca = tree.mrca(tips) 
  return tree.time - tips_mrca.time + height(tree.lineage_probs()) - height(tips_mrca.lineage_probs(tips=tips))
  
def branch_length(tree, tips):
  tips, tot = set(tips), 0
  for isp in isps(tree.label):
    if set(isp) >= tips:
      tot = tot + (s.S.NegativeOne)**(len(tips)-len(isp))*subtending_length(tree, set(isp))
  return tot

def subtending_lengths(tree):
  result = {}
  for tips in isps(tree.label):
    result[tuple(sorted(tips))] = subtending_length(tree, tips)
  return result
    
def branch_lengths(tree, subtending_lengths_dict=None):
  if subtending_lengths_dict == None:
    subtending_lengths_dict = subtending_lengths(tree)
  
  result = {}
  for tips in isps(tree.label):
    tot = 0 
    for isp in isps(tree.label):
      if set(isp) >= set(tips):
        tot += (s.S.NegativeOne)**(len(tips)-len(isp))*subtending_lengths_dict[tuple(sorted(isp))]
    result[tuple(sorted(tips))] = expand_and_simplify(tot)
  
  return result 

# branch lengths under a star tree (i.e. coalescent process = Kingman)
def star_tree_lengths(labels):
  result = { tuple(sorted(isp)): s.Rational(2, len(isp) * s.binomial(len(labels), len(isp))) for isp in isps(labels) }
  return result 

def rooted_parsimony_score(candidate_tree, branch_lengths):
  tot = 0
  for isp in isps(candidate_tree.label):
    tot += candidate_tree.rooted_parsimony_cost(isp) * branch_lengths[isp]
  return tot

def unrooted_parsimony_score(candidate_tree, branch_lengths):
  tot = 0
  for isp in isps_unrooted(candidate_tree.label):
    tot += candidate_tree.unrooted_parsimony_cost(isp) * branch_lengths[isp]
  return tot

def generate_labeled_topologies(labels):
  n = len(labels)
  labels = set(labels)
  
  if n == 1: return [ Node(label={label}) for label in labels ]
  topologies = []
  
  for k in range(1, n // 2 + 1):
    
    left_labels = tuple(itertools.combinations(labels, k))
    max_index = len(left_labels) // 2 if 2 * k == n else len(left_labels)
    for left_label in left_labels[:max_index]:
      
      left_label = set(left_label)
      right_label = labels ^ left_label
      left_topologies = generate_labeled_topologies(left_label)
      right_topologies = generate_labeled_topologies(right_label)
      for left in left_topologies:
        for right in right_topologies:
          topologies.append(left+right)
    
  return topologies

# for rooted 5 taxa case only
def gen_inequalities_5taxa(species_tree):
  x1, x2, x3 = s.symbols('x1 x2 x3')

  species_subtending_lengths = subtending_lengths(species_tree)
  species_branch_lengths = branch_lengths(species_tree, subtending_lengths_dict=species_subtending_lengths)

  five_taxa_trees = generate_labeled_topologies(range(5))
  par_score_true = rooted_parsimony_score(species_tree, species_branch_lengths)

  inequalities = []

  for candidate_tree in five_taxa_trees:
      par_score_candidate = rooted_parsimony_score(candidate_tree, species_branch_lengths)
      inequalities.append(s.lambdify((x1,x2,x3), s.Gt(par_score_true, par_score_candidate), "numpy"))
  
  return inequalities

def has_same_unrooted_topology(tree1, tree2):
  for tri in tree1.tripartitions():
    if tri not in tree2.tripartitions():
      return False
  return True
  

### Rooted 5 taxa case

#### Setup

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.cm import ScalarMappable

colors = [
  [19, 97, 91], 
  [46, 155, 148], 
  [88, 211, 201], 
  [104, 227, 219], 
  [135, 239, 230], 
  [38, 52, 94], 
  [78, 96, 153], 
  [116, 134, 197], 
  [145, 163, 219], 
  [171, 187, 227], 
  [32, 22, 94], 
  [71, 43, 155], 
  [106, 63, 192], 
  [137, 93, 191], 
  [162, 120, 195]
]

good_cmap = ListedColormap(colors, name='good_cmap')

In [None]:
# creating the colorbar seperately and manually adding to other figures is easiest

fig, ax = plt.subplots()

norm = BoundaryNorm([i for i in range(16)], 15)
cbar = fig.colorbar(
  ScalarMappable(norm=norm, cmap=good_cmap), 
  ax=ax,
  orientation = 'horizontal'
)
cbar.set_ticks([i+0.5 for i in range(15)], labels=[i for i in range(1,16)])
cbar

In [None]:
A, B, C, D, E = [Node(label={label}) for label in range(5)]

# caterpillar topology 
cat_tree = A + B + C + D + E

# balanced topology
bal_tree = (A + B + C) + (D + E)

# pseudo-symmetric topology
sym_tree = ((A + B) + (C + D)) + E

species_trees = (cat_tree, bal_tree, sym_tree)

x1, x2, x3 = s.symbols('x1 x2 x3')
for species_tree in species_trees:
    species_tree.assign_branch_lengths(x1, x2, x3)

# generate all possible 5-taxa topologies
five_taxa_trees = list(generate_labeled_topologies(range(5)))

In [None]:
# example: finding the probability distribution of the random number of lineages that enter the root of the caterpillar tree 

cat_tree.lineage_probs({0,1,2,3,4})

In [None]:
# find the branch lengths for each of the three possible shapes of species trees 

cat_subtending_lengths = subtending_lengths(cat_tree)
cat_branch_lengths = branch_lengths(cat_tree, subtending_lengths_dict=cat_subtending_lengths)

bal_subtending_lengths = subtending_lengths(bal_tree)
bal_branch_lengths = branch_lengths(bal_tree, subtending_lengths_dict=bal_subtending_lengths)

sym_subtending_lengths = subtending_lengths(sym_tree)
sym_branch_lengths = branch_lengths(sym_tree, subtending_lengths_dict=sym_subtending_lengths)

In [None]:
# printing the expected branch lengths 

for isp, length in cat_branch_lengths.items():
  print(isp, length)

In [None]:
# expected parsimony cost for the caterpillar topology, given the true tree is a caterpillar tree

rooted_parsimony_score(cat_tree, branch_lengths=cat_branch_lengths)

In [None]:
# expected parsimony cost for the balanced topology, given the true tree is a caterpillar tree

rooted_parsimony_score(bal_tree, branch_lengths=cat_branch_lengths)

In [None]:
# expected parsimony cost for the pseudo-symmetric topology, given the true tree is a caterpillar tree

rooted_parsimony_score(sym_tree, branch_lengths=cat_branch_lengths)

In [None]:
# expected parsimony costs under a star tree (all internal branch lengths x_i = 0)

for tree in (cat_tree, bal_tree, sym_tree):
  print(
    rooted_parsimony_score(tree, branch_lengths=star_tree_lengths(range(5)))
  )

In [None]:
def gen_sym_topologies_helper(a,b,c,d,e):
  return [ ((a+b)+(c+d))+e, ((a+c)+(b+d))+e, ((a+d)+(b+c))+e]

sym_topologies = [
  *gen_sym_topologies_helper(A,B,C,D,E),
  *gen_sym_topologies_helper(A,B,C,E,D),
  *gen_sym_topologies_helper(A,B,D,E,C),
  *gen_sym_topologies_helper(A,E,C,D,B),
  *gen_sym_topologies_helper(B,C,D,E,A)
] 
  

In [None]:
# find expected costs for all symmetric trees, given the true tree is a caterpillar tree 
# notice the similarity between these costs!

for sym_topology in sym_topologies:
  print(str(sym_topology), rooted_parsimony_score(sym_topology, branch_lengths=cat_branch_lengths))

In [None]:
# find expected costs for all symmetric trees, given the true tree is a balanced tree
# notice the similarity between these costs!

for sym_topology in sym_topologies:
  print(str(sym_topology), rooted_parsimony_score(sym_topology, branch_lengths=bal_branch_lengths))

#### Parsimony anomaly zone for caterpillar tree (fixed x3)

In [None]:
inequalities_cat = gen_inequalities_5taxa(cat_tree)

In [None]:
resolution = 400
x1_lim = (0, 0.1)
x2_lim = (0, 0.1)

x1s = np.linspace(*x1_lim, resolution)
x2s = np.linspace(*x2_lim, resolution)
# need a very small value of x3 instead of 0 in order for numerical evaluation to work properly
x3s = np.array([1e-6,1/50,1/10,1/2])

X, Y, Z = np.meshgrid(x1s, x2s, x3s)

data = np.zeros_like(X, dtype=int).copy()
for inequality in inequalities_cat:
    data += inequality(X, Y, Z)
        
data = np.ma.masked_where(data < 1, data)

In [None]:
norm = BoundaryNorm([i for i in range(16)], 15)

kw = {
  'vmin': 0,
  'vmax': 15,
  'levels': np.linspace(0,15,15),
  'cmap': good_cmap,
  'antialiased': True,
  'algorithm': 'threaded',
  'norm': norm
}

fig, axs = plt.subplots(1, len(x3s), figsize=(5*len(x3s), 5), sharey=True)

for idx, ax in enumerate(axs):
    
  ax.contourf(
    X[:, :, idx], Y[:, :, idx], data[:, :, idx],
    **kw
  )
  
  ax.set(
    xlim=x1_lim, 
    ylim=x2_lim,
    xlabel='x1',
    ylabel='x2',
    xticks=np.linspace(0,0.1,6),
    yticks=np.linspace(0,0.1,6)
  )

plt.show()

In [None]:
# verifying that candidate topology (((AB)(CD))E) is maximally anomalous for species tree topology ((((AB)C)D)E) 
# i.e. this candidate topology contribues to any anomalous region

inequality_max_anom = s.lambdify(
  (x1,x2,x3), 
  s.Gt(rooted_parsimony_score(cat_tree, branch_lengths=cat_branch_lengths), rooted_parsimony_score(sym_tree, branch_lengths=cat_branch_lengths)), 
  "numpy"
)

resolution = 400
x1_lim = (0, 0.1)
x2_lim = (0, 0.1)

x1s = np.linspace(*x1_lim, resolution)
x2s = np.linspace(*x2_lim, resolution)
x3s = np.array([1e-6,1/50,1/10,1/2])

X, Y, Z = np.meshgrid(x1s, x2s, x3s)

data_max_anom = np.zeros_like(X, dtype=int).copy()
data_max_anom += inequality_max_anom(X, Y, Z)
        
data_max_anom = np.ma.masked_where(data_max_anom < 1, data_max_anom)

norm = BoundaryNorm([i for i in range(16)], 15)

kw = {
  'vmin': 0,
  'vmax': 15,
  'levels': np.linspace(0,15,15),
  'cmap': good_cmap,
  'antialiased': True,
  'algorithm': 'threaded',
  'norm': norm
}

fig, axs = plt.subplots(1, len(x3s), figsize=(5*len(x3s), 5), sharey=True)

for idx, ax in enumerate(axs):
    
  ax.contourf(
    X[:, :, idx], Y[:, :, idx], data_max_anom[:, :, idx],
    **kw
  )
  
  ax.set(
    xlim=x1_lim, 
    ylim=x2_lim,
    xlabel='x1',
    ylabel='x2',
    xticks=np.linspace(0,0.1,6),
    yticks=np.linspace(0,0.1,6)
  )

plt.show()

#### Parsimony anomaly zone for balanced tree (fixed x3)

In [None]:
inequalities_bal = gen_inequalities_5taxa(bal_tree)

In [None]:
resolution = 400
x1_lim = (0, 0.1)
x2_lim = (0, 0.1)

x1s = np.linspace(*x1_lim, resolution)
x2s = np.linspace(*x2_lim, resolution)
# need a very small value of x3 instead of 0 in order for numerical evaluation to work properly
x3s = np.array([1e-6,1/50,1/10,1/2])

X, Y, Z = np.meshgrid(x1s, x2s, x3s)

data = np.zeros_like(X, dtype=int).copy()
for inequality in inequalities_bal:
    data += inequality(X, Y, Z)
        
data = np.ma.masked_where(data < 1, data)

In [None]:
norm = BoundaryNorm([i for i in range(16)], 15)

kw = {
  'vmin': 0,
  'vmax': 15,
  'levels': np.linspace(0,15,15),
  'cmap': good_cmap,
  'antialiased': True,
  'algorithm': 'threaded',
  'norm': norm
}

fig, axs = plt.subplots(1, len(x3s), figsize=(5*len(x3s), 5), sharey=True)

for idx, ax in enumerate(axs):
    
  ax.contourf(
    X[:, :, idx], Y[:, :, idx], data[:, :, idx],
    **kw
  )
  
  ax.set(
    xlim=x1_lim, 
    ylim=x2_lim,
    xlabel='x1',
    ylabel='x2',
    xticks=np.linspace(0,0.1,6),
    yticks=np.linspace(0,0.1,6)
  )

plt.show()

In [None]:
# verifying that candidate topology (((AB)(DE))C) is maximally anomalous for speccies tree topology (((AB)C)(DE)) 
# i.e. this topology contribues to any anomalous region

inequality_max_anom = s.lambdify(
  (x1,x2,x3), 
  s.Gt(rooted_parsimony_score(bal_tree, branch_lengths=bal_branch_lengths), rooted_parsimony_score(((A+B)+(D+E))+C, branch_lengths=bal_branch_lengths)), 
  "numpy"
)

resolution = 400
x1_lim = (0, 0.1)
x2_lim = (0, 0.1)

x1s = np.linspace(*x1_lim, resolution)
x2s = np.linspace(*x2_lim, resolution)
x3s = np.array([1e-6,1/50,1/10,1/2])

X, Y, Z = np.meshgrid(x1s, x2s, x3s)

data_max_anom = np.zeros_like(X, dtype=int).copy()
data_max_anom += inequality_max_anom(X, Y, Z)
        
data_max_anom = np.ma.masked_where(data_max_anom < 1, data_max_anom)

norm = BoundaryNorm([i for i in range(16)], 15)

kw = {
  'vmin': 0,
  'vmax': 15,
  'levels': np.linspace(0,15,15),
  'cmap': good_cmap,
  'antialiased': True,
  'algorithm': 'threaded',
  'norm': norm
}

fig, axs = plt.subplots(1, len(x3s), figsize=(5*len(x3s), 5), sharey=True)

for idx, ax in enumerate(axs):
    
  ax.contourf(
    X[:, :, idx], Y[:, :, idx], data_max_anom[:, :, idx],
    **kw
  )
  
  ax.set(
    xlim=x1_lim, 
    ylim=x2_lim,
    xlabel='x1',
    ylabel='x2',
    xticks=np.linspace(0,0.1,6),
    yticks=np.linspace(0,0.1,6)
  )

plt.show()

#### Parsimony anomaly zone for symmetric tree (fixed x3)

In [None]:
inequalities_sym = gen_inequalities_5taxa(sym_tree)

In [None]:
resolution = 400
x1_lim = (0, 0.1)
x2_lim = (0, 0.1)

x1s = np.linspace(*x1_lim, resolution)
x2s = np.linspace(*x2_lim, resolution)
# need a very small value of x3 instead of 0 in order for numerical evaluation to work properly
x3s = np.array([1e-6,1/50,1/10,1/2])

X, Y, Z = np.meshgrid(x1s, x2s, x3s)

data = np.zeros_like(X, dtype=int).copy()
for inequality in inequalities_sym:
    data += inequality(X, Y, Z)
        
data = np.ma.masked_where(data < 1, data)

In [None]:
norm = BoundaryNorm([i for i in range(16)], 15)

kw = {
  'vmin': 0,
  'vmax': 15,
  'levels': np.linspace(0,15,15),
  'cmap': good_cmap,
  'antialiased': True,
  'algorithm': 'threaded',
  'norm': norm
}

fig, axs = plt.subplots(1, len(x3s), figsize=(5*len(x3s), 5), sharey=True)

for idx, ax in enumerate(axs):
    
  ax.contourf(
    X[:, :, idx], Y[:, :, idx], data[:, :, idx],
    **kw
  )
  
  ax.set(
    xlim=x1_lim, 
    ylim=x2_lim,
    xlabel='x1',
    ylabel='x2',
    xticks=np.linspace(0,0.1,6),
    yticks=np.linspace(0,0.1,6)
  )

plt.show()

#### Parsimony anomaly zone for caterpillar tree (3D plot)

Using the symmetric tree gives the same result, just as it in the 2D plots above.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

resolution = 400
min_x = 1e-10
max_x = 0.1

x1_lim = (min_x, max_x)
x2_lim = (min_x, max_x)
x3_lim = (min_x, max_x)

x1s = np.linspace(*x1_lim, resolution)
x2s = np.linspace(*x2_lim, resolution)
x3s = np.linspace(*x3_lim, resolution)
X, Y, Z = np.meshgrid(x1s, x2s, x3s)
data = np.zeros_like(X, dtype=int)

for inequality in inequalities_cat: 
    data[0, :, :] += inequality(X[0, :, :], Y[0, :, :], Z[0, :, :])
    data[:, 0, :] += inequality(X[:, 0, :], Y[:, 0, :], Z[:, 0, :])
    data[:, :, 0] += inequality(X[:, :, 0], Y[:, :, 0], Z[:, :, 0])

data = np.ma.masked_where(data < 1, data)

kw = {
  'vmin': 0,
  'vmax': 15,
  'levels': np.linspace(0, 15, 15),
  'cmap': good_cmap,
  'antialiased': True,
  'algorithm': 'threaded',
}

fig = plt.figure(figsize=(7,7))
ax = fig.add_subplot(111, projection='3d')

_ = ax.contourf(
  X[:, :, 0], Y[:, :, 0], data[:, :, 0],
  zdir='z', offset=0, **kw
)
_ = ax.contourf(
  X[0, :, :], data[0, :, :], Z[0, :, :],
  zdir='y', offset=0, **kw
)
_ = ax.contourf(
  data[:, 0, :], Y[:, 0, :], Z[:, 0, :],
  zdir='x', offset=0, **kw
)

ax.set(xlim=x1_lim, ylim=x2_lim, zlim=x3_lim)
ax.set(
    xlabel='x1',
    ylabel='x2',
    zlabel='x3',
    xticks=[0.02,0.04,0.06,0.08],
    yticks=[0.02,0.04,0.06,0.08],
    zticks=[0.02,0.04,0.06,0.08],
)

ax.view_init(35, 35, 0)
ax.set_box_aspect(None, zoom=0.9)

plt.show()

### Unrooted 6 taxa case

In [None]:
A, B, C, D, E, F = [Node(label={label}) for label in range(6)]

six_taxa_rooted_reps = [
  A+B+C+D+E+F,
  A+B+C+D+(E+F),
  ((A+B+C)+(D+E))+F,
  (A+B+C)+(D+E+F),
  ((A+B)+(C+D))+E+F,
  ((A+B)+(C+D))+(E+F)
]

x = s.symbols('x')

for rooted_tree in six_taxa_rooted_reps:
  rooted_tree.assign_branch_lengths(x,x,x,x)

six_taxa_rooted_topologies = generate_labeled_topologies(range(6))
six_taxa_unrooted_reps = []

for rooted_topology in six_taxa_rooted_topologies:
  seen = False
  for rep in six_taxa_unrooted_reps:
    if has_same_unrooted_topology(rooted_topology, rep):
      seen = True
      break
  if not seen: six_taxa_unrooted_reps.append(rooted_topology)
  

In [None]:
# expected parsimony costs for the two possible unrooted topologies under a star tree (all internal branch lengths x_i = 0) 

print(
  unrooted_parsimony_score(A+B+C+D+E+F, branch_lengths=star_tree_lengths(range(6))),
  unrooted_parsimony_score(((A+B)+(C+D))+(E+F), branch_lengths=star_tree_lengths(range(6)))
)

In [None]:
def gen_inequalities_6taxa(species_tree):
  
  x = s.symbols('x')

  species_subtending_lengths = subtending_lengths(species_tree)
  species_branch_lengths = branch_lengths(species_tree, subtending_lengths_dict=species_subtending_lengths)
  par_score_true = unrooted_parsimony_score(species_tree, species_branch_lengths)

  inequalities = []

  for candidate_tree in six_taxa_unrooted_reps:
      par_score_candidate = unrooted_parsimony_score(candidate_tree, species_branch_lengths)
      inequalities.append(s.lambdify(x, s.Gt(par_score_true, par_score_candidate), "numpy"))
  
  return inequalities

In [None]:
inequalities = gen_inequalities_6taxa(six_taxa_rooted_reps[0])

In [None]:
inequalities_arr = []

for rooted_tree in six_taxa_rooted_reps:
  inequalities_arr.append(
    gen_inequalities_6taxa(rooted_tree)
  )

In [None]:
xlim = (1e-10, 0.1)
resolution = 1000
xs = np.linspace(*xlim, resolution)

fig, ax = plt.subplots(figsize=(8,10))

for idx, inequalities in enumerate(inequalities_arr): 

  data = np.zeros_like(xs, dtype=int).copy()
  for inequality in inequalities:
    data += inequality(xs)
    
  ax.plot(xs, data)

ax.set_yticks([1,3,5,9,15])
ax.set_yticks([0,1,3,5,9,15])
plt.show()


In [None]:
xlim = (1e-6, 0.1)
resolution = 1000
xs = np.linspace(*xlim, resolution)

fig, ax = plt.subplots(figsize=(8,10))

# only plot the number of PAGTs for trees that have PAGTs
for idx, inequalities in list(enumerate(inequalities_arr))[:4]: 

  data = np.zeros_like(xs, dtype=int).copy()
  for inequality in inequalities:
    data += inequality(xs)
    
  ax.plot(xs, data)

ax.set_yticks([0,1,3,5,9,15])
plt.show()