In [1]:
import sympy as sp
import sympy.stats as st
import itertools as it
from sympy.solvers.solveset import linsolve, linear_eq_to_matrix
import math

In [2]:
# highest order of moments
order = 4
# number of particle species
nn = 3

# the indices of the moments, WaWbWc
a = [st.rv.RandomSymbol(f'a_{j}') for j in range(nn)]
r = [st.rv.RandomSymbol(f'r_{j}') for j in range(nn)]
# The indices of the elements of the matrix, Ωijk
i = [st.rv.RandomSymbol(f'i_{j}') for j in range(nn)]

# t_a
t = sp.Function('t')
N = sp.Function('N')

N_RV = [st.rv.RandomSymbol(f'N_{j}') for j in range(nn)]
W_RV = [st.rv.RandomSymbol(f'W_{j}') for j in range(nn)]

ω = sp.Function('ω')
# h_i(t_a)
h = sp.Function('h')
MW = sp.Function('M_W')

κ = {int(i * '1'): sp.Function(f'κ{{{int(i * '1')}}}') for i in range(1, nn + 1)}

MW_arguments = [h(ii, *[t(aa) for aa in a]) for ii in i]

In [3]:
def get_all_combinations(order, m=2):
    ret = []
    for comb in it.product(range(order + 1), repeat=m):
        if 1 <= sum(comb) <= order:
            ret.append(comb)

    return sorted(ret)

list(reversed(sorted(get_all_combinations(order, 3), key=sum)))

[(4, 0, 0),
 (3, 1, 0),
 (3, 0, 1),
 (2, 2, 0),
 (2, 1, 1),
 (2, 0, 2),
 (1, 3, 0),
 (1, 2, 1),
 (1, 1, 2),
 (1, 0, 3),
 (0, 4, 0),
 (0, 3, 1),
 (0, 2, 2),
 (0, 1, 3),
 (0, 0, 4),
 (3, 0, 0),
 (2, 1, 0),
 (2, 0, 1),
 (1, 2, 0),
 (1, 1, 1),
 (1, 0, 2),
 (0, 3, 0),
 (0, 2, 1),
 (0, 1, 2),
 (0, 0, 3),
 (2, 0, 0),
 (1, 1, 0),
 (1, 0, 1),
 (0, 2, 0),
 (0, 1, 1),
 (0, 0, 2),
 (1, 0, 0),
 (0, 1, 0),
 (0, 0, 1)]

In [None]:
def W(*indices):
    derivatives = []
    for ind in indices:
        derivatives.append(t(a[ind]))

    res = sp.Derivative(MW(*MW_arguments), *derivatives).doit()

    # do the substutions for the <NaNbNc...>
    # ∂/∂(hj) → Nj
    N_symbols = []
    substitutions_Ns = []
    for idx in it.product(reversed(range(order + 1)), repeat=nn):
        if sum(idx) > order:
            continue

        derivs_all = []

        for idx_idx, ii in enumerate(idx):
            if ii != 0:
                deriv = (h(i[idx_idx], *[t(aa) for aa in a]), ii)
                derivs_all.append((i[idx_idx], ii,  deriv))

        newterm = 1
        for idx, ii, deriv in derivs_all:
            if deriv:
                newterm *= N(idx) ** ii

        newterm = st.E(newterm)

        derivs = [d for _, ii, d in derivs_all if ii > 0]
        if len(derivs) == 0:
            continue

        N_symbols.append(newterm)
        derivative = sp.Derivative(MW(*MW_arguments), *derivs)

        substitutions_Ns.append((derivative, newterm))

    res_Nsubs = res.subs(substitutions_Ns)

    κ = {int(i * '1'): sp.Function(f'κ{int(i * '1')}') for i in range(1, order + 1)}

    substitutions_kappas = []
    for idx in it.product(reversed(range(order + 1)), repeat=nn):
        if not 0 < sum(idx) <= order:
            continue

        kappa = κ[int(sum(idx) * '1')]

        deriv_ti = []
        for ii, exp in enumerate(idx):
            deriv_ti += exp * [t(a[ii])]

        omega_indices = []
        for ii, idxidx in enumerate(idx):
            omega_indices += idxidx * [a[ii]]

        for idx2 in i:
            newterm = kappa(ω(*omega_indices, idx2))
            derivative = sp.Derivative(h(idx2, *[t(aa) for aa in a]), *deriv_ti)
            substitutions_kappas.append((derivative, newterm))

    res_kappasubs = res_Nsubs.subs(substitutions_kappas)

    res_idxsubs = res_kappasubs.subs([
        (a[ii], jj) for ii, jj in enumerate(range(nn))
    ]).subs([
        (jj, ii) for ii, jj in enumerate(i)
    ])

    return res_idxsubs.expand().collect(N_symbols)

In [None]:
meanW = sp.Matrix([W(*(sum(list((exp * [ii] for ii, exp in enumerate(exps))), []))) for exps in get_all_combinations(order, 3)])

