In [15]:
import numpy as np
from numba import njit

In [31]:
# irrep for each spatial MO
irreps = (1,1,4,1,3,1,4,4,1,1,3,4,2,1,3,4,1,4,1,3,2,1,1,4)

nelec = 10

# numba wants numpy arrays!
occ = np.asarray(list(range(int(nelec/2)))) # closed shell case
vir = np.asarray(list(range(occ[-1]+1,len(irreps))))

sym_target = 1

In [26]:
@njit
def sym_mult_table(irrepA, irrepB):
    mult_table = np.array([[1,2,3,4],
                           [2,1,4,3],
                           [3,4,1,2],
                           [4,3,2,1]])
 
    return mult_table[irrepA-1, irrepB-1]

@njit
def count_t3A_t3D(occ, vir, irreps):
    nocc = len(occ)
    nvir = len(vir)
    
    ctA = 0
    for i in range(nocc):
        for j in range(i+1,nocc):
            for k in range(j+1,nocc):
                for a in range(nvir):
                    for b in range(a+1,nvir):
                        for c in range(b+1,nvir):
                            sym_hole = sym_mult_table(sym_mult_table(irreps[occ[i]],irreps[occ[j]]),irreps[occ[k]])
                            sym_particle = sym_mult_table(sym_mult_table(irreps[vir[a]],irreps[vir[b]]),irreps[vir[c]])
                            sym_exc = sym_mult_table(sym_hole,sym_particle)
                            if sym_exc == sym_target:
                                ctA += 1
    return ctA

@njit
def count_t3B_t3C(occ, vir, irreps):
    nocc = len(occ)
    nvir = len(vir)
    
    ctB = 0
    for i in range(nocc):
        for j in range(i+1,nocc):
            for k in range(nocc):
                for a in range(nvir):
                    for b in range(a+1,nvir):
                        for c in range(nvir):
                            sym_hole = sym_mult_table(sym_mult_table(irreps[occ[i]],irreps[occ[j]]),irreps[occ[k]])
                            sym_particle = sym_mult_table(sym_mult_table(irreps[vir[a]],irreps[vir[b]]),irreps[vir[c]])
                            sym_exc = sym_mult_table(sym_hole,sym_particle)
                            if sym_exc == sym_target:
                                ctB += 1
    return ctB

@njit
def count_t4A_t4E(occ, vir, irreps):
    nocc = len(occ)
    nvir = len(vir)
    
    ctA = 0
    for i in range(nocc):
        for j in range(i+1,nocc):
            for k in range(j+1,nocc):
                for l in range(k+1,nocc):
                    for a in range(nvir):
                        for b in range(a+1,nvir):
                            for c in range(b+1,nvir):
                                for d in range(c+1,nvir):
                                    sym_hole = sym_mult_table(sym_mult_table(sym_mult_table(irreps[occ[i]],irreps[occ[j]]),irreps[occ[k]]),irreps[occ[l]])
                                    sym_particle = sym_mult_table(sym_mult_table(sym_mult_table(irreps[vir[a]],irreps[vir[b]]),irreps[vir[c]]),irreps[vir[d]])
                                    sym_exc = sym_mult_table(sym_hole,sym_particle)
                                    if sym_exc == sym_target:
                                        ctA += 1
    return ctA

@njit
def count_t4B_t4D(occ, vir, irreps):
    nocc = len(occ)
    nvir = len(vir)
    
    ctB = 0
    for i in range(nocc):
        for j in range(i+1,nocc):
            for k in range(j+1,nocc):
                for l in range(nocc):
                    for a in range(nvir):
                        for b in range(a+1,nvir):
                            for c in range(b+1,nvir):
                                for d in range(nvir):
                                    sym_hole = sym_mult_table(sym_mult_table(sym_mult_table(irreps[occ[i]],irreps[occ[j]]),irreps[occ[k]]),irreps[occ[l]])
                                    sym_particle = sym_mult_table(sym_mult_table(sym_mult_table(irreps[vir[a]],irreps[vir[b]]),irreps[vir[c]]),irreps[vir[d]])
                                    sym_exc = sym_mult_table(sym_hole,sym_particle)
                                    if sym_exc == sym_target:
                                        ctB += 1
    return ctB

@njit
def count_t4C(occ, vir, irreps):
    nocc = len(occ)
    nvir = len(vir)
    
    ctC = 0
    for i in range(nocc):
        for j in range(i+1,nocc):
            for k in range(nocc):
                for l in range(k+1,nocc):
                    for a in range(nvir):
                        for b in range(a+1,nvir):
                            for c in range(nvir):
                                for d in range(c+1,nvir):
                                    sym_hole = sym_mult_table(sym_mult_table(sym_mult_table(irreps[occ[i]],irreps[occ[j]]),irreps[occ[k]]),irreps[occ[l]])
                                    sym_particle = sym_mult_table(sym_mult_table(sym_mult_table(irreps[vir[a]],irreps[vir[b]]),irreps[vir[c]]),irreps[vir[d]])
                                    sym_exc = sym_mult_table(sym_hole,sym_particle)
                                    if sym_exc == sym_target:
                                        ctC += 1
    return ctC

In [27]:
num_triples = 2*count_t3A_t3D(occ, vir, irreps) + 2*count_t3B_t3C(occ, vir, irreps)

In [28]:
num_quadruples = 2*count_t4A_t4E(occ, vir, irreps) + 2*count_t4B_t4D(occ, vir, irreps) + count_t4C(occ, vir, irreps)

In [29]:
num_triples

86864

In [30]:
num_quadruples

1201298