## Imports

In [1]:
import sympy as sp
import time

q = sp.Symbol('q')

## Util functions

In [2]:
# Input: polynomial f, variable x, polynomials u and v
# Output: f with x replaced by u / v, then 'homogenized'

def subs_frac(f, x, u, v):
    return sp.Poly(sp.Poly(f, x).transform(sp.Poly(u, x), sp.Poly(v, x)), f.gens)

In [3]:
# Input: system
# Output: list of pairs (eqs, X_vars)

def split_system(system):
    groups = []
    
    # Group equations
    used_vars = set(system.S_vars)
    for eq in system.eqs:
        eq_vars = eq.free_symbols.difference(system.S_vars)
        used_vars.update(eq_vars)
        for g_eqs, g_vars in groups:
            if not g_vars.isdisjoint(eq_vars):
                g_eqs.append(eq)
                g_vars.update(eq_vars)
                break
        else:
            groups.append(([ eq ], system.S_vars.union(eq_vars)))
    
    # Change gens for polynomials
    for g_eqs, g_vars in groups:
        g_eqs[:] = [ sp.Poly(eq, g_vars) for eq in g_eqs ]
    
    # Group for unused variables
    unused_vars = system.X_vars.difference(used_vars)
    if len(unused_vars) > 0:
        groups.append(([], system.S_vars.union(unused_vars)))
        
    return groups

In [4]:
# Input: F, G are lists of polynomials
# Output: True if isomorphic, otherwise False

def are_isomorphic(F, G):
    # Number of equations should match
    if len(F) != len(G):
        return False
    n_eqs = len(F)

    # Trivial case
    if n_eqs == 0:
        return True

    # Number of variables should match
    if len(F[0].gens) != len(G[0].gens):
        return False
    n_vars = len(F[0].gens)

    # Convert polynomials to their term representation
    F = [ f.terms() for f in F ]
    G = [ g.terms() for g in G ]

    # Now try to match equations
    match_eqs, match_vars = [ None ] * n_eqs, [ None ] * n_vars
    
    def match_eq(match_eqs, match_vars, i):
        # If this was the last equation that needed to be matched, we are done!
        if i == n_eqs:
            return True

        f_terms = F[i]
        n_terms = len(f_terms)
        for j in [ j for j in range(n_eqs) if j not in match_eqs and len(G[j]) == n_terms ]:
            # Try to match equation F[i] to G[j]
            g_terms = G[j]

            # Now try to match terms
            match_eqs[i] = j
            match_terms = [ None ] * n_terms
            result = match_term(match_eqs, match_vars, i, j, match_terms, n_terms, 0)
            if result == True:
                return True
            match_eqs[i] = None

        return False

    def match_term(match_eqs, match_vars, i, j, match_terms, n_terms, k):
        # If all terms are matched, go on to the next equation
        if k == n_terms:
            return match_eq(match_eqs, match_vars, i + 1)

        # Find match for F[i][k]
        f_term = F[i][k]
        for l in [ l for l in range(n_terms) if l not in match_terms and could_match_term(f_term, G[j][l])]:
            g_term = G[j][l]
            options = []
            if not find_options_match_term(f_term[0], g_term[0], options, match_vars, 0):
                return False

            match_terms[k] = l
            for option in options:
                result = match_term(match_eqs, option, i, j, match_terms, n_terms, k + 1)
                if result == True:
                    return True
            match_terms[k] = None
        
        return False

    def find_options_match_term(T, S, options, option, u):
        # If all u's are matched, add to options
        if u == n_vars:
            options.append(option.copy())
            return True

        # If u was already matched, or if T[u] == 0, just continue
        if option[u] != None or T[u] == 0:
            return find_options_match_term(T, S, options, option, u + 1)
        
        # Otherwise, find new matches for v
        matches_v = [ v for v in range(n_vars) if S[v] == T[u] and v not in option ]
        if not matches_v:
            return False

        for v in matches_v:
            option[u] = v
            if not find_options_match_term(T, S, options, option, u + 1):
                option[u] = None        
                return False
        
        option[u] = None
        return True

    def could_match_term(T, S):
        # Coefficients should match
        if T[1] != S[1]:
            return False
        
        # Powers should match
        k = max(T[0] + S[0]) + 1
        P = [ 0 ] * k
        for a in T[0]:
            P[a] += 1
        for a in S[0]:
            P[a] -= 1
        
        return not any(P)

    return match_eq(match_eqs, match_vars, 0)