for exps in list(reversed(sorted(get_all_combinations(order, nn), key=sum))):
    term_N = math.prod((list((N(ii) ** exp for ii, exp in enumerate(exps)))))
    term_EN = st.E(math.prod((list((N_RV[ii] ** exp for ii, exp in enumerate(exps))))))
    meanW = meanW.subs(term_N, term_EN)

WW = sp.Matrix([st.E(math.prod((list((W_RV[ii] ** exp for ii, exp in enumerate(exps)))))) for exps in get_all_combinations(order, 3)])
Ns = sp.Matrix([st.E(math.prod((list((N_RV[ii] ** exp for ii, exp in enumerate(exps)))))) for exps in get_all_combinations(order, 3)])
display(Ns)
display(WW)

(2,)
(2, 2)
(2, 2, 2)
(2, 2, 2, 2)
(1,)
(1, 2)
(1, 2, 2)
(1, 2, 2, 2)
(1, 1)
(1, 1, 2)
(1, 1, 2, 2)
(1, 1, 1)
(1, 1, 1, 2)
(1, 1, 1, 1)


In [None]:
A, b = linear_eq_to_matrix(meanW - WW, list(Ns))

In [None]:
for ii, jj in it.product(range(A.shape[0]), range(A.shape[1])):
    print(f'A[{ii}, {jj}] = {A[ii, jj]}')

A[0, 0] = κ1(ω(2, 2))
A[0, 1] = 0
A[0, 2] = 0
A[0, 3] = 0
A[0, 4] = κ1(ω(2, 1))
A[0, 5] = 0
A[0, 6] = 0
A[0, 7] = 0
A[0, 8] = 0
A[0, 9] = 0
A[0, 10] = 0
A[0, 11] = 0
A[0, 12] = 0
A[0, 13] = 0
A[0, 14] = κ1(ω(2, 0))
A[0, 15] = 0
A[0, 16] = 0
A[0, 17] = 0
A[0, 18] = 0
A[0, 19] = 0
A[0, 20] = 0
A[0, 21] = 0
A[0, 22] = 0
A[0, 23] = 0
A[0, 24] = 0
A[0, 25] = 0
A[0, 26] = 0
A[0, 27] = 0
A[0, 28] = 0
A[0, 29] = 0
A[0, 30] = 0
A[0, 31] = 0
A[0, 32] = 0
A[0, 33] = 0
A[1, 0] = κ11(ω(2, 2, 2))
A[1, 1] = κ1(ω(2, 2))**2
A[1, 2] = 0
A[1, 3] = 0
A[1, 4] = κ11(ω(2, 2, 1))
A[1, 5] = 2*κ1(ω(2, 1))*κ1(ω(2, 2))
A[1, 6] = 0
A[1, 7] = 0
A[1, 8] = κ1(ω(2, 1))**2
A[1, 9] = 0
A[1, 10] = 0
A[1, 11] = 0
A[1, 12] = 0
A[1, 13] = 0
A[1, 14] = κ11(ω(2, 2, 0))
A[1, 15] = 2*κ1(ω(2, 0))*κ1(ω(2, 2))
A[1, 16] = 0
A[1, 17] = 0
A[1, 18] = 2*κ1(ω(2, 0))*κ1(ω(2, 1))
A[1, 19] = 0
A[1, 20] = 0
A[1, 21] = 0
A[1, 22] = 0
A[1, 23] = 0
A[1, 24] = κ1(ω(2, 0))**2
A[1, 25] = 0
A[1, 26] = 0
A[1, 27] = 0
A[1, 28] = 0
A[1, 29] = 0
A[1, 

In [None]:
for elem in b:
    print(elem)

Expectation(W_2)
Expectation(W_2**2)
Expectation(W_2**3)
Expectation(W_2**4)
Expectation(W_1)
Expectation(W_1*W_2)
Expectation(W_1*W_2**2)
Expectation(W_1*W_2**3)
Expectation(W_1**2)
Expectation(W_1**2*W_2)
Expectation(W_1**2*W_2**2)
Expectation(W_1**3)
Expectation(W_1**3*W_2)
Expectation(W_1**4)
Expectation(W_0)
Expectation(W_0*W_2)
Expectation(W_0*W_2**2)
Expectation(W_0*W_2**3)
Expectation(W_0*W_1)
Expectation(W_0*W_1*W_2)
Expectation(W_0*W_1*W_2**2)
Expectation(W_0*W_1**2)
Expectation(W_0*W_1**2*W_2)
Expectation(W_0*W_1**3)
Expectation(W_0**2)
Expectation(W_0**2*W_2)
Expectation(W_0**2*W_2**2)
Expectation(W_0**2*W_1)
Expectation(W_0**2*W_1*W_2)
Expectation(W_0**2*W_1**2)
Expectation(W_0**3)
Expectation(W_0**3*W_2)
Expectation(W_0**3*W_1)
Expectation(W_0**4)


In [None]:
A

Matrix([
[            κ1(ω(2, 2)),                                                                                                                                                                           0,                                                                                                                                                                0,                                      0,             κ1(ω(2, 1)),                                                                                                                                                                                                                                                                                                                                                         0,                                                                                                                                                                                                                           