In [1]:
from __future__ import division
import sys
sys.path.insert(0, "~/.local/lib/python3.6/site-packages")

import sympy
from sympy import *

def express(a, b, name):
    sym = symbols(name)
    sol = solve(a-sym, b)
    assert len(sol) == 1
    return (sym, sol[0])

In [2]:
from sympy.physics.quantum import *

In [3]:
# https://stackoverflow.com/q/59523322/1137334

from sympy.core.operations import AssocOp

def apply_ccr(expr, ccr, reverse=False):
    if not isinstance(expr, Basic):
        raise TypeError("The expression to simplify is not a sympy expression.")
        
    if not isinstance(ccr, Eq):
        if isinstance(ccr, Basic):
            ccr = Eq(ccr, 0)
        else:
            raise TypeError("The canonical commutation relation is not a sympy expression.")
    
    comm = None
    
    for node in preorder_traversal(ccr):
        if isinstance(node, Commutator):
            comm = node
            break
            
    if comm is None:
        raise ValueError("The cannonical commutation relation doesn not include a commutator.")
        
    solutions = solve(ccr, comm)
    
    if len(solutions) != 1:
        raise ValueError("There are more solutions to the cannonical commutation relation.")
        
    value = solutions[0]
    
    A = comm.args[0]
    B = comm.args[1]
    
    if reverse:
        (A, B) = (B, A)
        value = -value
    
    def is_expandable_pow_of(base, expr):
        return isinstance(expr, Pow) \
            and base == expr.args[0] \
            and isinstance(expr.args[1], Number) \
            and expr.args[1] >= 1
    
    
    def walk_tree(expr):
        if isinstance(expr, Number):
            return expr
        
        if not isinstance(expr, AssocOp) and not isinstance(expr, Function):
            return expr.copy()
        
        elif not isinstance(expr, Mul):
            return expr.func(*(walk_tree(node) for node in expr.args))
        
        else:
            args = [arg for arg in expr.args]
            
            for i in range(len(args)-1):
                x = args[i]
                y = args[i+1]
                
                if B == x and A == y:
                    args = args[0:i] + [A*B - value] + args[i+2:]
                    return walk_tree( Mul(*args).expand() )
                
                if B == x and is_expandable_pow_of(A, y):
                    ypow = Pow(A, y.args[1] - 1)
                    args = args[0:i] + [A*B - value, ypow] + args[i+2:]
                    return walk_tree( Mul(*args).expand() )
                
                if is_expandable_pow_of(B, x) and A == y:
                    xpow = Pow(B, x.args[1] - 1)
                    args = args[0:i] + [xpow, A*B - value] + args[i+2:]
                    return walk_tree( Mul(*args).expand() )
                
                if is_expandable_pow_of(B, x) and is_expandable_pow_of(A, y):
                    xpow = Pow(B, x.args[1] - 1)
                    ypow = Pow(A, y.args[1] - 1)
                    args = args[0:i] + [xpow, A*B - value, ypow] + args[i+2:]
                    return walk_tree( Mul(*args).expand() )
            
            return expr.copy()
            
    
    return walk_tree(expr)
   

Basic.apply_ccr = lambda self, ccr, reverse=False: apply_ccr(self, ccr, reverse)


In [4]:
# https://stackoverflow.com/q/59524925/1137334

from sympy.core.operations import AssocOp

def apply_operator(expr, eqns):
    if not isinstance(expr, Basic):
        raise TypeError("The expression to simplify is not a sympy expression.")
    
    if not isinstance(eqns, list) and not isinstance(eqns, tuple):
        eqns = (eqns,)
    
    
    rules = []
    
    
    class Rule(object):
        operator = None
        ketSymbol = None
        result = None
        generic = False
    
    
    def is_operator(op):
        return isinstance(op, Operator) \
        or isinstance(op, Dagger) \
        and isinstance(op.args[0], Operator)
    
    
    for eqn in eqns:
        if not isinstance(eqn, Eq):
            raise TypeError("One of the equations is not a valid sympy equation.")
        
        lhs = eqn.lhs
        rhs = eqn.rhs
        
        if not isinstance(lhs, Mul) \
        or len(lhs.args) != 2 \
        or not is_operator(lhs.args[0]) \
        or not isinstance(lhs.args[1], Ket):
            raise ValueError("The left-hand side has to be an operator applied to a ket.")
        
        rule = Rule()
        rule.operator = lhs.args[0]
        rule.ketSymbol = lhs.args[1].args[0]
        rule.result = rhs
        
        if not isinstance(rule.ketSymbol, Symbol):
            raise ValueError("The left-hand ket has to contain a simple symbol.")
        
        for ket in preorder_traversal(rhs):
            if isinstance(ket, Ket):
                for symb in preorder_traversal(ket):
                    if symb == rule.ketSymbol:
                        rule.generic = True
                        break
                        
        rules.append(rule)
    
    
    def is_expandable_pow_of(base, expr):
        return isinstance(expr, Pow) \
            and base == expr.args[0] \
            and isinstance(expr.args[1], Number) \
            and expr.args[1] >= 1
            
            
    def is_ket_of_rule(ket, rule):
        if not isinstance(ket, Ket):
            return False
        
        if rule.generic:
            for sym in preorder_traversal(ket):
                if sym == rule.ketSymbol:
                    return True
            return False
                
        else:
            return ket.args[0] == rule.ketSymbol
    
    
    def walk_tree(expr):
        if isinstance(expr, Number):
            return expr
        
        if not isinstance(expr, AssocOp) and not isinstance(expr, Function):
            return expr.copy()
        
        elif not isinstance(expr, Mul):
            return expr.func(*(walk_tree(node) for node in expr.args))
        
        else:
            args = [arg for arg in expr.args]
            
            for rule in rules:
                A = rule.operator
                ketSym = rule.ketSymbol
                
                for i in range(len(args)-1):
                    x = args[i]
                    y = args[i+1]

                    if A == x and is_ket_of_rule(y, rule):
                        ev = rule.result
                        
                        if rule.generic:
                            ev = ev.subs(rule.ketSymbol, y.args[0])
                        
                        args = args[0:i] + [ev] + args[i+2:]
                        return walk_tree( Mul(*args).expand() )

                    if is_expandable_pow_of(A, x) and is_ket_of_rule(y, rule):
                        xpow = Pow(A, x.args[1] - 1)
                        ev = rule.result
                        
                        if rule.generic:
                            ev = ev.subs(rule.ketSymbol, y.args[0])
                        
                        args = args[0:i] + [xpow, ev] + args[i+2:]
                        return walk_tree( Mul(*args).expand() )
                
            
            return expr.copy()
            
    
    return walk_tree(expr)
   