## Class Solver

In [5]:
class Solver:
    
    def __init__(self):
        self.dictionary = []
        self.unknowns = 0
    
    # Input: system
    # Output: class in grothendieck ring
    def compute_class(self, system):    
        print('System {} in {}'.format(system, system.X_vars))
        
        # Reduce system equations
        system.reduce()
                
        # Special cases
        if system.eqs == []:
            return q ** (len(system.X_vars) - len(system.S_vars))

        if system.eqs == [ 1 ]:
            return 0
                
        # Split system
        groups = split_system(system)
        if len(groups) != 1: 
            c = 1
            for g_eqs, g_vars in groups:
                if g_eqs:
                    print('Group {} in {}'.format([ f.expr for f in g_eqs ], g_vars))
                    s = System(g_eqs, g_vars, system.S_vars, system.factored_eqs)
                    c *= self.compute_class(s)
                else:
                    c *= q ** (len(g_vars) - len(system.S_vars))
            return c
                
        # Search in dictionary
        for entry in self.dictionary:
            if are_isomorphic(entry[0], system.eqs):
                print('Found match!')
                return entry[1]

        # Apply solving techniques
        c = self.solve_system(system)
        c = sp.expand(c)

        print('Save {} --> {}'.format([ eq.expr for eq in system.eqs ], c))

        self.dictionary.append((system.eqs.copy(), c))
        return c
    
    def new_unknown(self):
        x = sp.Symbol('X_' + str(self.unknowns))
        self.unknowns += 1
        return x
    
#   -------- SOLVING TECHNIQUE METHODS --------
    
    def check_univariate_equations(self, system, eq):
        # Check for equations with only one free variable x, then solve for x
        eq_vars = eq.free_symbols
        if len(eq_vars) != 1:
            return None

        # Variable must not be of S
        x = eq_vars.pop()
        if x in system.S_vars:
            return None

        # Simply consider all solutions and add the classes
        x_solutions = sp.solve(eq, x)
        Y_vars = system.X_vars.difference({ x, })
        c = 0
        for v in x_solutions:                
            s = System([ sp.Poly(f.subs(x, v), Y_vars) for f in system.eqs if f != eq ], Y_vars, system.S_vars, system.factored_eqs)
            c += self.compute_class(s)
            
        return c
    
    def check_linear_equations(self, system, eq):
        # Look for something of the form 'x * u + v = 0'
        for x in [ x for x in eq.free_symbols if x not in system.S_vars and eq.degree(x) == 1]:
            v = sp.Poly(eq.subs(x, 0), gens = system.X_vars)
            u = sp.Poly((eq - v) / x, gens = system.X_vars)

            c = 0

            # Case 1: u = 0, v = 0
            s = System([ f for f in system.eqs if f != eq ] + [ u, v ], system.X_vars, system.S_vars, system.factored_eqs)
            c += self.compute_class(s)

            # Case 2: [ u != 0, x = -v / u ] = [ x = -v / u ] - [ u = 0, x = -v / u ]
            Y_vars = system.X_vars.difference({ x, })
            Y_eqs = [ sp.Poly(subs_frac(f, x, -v, u), Y_vars) for f in system.eqs if f != eq ]

            s = System(Y_eqs, Y_vars, system.S_vars, system.factored_eqs)
            c += self.compute_class(s)

            s = System(Y_eqs + [ u ], Y_vars, system.S_vars, system.factored_eqs)
            c -= self.compute_class(s)

            return c
            
        return None
    
    def check_product_equations(self, system, eq):
        # Check for equations of the form 'u * v = 0'
        factors = system.factored_eqs[eq.expr]
        if len(factors) < 2:
            return None
        
        # Determine u and v
        u = factors[0]
        v = 1
        for f in factors[1:]:
            v *= f
        v = sp.Poly(v, system.X_vars)
        
        # Construct dictionary of factored eqs
        Y_factored_eqs = system.factored_eqs.copy()
        Y_factored_eqs[u.expr] = [ u ]
        Y_factored_eqs[v.expr] = factors[1:]
        
        c = 0

        # Case 1: u = 0
        s = System([ f for f in system.eqs if f != eq ] + [ u ], system.X_vars, system.S_vars, Y_factored_eqs)
        c += self.compute_class(s)

        # Case 2: [ u != 0, v = 0 ] = [ v = 0 ] - [ u = 0, v = 0]
        s = System([ f for f in system.eqs if f != eq ] + [ v ], system.X_vars, system.S_vars, Y_factored_eqs)
        c += self.compute_class(s)

        s = System([ f for f in system.eqs if f != eq ] + [ u, v ], system.X_vars, system.S_vars, Y_factored_eqs)
        c -= self.compute_class(s)

        return c
    
    def solve_system(self, system):
        system.eqs.sort(key = lambda eq : len(eq.terms()))
        
        # Chcek univariate equations
        for eq in system.eqs:            
            c = self.check_univariate_equations(system, eq)
            if c != None:
                return c
        
        # Check product equations
        for eq in system.eqs:
            c = self.check_product_equations(system, eq)
            if c != None:
                return c
        
        # Check linear equations
        for eq in system.eqs:
            c = self.check_linear_equations(system, eq)
            if c != None:
                return c

        # If all failed, create new symbol
        return self.new_unknown()

