In [None]:
import sympy as sp

a, a_dagger = sp.symbols(r"a a^\dagger", commutative = False)
b, b_dagger = sp.symbols(r"b b^\dagger", commutative = False)
c, c_dagger = sp.symbols(r"c c^\dagger", commutative = False)
r, z, q = sp.symbols(r"r z q", commutative = True, real = True)
al, be, ga, mu, nu  = sp.symbols(r"alpha beta gamma mu nu", commutative = True)
th, ph, xi = sp.symbols(r"theta phi xi", commutative = True, real = True)
et, T, t = sp.symbols(r"eta T t", real = True, positive = True)
gap, n, m, l = sp.symbols(r"g n m l")

sym = [a, a_dagger, b, b_dagger, c, c_dagger]
anihilation = [c, b, a]
creation = [c_dagger,b_dagger,a_dagger]

class Ket(sp.Symbol):
    def __new__(cls, name, amp = 1, **kwargs):
        obj = sp.Symbol.__new__(Ket, r"\left|%s\right>" % str(name), commutative = False)
        obj.value = name
        obj.amplitude = amp
        return obj
    
class Bra(sp.Symbol):
    def __new__(cls, name, amp=1,  **kwargs):
        obj = sp.Symbol.__new__(Bra, r"\left<%s\right|" % str(name), commutative = False)
        obj.value = name
        obj.amplitude = amp
        return obj

def apply_a(expr): 
    assert type(expr) == sp.Add
    add_args = []
    for arg in expr.args:
        ket = arg.args[-1]
        if isinstance(ket, Ket):
            if ket.value == al:
                
                new_arg = arg.subs(a_dagger, sp.conjugate(al))
                if new_arg == arg:
                    new_arg = arg.subs(a, al)
            else:
                if ket.value != 0:
                    new_arg = arg.subs(a*ket, sp.sqrt(ket.value)*Ket(ket.value-1)) 
                    if new_arg == arg or new_arg.args[-1].value == 0: 
                        new_arg = new_arg.subs(a_dagger*new_arg.args[-1], sp.sqrt(new_arg.args[-1].value+1)*Ket(new_arg.args[-1].value + 1)) 
                else:
                    new_arg = arg.subs(a, z)
                    new_arg = new_arg.subs(a_dagger,z)       
            add_args.append(new_arg)
        else:
            add_args.append(arg)
    return sp.Add(*add_args) 

def apply_all_a(expr): 
    new_expr = apply_a(expr) 
    if new_expr == expr: 
        return expr 
    return apply_all_a(new_expr) 

def apply_b(expr): 
    assert type(expr) == sp.Add
    add_args = []
    for arg in expr.args:
        ket = arg.args[-1]
        if isinstance(ket, Ket):
            if ket.value == be:
                new_arg = arg.subs(b_dagger, sp.conjugate(be))
                if new_arg == arg:
                    new_arg = arg.subs(b, be)
            else:
                if ket.value != 0:
                    new_arg = arg.subs(b*ket, sp.sqrt(ket.value)*Ket(ket.value-1)) 
                    if new_arg == arg or new_arg.args[-1].value == 0: 
                        new_arg = new_arg.subs(b_dagger*new_arg.args[-1], sp.sqrt(new_arg.args[-1].value+1)*Ket(new_arg.args[-1].value + 1)) 
                else:
                    new_arg = arg.subs(b, z)
                    new_arg = new_arg.subs(b_dagger,z)       
            add_args.append(new_arg)
        else:
            add_args.append(arg)
    return sp.Add(*add_args) 

def apply_all_b(expr): 
    new_expr = apply_b(expr) 
    if new_expr == expr: 
        return expr 
    return apply_all_b(new_expr) 

def apply_c(expr): 
    assert type(expr) == sp.Add
    add_args = []
    for arg in expr.args:
        ket = arg.args[-1]
        if isinstance(ket, Ket):
            if ket.value == ga:
                new_arg = arg.subs(c_dagger, sp.conjugate(ga))
                if new_arg == arg:
                    new_arg = arg.subs(c, ga)
            else:
                if ket.value != 0:
                    new_arg = arg.subs(c*ket, sp.sqrt(ket.value)*Ket(ket.value-1)) 
                    if new_arg == arg or new_arg.args[-1].value == 0: 
                        new_arg = new_arg.subs(c_dagger*new_arg.args[-1], sp.sqrt(new_arg.args[-1].value+1)*Ket(new_arg.args[-1].value + 1)) 
                else:
                    new_arg = arg.subs(c, z)
                    new_arg = new_arg.subs(c_dagger,z)       
            add_args.append(new_arg)
        else:
            add_args.append(arg)
    return sp.Add(*add_args) 

def apply_all_c(expr): 
    new_expr = apply_c(expr) 
    if new_expr == expr: 
        return expr 
    return apply_all_c(new_expr) 

def collapse_scalar_products(expr):
    add_terms = []
    for arg in expr.args:
        bra, ket = arg.args[-2], arg.args[-1]
        if bra.value == ket.value:
            new_arg = arg.subs(bra * ket, 1)
        else:    
            new_arg = arg.subs(bra * ket, z)
        add_terms.append(new_arg)
    return sp.Add(*add_terms)

