NOTE: Run the following installation cell and restart the notebook if you haven't run it already.

In [1]:
!easy_install pandas # TODO: make explicit note about this

Searching for pandas
Best match: pandas 1.3.0
Processing pandas-1.3.0-py3.9-macosx-10.9-x86_64.egg
pandas 1.3.0 is already the active version in easy-install.pth

Using /Applications/SageMath/local/lib/python3.9/site-packages/pandas-1.3.0-py3.9-macosx-10.9-x86_64.egg
Processing dependencies for pandas
Finished processing dependencies for pandas


In [2]:
from sage.combinat.gelfand_tsetlin_patterns import GelfandTsetlinPattern
import itertools
import numpy as np
import collections
import pandas as pd


In [3]:
def is_dominant(row):
    return len(row) == len(list(set(row)))


def get_all_possible_next_rows(row, include_non_strict=False): # TODO: include_non_strict
    l = list(itertools.product(*[list(range(left_el, row[i+1]-1, -1)) for i, left_el in enumerate(row[:-1])]))
    l = [el for el in l if is_dominant(el)]
    return l


def row_pair_is_legal(upper_row, lower_row):
    assert len(upper_row) > 0
    assert len(lower_row) > 0
    
    for i, left_upper_el in enumerate(upper_row[:-1]):
        right_upper_el = upper_row[i+1]
        lower_el = lower_row[i]
        if not ((lower_el <= left_upper_el) and (lower_el >= right_upper_el)):
            return False
    
    return True


def pattern_is_legal(pattern): 
    top_row_length = len(pattern[0])
    for i, row in enumerate(pattern):
        if not len(row) == top_row_length - i:
            return False
        
    return all([row_pair_is_legal(upper_row, pattern[i+1]) for i, upper_row in enumerate(pattern[:-1])])


def get_all_patterns_from_top_row(top_row):
    
    def recurse(pattern_so_far):
        assert type(pattern_so_far) == list
        lowest_row_so_far = pattern_so_far[-1]
        assert len(lowest_row_so_far) > 0

        if len(lowest_row_so_far) == 1:
            if pattern_is_legal(pattern_so_far):
                return [pattern_so_far]
            else:
                return []

        results = []
        for new_row in get_all_possible_next_rows(lowest_row_so_far):
            results += recurse(pattern_so_far + [list(new_row)])

        return results

    assert len(top_row) > 0
    assert top_row == sorted(top_row, reverse=True)
    
    if len(top_row) == 1:
        return [top_row]
                    
    return reversed(recurse([top_row]))
    

In [4]:
def get_tokuyama_formula(lam, variables, ice='gamma', verbose=False):
    n = len(lam)
    top_row = [lam[i] + (n - (i+1)) for i in range(n)]
    
    if n == 1:
        return SR.symbol(f'{variables[-1]}')^top_row[0] 

    patterns = get_all_patterns_from_top_row(top_row=top_row)

    formula = 0
    for pattern in patterns:
        g = GelfandTsetlinPattern(pattern)
        
        if not g.is_strict():
            continue

        if verbose:
            g.pp()

        term = 1

        t = SR.symbol('t')
        term *= (t+1)^(len(g.special_entries()))
        
        if ice == 'gamma':
            term *= t^len(g.boxed_entries())
        elif ice == 'delta':
            term *= t^len(g.circled_entries())
        else:
            raise Exception('') # TODO
            
        d_list = g.row_sums() + [0]
        d_exponents = [d - d_list[i+1] for i, d in enumerate(d_list[:-1])]
        if ice == 'gamma': 
            d_exponents = list(reversed(d_exponents))
        
        for i, d_exp in enumerate(d_exponents):
            term *= SR.symbol(variables[i])^(d_exp)

        if verbose:
            print(f"Term: {term}")

        formula += term

        if verbose:
            print("="*40)
            
    return formula
    
    

In [5]:
def get_num_left(lam, mu):
    assert len(lam) == len(mu) + 1
    
    return sum([1 if lam[i] == mu[i] else 0 for i in range(len(mu))])

def get_num_right(lam, mu):
    assert len(lam) == len(mu) + 1
    
    return sum([1 if lam[i+1] == mu[i] else 0 for i in range(len(mu))])

def get_num_generic(lam, mu):
    assert len(lam) == len(mu) + 1
    
    return sum([1 if lam[i] > mu[i] and lam[i+1] < mu[i] else 0 for i in range(len(mu))])    

    

In [6]:
def branching_term(lam, mu, var, param, ice='gamma'):
    term = 1

    param_symbol = SR.symbol(param)
    term *= (param_symbol+1)^get_num_generic(lam, mu)

    if ice == 'gamma':
        term *= param_symbol^get_num_left(lam, mu)
    elif ice == 'delta':
        term *= param_symbol^get_num_right(lam, mu)
    else:
        raise Exception('') #TODO

    d_exp = sum(lam) - sum(mu)
    term *= SR.symbol(var)^(d_exp)
    
    return term