## Class System

In [6]:
class System:
    # Represents a variety $X \to S$
    
    def __init__(self, eqs, X_vars, S_vars = set(), factored_eqs = {}):
        self.eqs = eqs
        self.X_vars = X_vars
        self.S_vars = S_vars
        self.factored_eqs = factored_eqs

        if any(not f.is_Poly for f in self.eqs):
            self.eqs = [ sp.Poly(f, self.X_vars) for f in self.eqs ]
    
    def __repr__(self):
        return '{ ' + (', '.join([ str(f.expr) for f in self.eqs ])) + ' }'
    
    def reduce_groebner(self):
        # Special cases
        if self.eqs == [ 0 ]:
            self.eqs = []
            return

        if len(self.eqs) <= 1:
            return

        # Compute Gröbner basis
        self.eqs = list(sp.groebner(self.eqs, self.X_vars, order = 'grevlex'))
        
    def reduce_squarefree(self):
        eqs_sqf = []
        factored_eqs_sqf = {}
        flag = False
        
        for eq in self.eqs:
            # If factored before, just copy the factorization
            eq_expr = eq.expr
            if eq_expr in self.factored_eqs:
                eqs_sqf.append(eq)
                factored_eqs_sqf[eq_expr] = self.factored_eqs[eq_expr]
                continue
            
            factors = sp.factor_list(eq, extension = True, order = 'grevlex')
            factors = [ factor[0] for factor in factors[1] ]
            
            eq_sqf = 1
            for factor in factors:
                eq_sqf *= factor
            eq_sqf = sp.Poly(eq_sqf, self.X_vars)

            eqs_sqf.append(eq_sqf)
            factored_eqs_sqf[eq_sqf.expr] = factors
        
            if not flag and eq_sqf.LM() != eq.LM():
                flag = True
        
        self.eqs = eqs_sqf 
        self.factored_eqs = factored_eqs_sqf
        return flag
        
    def reduce_direct_solve(self):
        # Look for equations which are linear in some variable
        for eq in self.eqs:            
            for x in [ x for x in eq.free_symbols.difference(self.S_vars) if eq.degree(x) == 1 ]:                
                # See if we can directly solve for x
                r = eq.subs(x, 0)
                q = (eq - r) / x
                                    
                if sp.total_degree(q) != 0:
                    continue
                    
                # Solve for x and remove x from X_vars
                x_value = -r / q
                self.X_vars = self.X_vars.copy()
                self.X_vars.remove(x)
                self.eqs = [ sp.Poly(f.subs(x, x_value), self.X_vars) for f in self.eqs if f != eq ]
                return True
        else:
            return False
    
    def reduce(self):
        while True:
            
            self.reduce_groebner()

            if self.reduce_direct_solve() and len(self.eqs) > 1:
                continue
            
            if self.reduce_squarefree():
                continue
            
            return

