In [1]:
import sympy as sym
from sympy.solvers import solve
from sympy import Symbol
from sympy.abc import x,y,z,a,b,c,d,e,f,g,h,i,j,k,l
import dill
import time
import numpy as np

In [2]:
# Helper functions to simplify final expressions
def collect_by_list(expr_in, sort_list):
    expr = sym.expand(expr_in)
    res = expr
    for s in sort_list:
        res = sym.rcollect(res, s)
    return res

def get_symbol_list(expr_in, return_counts=False):
    expr = sym.expand(expr_in)
    syms = {}
    for a in expr.atoms():
        if type(a) is sym.Symbol:
            syms[a] = expr.count(a)
    symlist = list(sorted(syms, key=syms.__getitem__, reverse=True))
    if return_counts:
        countlist = list(sorted(syms.values(), reverse=True))
        return symlist, countlist
    else:
        return symlist

def collect_auto(expr_in, print_ops=False, syms_to_prioritize=[]):
    slist = get_symbol_list(expr_in)
    sort_list = syms_to_prioritize
    for s in slist:
        if s not in syms_to_prioritize:
            sort_list.append(s)
    res = collect_by_list(expr_in, sort_list)
    if print_ops:
        print(sym.count_ops(res))
    return res

def collect_terms_in_fraction(expr_in, terms_to_prioritize=[]):
    numerator, denominator = sym.fraction(expr_in)
    num_simp = collect_auto(numerator, syms_to_prioritize=terms_to_prioritize)
    den_simp = collect_auto(denominator, syms_to_prioritize=terms_to_prioritize)
    return num_simp/den_simp

# Full System of Equations

In [3]:
dKMgO_KMgO = Symbol('dKMgO_KMgO')
dKFeO_KFeO = Symbol('dKFeO_KFeO')
dKSiO2_KSiO2 = Symbol('dKSiO2_KSiO2')
dKMgSiO3_KMgSiO3 = Symbol('dKMgSiO3_KMgSiO3')
dKFeSiO3_KFeSiO3 = Symbol('dKFeSiO3_KFeSiO3')

dM_Mg = Symbol('dM_Mg')
dM_Fe = Symbol('dM_Fe')
dM_Si = Symbol('dM_Si')
dM_O = Symbol('dM_O')
dM_c = Symbol('dM_c')

dM_MgO = Symbol('dM_MgO')
dM_FeO = Symbol('dM_FeO')
dM_SiO2 = Symbol('dM_SiO2')
dM_MgSiO3 = Symbol('dM_MgSiO3')
dM_FeSiO3 = Symbol('dM_FeSiO3')
dM_m = Symbol('dM_m')

M_Mg = Symbol('M_Mg')
M_Fe = Symbol('M_Fe')
M_Si = Symbol('M_Si')
M_O = Symbol('M_O')
M_c = Symbol('M_c')

M_MgO = Symbol('M_MgO')
M_FeO = Symbol('M_FeO')
M_SiO2 = Symbol('M_SiO2')
M_MgSiO3 = Symbol('M_MgSiO3')
M_FeSiO3 = Symbol('M_FeSiO3')
M_m = Symbol('M_m')

In [5]:
# Total Moles
# eq_mantle_total_moles = M_MgO + M_FeO + M_SiO2 + M_MgSiO3 + M_FeSiO3 - M_m
# eq_core_total_moles = M_Mg + M_Fe + M_Si + M_O - M_c
eq_mantle_dtotal_moles = dM_MgO + dM_FeO + dM_SiO2 + dM_MgSiO3 + dM_FeSiO3 - dM_m
eq_core_dtotal_moles = dM_Mg + dM_Fe + dM_Si + dM_O - dM_c

# mantle interactions
eq_K_MgSiO3 = dM_MgO/M_MgO + dM_SiO2/M_SiO2 - dM_MgSiO3/M_MgSiO3 - dM_m/M_m - dKMgSiO3_KMgSiO3
eq_K_FeSiO3 = dM_FeO/M_FeO + dM_SiO2/M_SiO2 - dM_FeSiO3/M_FeSiO3 - dM_m/M_m - dKFeSiO3_KFeSiO3

# core interactions
eq_K_MgO = dM_Mg/M_Mg + dM_O/M_O + dM_m/M_m - 2*dM_c/M_c - dM_MgO/M_MgO - dKMgO_KMgO
eq_K_FeO = dM_Fe/M_Fe + dM_O/M_O + dM_m/M_m - 2*dM_c/M_c - dM_FeO/M_FeO - dKFeO_KFeO
eq_K_SiO2 = dM_Si/M_Si + 2*dM_O/M_O + dM_m/M_m - 3*dM_c/M_c - dM_SiO2/M_SiO2 - dKSiO2_KSiO2