def swf(lam, variables, parameters, ice='gamma', verbose=False):
    assert len(lam) == len(variables)
    
    n = len(lam)
    rho = [n - (i+1) for i in range(n)]
    top_row = [lam[i] + r for i, r in enumerate(rho)]
    
    if len(variables) == 1:
        return SR.symbol(variables[0])^lam[0]

    total = 0
    next_rows = get_all_possible_next_rows(top_row)
    for next_row in next_rows:        
        if verbose:
            print(f"Top row: {top_row}")
            print(f"Next row: {next_row}")

        if ice == 'gamma':
            new_variables = variables[:-1]
            new_parameters = parameters[:-1]
        elif ice == 'delta':
            new_variables = variables[1:]
            new_parameters = parameters[1:]
        
        branch = swf(np.array(next_row) - np.array(rho[1:]), new_variables, new_parameters, ice=ice, verbose=verbose)
        
        if ice == 'gamma':
            branching_var = variables[-1]
            branching_param = parameters[-1]
        elif ice == 'delta':
            branching_var = variables[0]
            branching_param = parameters[0]
        term = branching_term(top_row, next_row, branching_var, branching_param, ice)
        
        total += (term*branch)

        if verbose:
            print(f"Term: {term}")
            print(f"Branch: {branch.simplify().factor()}")

        if verbose:
            print("="*40)
            
    return total
    

In [7]:
def get_table(lam):
    
    n = len(lam)
    rho = [n - (i+1) for i in range(n)]
    top_row = [lam[i] + r for i, r in enumerate(rho)]
    
    variables = [f'z_{i}' for i in range(1, n+1, 1)]
    parameters = [f't_{i}' for i in range(1, n+1, 1)]
    
    next_rows = get_all_possible_next_rows(top_row)
    d = collections.defaultdict(list)
    for mu in next_rows:        
        shifted_mu = np.array(mu) - np.array(rho[1:])
        contribution_gamma = branching_term(top_row, shifted_mu, f'z_{n}', f't_{n}', ice='gamma')
        contribution_delta = branching_term(top_row, shifted_mu, 'z_1', 't_1', ice='delta')
        
        branch_gamma = swf(mu, variables[:-1], parameters[:-1], ice='gamma')
        branch_delta = swf(mu, variables[1:], parameters[1:], ice='delta')
        
        d['mu'].append(mu)
        d['gamma_contribution'].append(contribution_gamma)
        d['gamma_branch'].append(branch_gamma)
        d['delta_contribution'].append(contribution_delta)
        d['delta_branch'].append(branch_delta)
        
    df = pd.DataFrame(data=d)
    df.set_index('mu')
    
    return df


In [8]:
lam = [1,0,0]

variables = [f'z_{i}' for i in range(1, len(lam)+1, 1)]
parameters = ['t' for i in range(len(lam))]

print("Notice that the following are equal:\n")
print("Tokuyama formula:\n\t",  get_tokuyama_formula(lam, variables).simplify().factor())
print()
print("Spherical Whittaker function for the general linear group over a non-archimedian local field:\n\t", swf(lam, variables, parameters, ice='gamma').simplify().factor())

Notice that the following are equal:

Tokuyama formula:
	 (t*z_1 + z_2)*(t*z_1 + z_3)*(t*z_2 + z_3)*(z_1 + z_2 + z_3)

Spherical Whittaker function for the general linear group over a non-archimedian local field:
	 (t*z_1 + z_2)*(t*z_1 + z_3)*(t*z_2 + z_3)*(z_1 + z_2 + z_3)


In [9]:
lam = [1,0,0]

variables = [f'z_{i}' for i in range(1, len(lam)+1, 1)]
parameters = [f't_{i}' for i in range(1, len(lam)+1, 1)]

print("Notice that the following are equal:\n")
print("Gamma:\n\t", swf(lam, variables, parameters, ice='gamma').simplify().factor())
print()
print("Delta:\n\t", swf(lam, variables, parameters, ice='delta').simplify().factor())

Notice that the following are equal:

Gamma:
	 (t_2*z_1 + z_2)*(t_3*z_1 + z_3)*(t_3*z_2 + z_3)*(z_1 + z_2 + z_3)

Delta:
	 (t_1*z_1 + z_2)*(t_1*z_1 + z_3)*(t_2*z_2 + z_3)*(z_1 + z_2 + z_3)


In [10]:
df = get_table(lam=[1,0,0])

print(f'Given lambda={lam}, we get the following table:')
df

Given lambda=[1, 0, 0], we get the following table:


Unnamed: 0,mu,gamma_contribution,gamma_branch,delta_contribution,delta_branch
0,"(3, 1)",(t_3 + 1)*t_3*z_3,t_2*z_1^4*z_2 + (t_2 + 1)*z_1^3*z_2^2 + (t_2 +...,(t_1 + 1)*z_1,t_2*z_2^4*z_3 + (t_2 + 1)*z_2^3*z_3^2 + (t_2 +...
1,"(3, 0)",(t_3 + 1)*z_3^2,t_2*z_1^4 + (t_2 + 1)*z_1^3*z_2 + (t_2 + 1)*z_...,(t_1 + 1)*t_1*z_1^2,t_2*z_2^4 + (t_2 + 1)*z_2^3*z_3 + (t_2 + 1)*z_...
2,"(2, 1)",t_3*z_3^2,t_2*z_1^3*z_2 + (t_2 + 1)*z_1^2*z_2^2 + z_1*z_2^3,t_1*z_1^2,t_2*z_2^3*z_3 + (t_2 + 1)*z_2^2*z_3^2 + z_2*z_3^3
3,"(2, 0)",z_3^3,t_2*z_1^3 + (t_2 + 1)*z_1^2*z_2 + (t_2 + 1)*z_...,t_1^2*z_1^3,t_2*z_2^3 + (t_2 + 1)*z_2^2*z_3 + (t_2 + 1)*z_...
4,"(1, 0)",z_3^4,t_2*z_1^2 + (t_2 + 1)*z_1*z_2 + z_2^2,t_1*z_1^4,t_2*z_2^2 + (t_2 + 1)*z_2*z_3 + z_3^2