def normal_ordering_a(expr):
    new_expr = expr.subs(a*a_dagger, 1 + a_dagger * a).expand()
    if new_expr == expr:
        return expr
    return normal_ordering_a(new_expr)

def normal_ordering_aa(expr):
    new_expr = expr.subs(a_dagger * a_dagger * a, a_dagger*(-1 + a * a_dagger)).expand()
    if new_expr == expr:
        return expr
    return normal_ordering_aa(new_expr)

def normal_ordering_b(expr):
    new_expr = expr.subs(b*b_dagger, 1 + b_dagger * b).expand()
    if new_expr == expr:
        return expr
    return normal_ordering_b(new_expr)

def normal_ordering_bb(expr):
    new_expr = expr.subs(b_dagger * b_dagger * b, b_dagger*(-1 + b * b_dagger)).expand()
    if new_expr == expr:
        return expr
    return normal_ordering_bb(new_expr)

def normal_ordering_c(expr):
    new_expr = expr.subs(c*c_dagger, 1 + c_dagger * c).expand()
    if new_expr == expr:
        return expr
    return normal_ordering_c(new_expr)

def normal_ordering_cc(expr):
    new_expr = expr.subs(c_dagger * c_dagger * c, c_dagger*(-1 + c * c_dagger)).expand()
    if new_expr == expr:
        return expr
    return normal_ordering_cc(new_expr)

def commutative_ordering(expr):
    new_expr = commutative_ordering_iter(expr)
    if new_expr == expr:
        return expr
    return commutative_ordering(new_expr)

def commutative_ordering_iter(expr):
    for i in range(len(anihilation)):
        x = anihilation[i]
        for j in anihilation[i+1::]:
            new_expr = expr.subs(x*j, j*x).expand()
            if new_expr != expr:
                expr = commutative_ordering_sub(new_expr, x, j)
    for i in range(len(creation)):
        x = creation[i]
        for j in anihilation[i+1::]:
            new_expr = expr.subs(x*j, j*x).expand()
            if new_expr != expr:
                expr = commutative_ordering_sub(new_expr, x, j)
    for i in range(len(anihilation)):
        x = anihilation[i]
        for j in creation[i+1::]:
            new_expr = expr.subs(x*j, j*x).expand()
            if new_expr != expr:
                expr = commutative_ordering_sub(new_expr, i, j)
    for i in range(len(creation)):
        x = creation[i]
        for j in creation[i+1::]:
            new_expr = expr.subs(x*j, j*x).expand()
            if new_expr != expr:
                expr = commutative_ordering_sub(new_expr, i, j)
    return expr

def commutative_ordering_sub(expr, i, j):
    new_expr = expr.subs(i*j, j*i)
    if new_expr == expr:
        return expr
    return commutative_ordering_sub(new_expr, i, j)

def complex_ordering(expr):
    new_expr = expr.subs(al*sp.conjugate(al), sp.Abs(al)**2).expand()
    if new_expr == expr:
        return expr
    return complex_ordering(new_expr)

def bra_ordering(expr, operator, bra):
    new_expr = expr.subs(bra*operator, operator*bra).expand()
    if new_expr == expr:
        return expr
    return bra_ordering(new_expr, operator, bra)

def order(expr):
    for _ in range(8):
        expr = commutative_ordering_iter(expr)
        expr = normal_ordering_a(normal_ordering_aa(normal_ordering_b(normal_ordering_bb(normal_ordering_c(normal_ordering_cc(expr))))))
    return expr

def simphiperbolic(expr):
    new_expr = expr.subs(sp.sinh(r)*sp.cosh(r), sp.Rational(1,2)*sp.sinh(2*r))
    if new_expr == expr:
        new_expr = expr.subs(sp.sinh(r)**2+sp.cosh(r)**2, sp.cosh(2*r))
        if new_expr == expr:
            new_expr = expr.subs(-sp.sinh(r)**2+sp.cosh(r)**2, 1)
            if new_expr == expr:
                return expr
    return simphiperbolic(new_expr)

def simptrigonometric(expr):
    new_expr = expr.subs(sp.sin(ph/2)*sp.cos(ph/2), sp.Rational(1,2)*sp.sin(ph))
    if new_expr == expr:
        return expr
    return simptrigonometric(new_expr)