Basic.apply_operator = lambda self, *eqns: apply_operator(self, eqns)


In [5]:
a = Operator("a")
ad = Dagger(a)
N = ad * a

ccr = Eq( Commutator(a, ad),  1 )

n = Symbol('n', integer=True, negative=False)
down = Eq( a *Ket(n), sqrt(n  )*Ket(n-1) )
up   = Eq( ad*Ket(n), sqrt(n+1)*Ket(n+1) )

In [6]:
expr1 = (ad + a)**4
expr1 = expr1.expand().apply_ccr(ccr)
expr1

3 + 12*Dagger(a)*a + 4*Dagger(a)*a**3 + 6*Dagger(a)**2 + 6*Dagger(a)**2*a**2 + 4*Dagger(a)**3*a + Dagger(a)**4 + 6*a**2 + a**4

In [7]:
expr2 \
= ad**4 + a**4 + 2*ad**2 + 2*a**2 + 4*N**2 + 4*N \
+ 2*ad**2*N + 2*N*ad**2 + 2*a**2*N + 2*N*a**2 \
+ ad**2*a**2 + a**2*ad**2 + 1
expr2 = expr2.expand().apply_ccr(ccr)
assert 0 == (expr1-expr2).simplify()

In [8]:
expr3 \
= ad**4 + a**4 +6*ad**2 - 2*a**2 + 6*N**2 + 6*N \
+ 4*ad**2*N + 4*a**2*N + 3
expr3 = expr3.expand().apply_ccr(ccr)
assert 0 == (expr1-expr3).simplify()

In [9]:
expr4 = (expr1 * Ket(n)).expand().apply_operator(up, down)
expr4

4*n**(3/2)*sqrt(n - 1)*|n - 2> + sqrt(n)*sqrt(n - 3)*sqrt(n - 2)*sqrt(n - 1)*|n - 4> - 2*sqrt(n)*sqrt(n - 1)*|n - 2> + 6*n**2*|n> + 4*n*sqrt(n + 1)*sqrt(n + 2)*|n + 2> + 6*n*|n> + sqrt(n + 1)*sqrt(n + 2)*sqrt(n + 3)*sqrt(n + 4)*|n + 4> + 6*sqrt(n + 1)*sqrt(n + 2)*|n + 2> + 3*|n>

In [10]:
expr5 \
= sqrt(factorial(n+4)/factorial(n))*Ket(n+4) \
+ (4*n+6)*sqrt(factorial(n+2)/factorial(n))*Ket(n+2) \
+ (6*n**2 + 6*n + 3)*Ket(n) + 0 \
+ 2*(2*n - 1)*sqrt(factorial(n)/factorial(n-2))*Ket(n-2) \
+ sqrt(factorial(n)/factorial(n-4))*Ket(n-4)
expr5

(4*n - 2)*sqrt(1/factorial(n - 2))*sqrt(factorial(n))*|n - 2> + (4*n + 6)*sqrt(factorial(n + 2))*|n + 2>/sqrt(factorial(n)) + (6*n**2 + 6*n + 3)*|n> + sqrt(1/factorial(n - 4))*sqrt(factorial(n))*|n - 4> + sqrt(factorial(n + 4))*|n + 4>/sqrt(factorial(n))

In [11]:
assert 0 == (expr5.simplify().expand() - expr4).subs(n, n+4).simplify()

In [12]:
expr6 = (a - ad)**2
expr6 = expr6.expand().apply_ccr(ccr)
expr6

-1 - 2*Dagger(a)*a + Dagger(a)**2 + a**2

In [13]:
expr7 = expr6 * Ket(n)
expr7 = expr7.expand().apply_operator(up, down)
expr7

sqrt(n)*sqrt(n - 1)*|n - 2> - 2*n*|n> + sqrt(n + 1)*sqrt(n + 2)*|n + 2> - |n>