## Tests

In [7]:
x, y, z, w = sp.symbols('x y z w')

# A test is a tuple (equations, variables, expected class)
tests = [
    ([], set(), 1),
    ([], { x }, q),
    ([ x ], { x }, 1),
    ([ x - 3 ], { x }, 1),
    ([ x, y ], { x, y }, 1),
    ([ x * y ], { x, y }, 2*q - 1),
    ([ x * y + 1 ], { x, y }, q - 1),
    ([ (x + 1) * (y + 1) ], { x, y }, 2 * q - 1),
    ([ x**2 + 1 ], { x }, 2),
    ([ x**3 + x + 1 ], { x }, 3),
    ([ x**3 + 4*x**2 + 5*x + 2 ], { x }, 2),    
    ([ (x**2 + 1) * (y**2 + 1) ], { x, y }, 4*q - 4),
    ([ x * y - z * w - 1 ], { x, y, z, w }, q**3 - q),
]

In [8]:
# Run tests
success = True
for eqs, X_vars, exp_c in tests:
    # Solve system
    s = Solver()
    system = System(eqs, X_vars)
    c = s.compute_class(system)
    
    # Check if agrees with expected value
    if sp.expand(c - exp_c) != 0:
        print('Test failed! System {}: expected {} but got {}'.format(system, exp_c, c))
        success = False
    else:
        print('Test succeeded!')

if success:
    print('\nAll tests succeeded!')
else:
    print('\nSome tests failed!')