def fun(x, bra_states, ket_states):
    ket_a, ket_b, ket_c = Ket(ket_states[0]), Ket(ket_states[1]), Ket(ket_states[2])
    bra_a, bra_b, bra_c = Bra(bra_states[0]), Bra(bra_states[1]), Bra(bra_states[2])
    x = order(x) 
    x_ket = (x * ket_c).expand()
    x_ket = sp.Add(x_ket)
    x_ket = (bra_c * x_ket).expand()
    x_ket = apply_all_c(x_ket)
    for _ in range(4):
        x_ket = bra_ordering(x_ket, a, bra_c)
        x_ket = bra_ordering(x_ket, a_dagger, bra_c)
        x_ket = bra_ordering(x_ket, b, bra_c)
        x_ket = bra_ordering(x_ket, b_dagger, bra_c)
    x_ket = collapse_scalar_products(x_ket)

    x_ket = (x_ket*ket_b).expand()
    x_ket = sp.Add(x_ket)
    x_ket = (bra_b * x_ket).expand()
    x_ket = apply_all_b(x_ket)
    for _ in range(4):
        x_ket = bra_ordering(x_ket, a, bra_b)
        x_ket = bra_ordering(x_ket, a_dagger, bra_b)

    x_ket = collapse_scalar_products(x_ket)

    x_ket = (x_ket*ket_a).expand()
    x_braket = (bra_a * x_ket).expand()
    x_braket = apply_all_a(x_braket)
    expval = collapse_scalar_products(x_braket)
    expval = expval.subs(z, 0)
    return expval

def polishing(expr):
    expr = complex_ordering(expr.subs(z, 0)).expand()
    expr = sp.collect(expr, sp.Abs(al)**2)
    expr = sp.collect(expr, sp.sin(ph))
    expr = simphiperbolic(expr).expand()
    expr = sp.collect(expr, sp.sinh(2*r))
    expr = simptrigonometric(expr)
    expr = sp.collect(expr, sp.sin(ph))
    expr = simptrigonometric(expr)    
    return expr

def heisenberg(x_2, states):
    x_2 += 1
    x_1 = (x_2-1)**2 + 1
    squared_expval = 0
    expval = 0
    var = 0
    for i in states[0]:
        for j in states[0]:
            for k in states[1]:
                for l in states[1]:
                    for m in states[2]:   
                        for n in states[2]:
                            bra_states = [i[0], k[0], m[0]]
                            ket_states = [j[0], l[0], n[0]]
                            # print(bra_states, ket_states)
                            res1 = 0
                            for u in (x_1.expand()).args:
                                res1 += fun(u+gap, bra_states, ket_states)
                                res1 = res1.subs(gap, 0)
                            res2 = 0
                            for u in (x_2.expand()).args:
                                res2 += fun(u+gap, bra_states, ket_states)
                                res2 = res2.subs(gap, 0)
                            if bra_states == ket_states:
                            # if res2:
                                res2 -= 1 
                            # if res1:
                                res1 -= 1 
                            # print(f"res1: {res1}, res2: {res2}")
                            amp = sp.conjugate(i[1]*k[1]*m[1])*j[1]*l[1]*n[1]
                            squared_expval += amp*res1
                            expval += amp*res2
                            # print(res2)
    var = ((squared_expval) - (expval)**2).expand()
    var = polishing(var)
    expval = polishing(expval)
    return var, expval

In [None]:
states = [[[al, 1]], [[0, 1]], [[0, 1]]]

a_1 = a
b_1 = sp.cosh(r)*b-sp.sinh(r)*b_dagger
c_1 = c

a_1_dagger = a_dagger
b_1_dagger = sp.cosh(r)*b_dagger-sp.sinh(r)*b
c_1_dagger = c_dagger

a_2 = sp.sqrt(T)*a_1 + sp.sqrt(1-T)*b_1
b_2 = sp.sqrt(1-T)*a_1 - sp.sqrt(T)*b_1
c_2  = c_1

a_2_dagger = sp.sqrt(T)*a_1_dagger + sp.sqrt(1-T)*b_1_dagger
b_2_dagger = sp.sqrt(1-T)*a_1_dagger - sp.sqrt(T)*b_1_dagger
c_2_dagger = c_1_dagger

a_3 = a_2 * sp.exp(sp.I * ph)
b_3 = sp.sqrt(t)*b_2 + sp.sqrt(1-t)*c_2
c_3 = sp.sqrt(1-t)*b_2 - sp.sqrt(t)*c_2

a_3_dagger = a_2_dagger * sp.exp(-sp.I * ph)
b_3_dagger = sp.sqrt(t)*b_2_dagger + sp.sqrt(1-t)*c_2_dagger
c_3_dagger = sp.sqrt(1-t)*b_2_dagger - sp.sqrt(t)*c_2_dagger

a_4 = 1/sp.sqrt(2)*a_3 + 1/sp.sqrt(2)*b_3
b_4 = 1/sp.sqrt(2)*a_3 - 1/sp.sqrt(2)*b_3
c_4 = c_3

a_4_dagger = 1/sp.sqrt(2)*a_3_dagger + 1/sp.sqrt(2)*b_3_dagger
b_4_dagger = 1/sp.sqrt(2)*a_3_dagger - 1/sp.sqrt(2)*b_3_dagger
c_4_dagger = c_3_dagger

expr = a_4_dagger*a_4 + b_4_dagger*b_4

res, res2 = heisenberg(expr, states)

In [None]:
variance = res
expval = res2
diff_expval_ph = sp.diff(expval, ph)
diff_expval_ph = diff_expval_ph**2
phase_variance = variance/diff_expval_ph
########################################
diff_expval_t = sp.diff(expval, t)
diff_expval_t = diff_expval_t**2
t_variance = variance/diff_expval_t
########################################
solution = phase_variance
# solution = t_variance