# species continuity
eq_dM_Mg = dM_MgO + dM_MgSiO3 + dM_Mg
eq_dM_Fe = dM_FeO + dM_FeSiO3 + dM_Fe
eq_dM_Si = dM_SiO2 + dM_MgSiO3 + dM_FeSiO3 + dM_Si
eq_dM_O =  dM_MgO + dM_FeO + 2*dM_SiO2 + 3*dM_MgSiO3 + 3*dM_FeSiO3 + dM_O

equations = [eq_mantle_dtotal_moles, eq_core_dtotal_moles, 
            eq_K_MgSiO3, eq_K_FeSiO3,
            eq_K_MgO, eq_K_FeO, eq_K_SiO2,
            eq_dM_Fe, eq_dM_Mg, eq_dM_O, eq_dM_Si]
solve_for = [dM_c, dM_Fe, dM_FeO, dM_FeSiO3, dM_m, dM_Mg, dM_MgO, dM_MgSiO3, dM_O, dM_Si, dM_SiO2]

In [8]:
start = -time.time()
solution = solve(equations, solve_for)
time_elapsed = start+time.time()

In [61]:
print('{:.1f} hrs to compute full solution'.format(time_elapsed/60/60))

1.9 hrs to compute full solution


In [12]:
# Save unsimplified full solution straight from the computation
dill.dump(solution,open('computed_solution.m','wb'))

In [7]:
solution = dill.load(open('computed_solution.m','rb'))

# Simplify Solution

In [8]:
# Simplify the solution using the helper functions at the top
terms_to_prioritize = [dKMgO_KMgO, dKFeO_KFeO, dKSiO2_KSiO2, dKMgSiO3_KMgSiO3, dKFeSiO3_KFeSiO3]
simplified = {}
for k,v in solution.items():
    simplified[k] = collect_terms_in_fraction(v, terms_to_prioritize=terms_to_prioritize)

In [9]:
# Demonstrate how much simpler the solutions are compared to those computed by sympy
print('{:.0f} operations in full solution'.format(sym.count_ops(solution[dM_Si])))
print('{:.0f} operations in simplified solution'.format(sym.count_ops(simplified[dM_Si])))

3745 operations in full solution
1507 operations in simplified solution


In [10]:
# save simplified solution
dill.dump(simplified, open('simplified_solution.m','wb'))

# Check by plugging in random values into Variables

In [48]:
def gen_and_check_rand_values(eqns, vars_to_rand=None):
    rand_values = gen_rand_vals(vars_to_rand=vars_to_rand)
    results = eval_eqns(eqns, rand_values)
    return rand_values, results

def gen_rand_vals(vars_to_rand = None):
    if vars_to_rand is None:
        vars_to_rand = [M_Mg, M_Fe, M_Si, M_O, M_c, M_MgO, M_FeO, M_SiO2, M_MgSiO3, M_FeSiO3, M_m, dKMgO_KMgO, dKFeO_KFeO, dKSiO2_KSiO2, dKMgSiO3_KMgSiO3, dKFeSiO3_KFeSiO3]
    rand_values = {}
    for v in vars_to_rand:
        rand_values[v] = np.random.rand()
    return rand_values

def eval_eqns(eqns, rand_values_dict):
    results = {}
    for symb in eqns.keys():
        results[symb] = eq_rand_values[symb].subs(rand_values_dict)
    return results


In [53]:
rand_values0, results0 = gen_and_check_rand_values(simplified)
rand_values1, results1 = gen_and_check_rand_values(simplified)

In [54]:
rand_results = ((rand_values0, results0),
(rand_values1, results1))
dill.dump(rand_results, open('rand_results.m','wb'))

# Import Simplified Solution

In [3]:
# open simplified solution
simplified = dill.load(open('simplified_solution2.m','rb'))

In [7]:
eqns_file = open('eqns_funcs.py','w')
for k,v in simplified.items():
    eqns_file.write('\tdef '+str(k)+'_dTc(self, Moles, dKs, dMi_b):\n')
    eqns_file.write('\t\t\'\'\'compute {} given Moles, dKDs, and dMm_b/dT\'\'\'\n'.format(k))
    eqns_file.write('\t\tdM_MgO_er, dM_SiO2_er, dM_FeO_er, dM_MgSiO3_er, dM_FeSiO3_er = dMi_b\n')
    eqns_file.write('\t\tM_Mg, M_Si, M_Fe, M_O, M_c, M_MgO, M_SiO2, M_FeO, M_MgSiO3, M_FeSiO3, M_m = self.unwrap_Moles(Moles)\n')
    eqns_file.write('\t\tdKMgO_KMgO, dKSiO2_KSiO2, dKFeO_KFeO, dKMgSiO3_KMgSiO3, dKFeSiO3_KFeSiO3 = dKs\n')
    eqns_file.write('\t\treturn '+str(v)+'\n\n')
eqns_file.close()
