Skip to content

Commit

Permalink
added dimension and derivative type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethsible committed May 15, 2024
1 parent eb64882 commit 638c225
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 155 deletions.
48 changes: 17 additions & 31 deletions nrpylatex/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, parser):

def generate(self, LHS, RHS, impsum=True):
# perform implied summation on indexed expression
LHS_RHS, dimension = self.expand_summation(LHS, RHS, impsum)
LHS_RHS, dimension, suffix = self.expand_summation(LHS, RHS, impsum)
if self._property['debug']:
lineno = '[%d]' % self._property['debug']
print('%s Python' % (len(lineno) * ' '))
Expand All @@ -37,7 +37,7 @@ def generate(self, LHS, RHS, impsum=True):
except IndexError:
raise GeneratorError('index out of range; change loop/summation range')

return global_env, dimension
return global_env, dimension, suffix

def expand_summation(self, LHS, RHS, impsum=True):
tree, indexing = ExprTree(LHS), []
Expand Down Expand Up @@ -75,10 +75,11 @@ def expand_summation(self, LHS, RHS, impsum=True):
subexpr = subtree.expr
if subexpr.func == Function('Tensor'):
symbol = str(subexpr.args[0])
dimension = self._namespace[symbol].dimension
for index in subexpr.args[1:]:
if str(index) in self._property['index']:
dimension = self._property['index'][str(index)]
else:
dimension = self._namespace[symbol].dimension
if str(index) in index_range and dimension != index_range[str(index)]:
raise GeneratorError('inconsistent loop/summation range for index \'%s\'' % index)
index_range[str(index)] = dimension
Expand All @@ -94,10 +95,11 @@ def expand_summation(self, LHS, RHS, impsum=True):
argument = subexpr.args[0]
derivative = 'diff(' + srepr(argument)
symbol = str(argument.args[0])
dimension = self._namespace[symbol].dimension
for index, order in subexpr.args[1:]:
if str(index) in self._property['index']:
dimension = self._property['index'][str(index)]
else:
dimension = self._namespace[symbol].dimension
if str(index) in index_range and dimension != index_range[str(index)]:
raise GeneratorError('inconsistent loop/summation range for index \'%s\'' % index)
index_range[str(index)] = dimension
Expand Down Expand Up @@ -160,11 +162,16 @@ def expand_summation(self, LHS, RHS, impsum=True):
dimension_LHS = index_range[index]

# shift tensor indexing forward whenever dimension > upper bound
# and infer derivative suffix of LHS tensor from RHS tensors
suffix_LHS = None
for subtree in tree.preorder():
subexpr = subtree.expr
if subexpr.func == Function('Tensor'):
symbol = str(subexpr.args[0])
dimension = self._namespace[symbol].dimension
suffix = self._namespace[symbol].suffix
if suffix is not None:
suffix_LHS = self._property['suffix']
tensor = IndexedSymbol(subexpr, dimension)
indexing = IndexedSymbol.indexing(subexpr)
for index in subexpr.args[1:]:
Expand All @@ -177,7 +184,7 @@ def expand_summation(self, LHS, RHS, impsum=True):
indexing[i] = ('%s + %s' % (idx, shift), pos)
equation[-1] = equation[-1].replace(tensor.array_format(subexpr), tensor.array_format(indexing))

return ' = '.join(equation), dimension_LHS
return ' = '.join(equation), dimension_LHS, suffix_LHS

@staticmethod
def separate_indexing(indexing, symbol_LHS, impsum=True):
Expand Down Expand Up @@ -213,39 +220,21 @@ def generate_metric(symbol, dimension, suffix):
r'\epsilon_{' + ' '.join('j_' + str(i) for i in range(1, 1 + dimension)) + '} '
det_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}^{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(1, 1 + dimension))
inv_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}^{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(2, 1 + dimension))
if suffix:
latex_config += r"% declare {symbol}det {inv_symbol} --dim {dimension} --suffix {suffix}" \
.format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('U', 'D'), dimension=dimension)
else:
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
latex_config += r"""
\mathrm{{{symbol}det}} = \frac{{1}}{{({dimension})({factorial})}} {det_latex} \\
\mathrm{{{symbol}}}_{{i_1 j_1}} = \frac{{1}}{{{factorial}}} \mathrm{{{symbol}det}}^{{{{-1}}}} ({inv_latex}) \\""" \
.format(symbol=symbol[:-2], inv_symbol=symbol.replace('U', 'D'), dimension=dimension,
factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
# latex_config += '\n' + r"% assign {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
# if suffix:
# latex_config += '\n' + r"% assign {symbol}det {inv_symbol} --suffix {suffix}" \
# .format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('U', 'D'))
.format(symbol=symbol[:-2], dimension=dimension, factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
else:
prefix = r'\epsilon^{' + ' '.join('i_' + str(i) for i in range(1, 1 + dimension)) + '} ' + \
r'\epsilon^{' + ' '.join('j_' + str(i) for i in range(1, 1 + dimension)) + '} '
det_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}_{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(1, 1 + dimension))
inv_latex = prefix + ' '.join(r'\mathrm{{{symbol}}}_{{i_{n} j_{n}}}'.format(symbol=symbol[:-2], n=i) for i in range(2, 1 + dimension))
if suffix:
latex_config += r"% declare {symbol}det {inv_symbol} --dim {dimension} --suffix {suffix}" \
.format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('D', 'U'), dimension=dimension)
else:
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
latex_config += r"% declare {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
latex_config += r"""
\mathrm{{{symbol}det}} = \frac{{1}}{{({dimension})({factorial})}} {det_latex} \\
\mathrm{{{symbol}}}^{{i_1 j_1}} = \frac{{1}}{{{factorial}}} \mathrm{{{symbol}det}}^{{{{-1}}}} ({inv_latex}) \\""" \
.format(symbol=symbol[:-2], inv_symbol=symbol.replace('D', 'U'), dimension=dimension,
factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
# latex_config += '\n' + r"% assign {symbol}det --dim {dimension}".format(symbol=symbol[:-2], dimension=dimension)
# if suffix:
# latex_config += '\n' + r"% assign {symbol}det {inv_symbol} --suffix {suffix}" \
# .format(suffix=suffix, symbol=symbol[:-2], inv_symbol=symbol.replace('D', 'U'))
.format(symbol=symbol[:-2], dimension=dimension, factorial=math.factorial(dimension - 1), det_latex=det_latex, inv_latex=inv_latex)
return latex_config

@staticmethod
Expand Down Expand Up @@ -288,10 +277,7 @@ def generate_covdrv(function, covdrv_index, symbol=None, diacritic=None, dimensi
RHS += '^{%s}_{%s %s} (%s)' % (index, bound_index, covdrv_index, latex)
else:
RHS += '^{%s}_{%s %s} (%s)' % (bound_index, index, covdrv_index, latex)
config = ('% declare ' + symbol + ' --dim %d --suffix dD\n' % dimension) if symbol else ''
return config + LHS + ' = ' + RHS
# config = (' % assign ' + symbol + ' --suffix dD\n') if symbol else ''
# return LHS + ' = ' + RHS + config
return LHS + ' = ' + RHS

@staticmethod
def generate_liedrv(function, vector, weight=None):
Expand Down
Loading

0 comments on commit 638c225

Please sign in to comment.