System {  } in set()
Test succeeded!
System {  } in {x}
Test succeeded!
System { x } in {x}
Test succeeded!
System { x - 3 } in {x}
Test succeeded!
System { x, y } in {x, y}
System {  } in set()
Save [y] --> 1
Test succeeded!
System { x*y } in {x, y}
System { y } in {x, y}
System { x } in {x, y}
System { y, x } in {x, y}
System {  } in set()
Save [y] --> 1
Save [x*y] --> 2*q - 1
Test succeeded!
System { x*y + 1 } in {x, y}
System { y, 1 } in {x, y}
System {  } in {y}
System { y } in {y}
Save [x*y + 1] --> q - 1
Test succeeded!
System { x*y + x + y + 1 } in {x, y}
System { y + 1 } in {x, y}
System { x + 1 } in {x, y}
System { y + 1, x + 1 } in {x, y}
System {  } in set()
Save [y + 1] --> 1
Save [x*y + x + y + 1] --> 2*q - 1
Test succeeded!
System { x**2 + 1 } in {x}
System {  } in set()
System {  } in set()
Save [x**2 + 1] --> 2
Test succeeded!
System { x**3 + x + 1 } in {x}
System {  } in set()
System {  } in set()
System {  } in set()
Save [x**3 + x + 1] --> 3
Test succeeded!
System {

## Scratch

In [9]:
a, b, c, d, e, f, g, h = sp.symbols('a b c d e f g h')
system = System([ a*d + b*c - 1, e*f + g*h - 1 ], { a, b, c, d, e, f, g, h })


s = Solver()

t_start = time.time()
_ = s.compute_class(system)
t_end = time.time()

print('Time elapsed: {}'.format(t_end - t_start))

sp.factor(_)

System { a*d + b*c - 1, e*f + g*h - 1 } in {g, h, b, d, a, f, c, e}
Group [e*f + g*h - 1] in {f, g, h, e}
System { e*f + g*h - 1 } in {f, g, h, e}
System { e, g*h - 1 } in {f, g, h, e}
Group [g*h - 1] in {g, h}
System { g*h - 1 } in {g, h}
System { h, -1 } in {g, h}
System {  } in {h}
System { h } in {h}
Save [g*h - 1] --> q - 1
System {  } in {g, h, e}
System { e } in {g, h, e}
Save [e*f + g*h - 1] --> q**3 - q
Group [a*d + b*c - 1] in {a, c, d, b}
System { a*d + b*c - 1 } in {a, c, d, b}
Found match!
Time elapsed: 0.17821478843688965


q**2*(q - 1)**2*(q + 1)**2

In [13]:
a, b, c, d, e, f, g, h = sp.symbols('a b c d e f g h')
system = System([ (a + b) * (c + d) * (e + f) * (g + h) ], { a, b, c, d, e, f, g, h })


N = 1
t_start = time.time()
for i in range(N):
    s = Solver()
    _ = s.compute_class(system)
t_end = time.time()

print('\nTime elapsed per solve: {}'.format((t_end - t_start) / N))
sp.factor(_)

System { a*c*e*g + a*c*e*h + a*c*f*g + a*c*f*h + a*d*e*g + a*d*e*h + a*d*f*g + a*d*f*h + b*c*e*g + b*c*e*h + b*c*f*g + b*c*f*h + b*d*e*g + b*d*e*h + b*d*f*g + b*d*f*h } in {g, h, b, d, a, f, c, e}
System { a + b } in {g, h, b, d, a, f, c, e}
System { c*e*g + c*e*h + c*f*g + c*f*h + d*e*g + d*e*h + d*f*g + d*f*h } in {g, h, b, d, a, f, c, e}
Group [c*e*g + c*e*h + c*f*g + c*f*h + d*e*g + d*e*h + d*f*g + d*f*h] in {c, d, f, g, h, e}
System { c*e*g + c*e*h + c*f*g + c*f*h + d*e*g + d*e*h + d*f*g + d*f*h } in {c, d, f, g, h, e}
System { c + d } in {c, d, f, g, h, e}
System { e*g + e*h + f*g + f*h } in {c, d, f, g, h, e}
Group [e*g + e*h + f*g + f*h] in {f, g, h, e}
System { e*g + e*h + f*g + f*h } in {f, g, h, e}
System { e + f } in {f, g, h, e}
System { g + h } in {f, g, h, e}
System { e + f, g + h } in {f, g, h, e}
Group [g + h] in {g, h}
System { g + h } in {g, h}
Save [e*g + e*h + f*g + f*h] --> 2*q**3 - q**2
System { c + d, e*g + e*h + f*g + f*h } in {c, d, f, g, h, e}
Group [e*g + e*

q**4*(2*q - 1)*(2*q**2 - 2*q + 1)

In [11]:
# https://docs.python.org/2/library/profile.html

In [12]:
# Solve 'simple' equations first!
# Equations are factored when computing square-free parts already, so store those for as long as possible

In [14]:
a, b, c, d, e, f, g, h = sp.symbols('a b c d e f g h')
system = System([ a * b + c * d ], { a, b, c, d })


N = 100
t_start = time.time()
for i in range(N):
    s = Solver()
    _ = s.compute_class(system)
t_end = time.time()

print('\nTime elapsed per solve: {}'.format((t_end - t_start) / N))
sp.expand(_)

System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 

Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b

Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in

System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b, c, d, a}
System { b, c*d } in {b, c, d, a}
Group [c*d] in {c, d}
System { c*d } in {c, d}
System { d } in {c, d}
System { c } in {c, d}
System { d, c } in {c, d}
System {  } in set()
Save [d] --> 1
Save [c*d] --> 2*q - 1
System {  } in {c, d, b}
System { b } in {c, d, b}
Save [a*b + c*d] --> q**3 + q**2 - q
System { a*b + c*d } in {b

q**3 + q**2 - q