### **Diagrammatic Rayleigh-Schrödinger Perturbation Theory Applied to the 1D Infinite Square Well with Midpoint Repulsive Dirac Delta Perturbation**
### By Carlo van Maaren and Joeri van Limpt

In [None]:
##################################################################
# IMPORTS
##################################################################
import itertools
from typing import List, Tuple, Union, Any, Set
import numpy as np
import sympy as sp
import re
from IPython.display import display, Math
from sympy.parsing.mathematica import parse_mathematica
from wolframclient.evaluation import WolframLanguageSession # Requires installation, see https://reference.wolfram.com/language/workflow/EvaluateAWolframLanguageExpressionFromPython.html
from wolframclient.language import wl, wlexpr # Requires installation, see https://reference.wolfram.com/language/workflow/EvaluateAWolframLanguageExpressionFromPython.html
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from winsound import Beep

##################################################################
# FUNCTIONS
##################################################################

'''GENERAL FUNCTIONS TO PRODUCE THE RAYLEIGH-SCHRÖDINGER PERTURBATION THEORY ENERGY CORRECTIONS USING THE DIAGRAMMATIC APPROACH DISCUSSED IN SZABO AND OSTLUNDS'S MODERN QUANTUM CHEMISTRY'''

def generate_unlabeled_diagrams(order: int) -> List[List[Tuple[int, int]]]:
    dots = range(order) # Create a range of dots from 0 to (order - 1), i.e. following the procedure of "Modern Quantum Chemistry" by Szabo and Ostlund but starting at 0 instead of 1 (this allows us to use the dot positions as indices which removes the need for some unnecessary operations)
    all_permutations = itertools.permutations(dots) # Generate all permutations of the list of dots using itertools in order to make sure we consider all possible pairs (each line is connected between two dots - the first order diagram is also taken into account for in the following; it is paired to itself)
    unique_diagrams = set(tuple(sorted(zip(p, p[1:] + p[:1]))) for p in all_permutations) # Create a set of unique diagrams by sorting and zipping each permutation. The zip operation pairs each dot with the next dot in the permutation (through the circular shift p[1:] + p[:1]) and the sorted function ensures that each pair is in ascending order such that set() correctly ensures uniqueness. Tuples are used because they are hashable and can be used as elements in a set. They also allow more readibility of the outputted diagrams. The pairs are used to determine the (direction of) lines (by the order: (start, end))   
    return [list(diagram) for diagram in unique_diagrams] # Convert the unique diagrams back to lists and return the result

def determine_number_of_hole_lines(diagram: List[Union[Tuple[int, int], Tuple[int, int, int]]]) -> int:
    if len(diagram) == 1: # Handle the case of first-order correction
        return 1 # By definition, this is a hole line, see p. 331 of "Modern Quantum Chemistry" by Szabo and Ostlund
    return sum(1 for pair in diagram if pair[0] < pair[1]) # Count hole lines by considering increasing tuples / pairs (i.e. down-pointing arrows) and return the (general) result

def determine_number_of_particle_lines(diagram: List[Union[Tuple[int, int], Tuple[int, int, int]]]) -> int:
    if len(diagram) == 1: # Handle the case of first-order correction
        return 0 # By definition, this is a hole line, see p. 331 of "Modern Quantum Chemistry" by Szabo and Ostlund
    return sum(1 for pair in diagram if pair[0] > pair[1]) # Count particle lines by considering decreasing tuples / pairs (i.e. up-pointing arrows) and return the (general) result

def label_diagram(diagram: List[Tuple[int, int]], n: int = 1) -> List[Union[Tuple[int, int, int], Tuple[int, int, str]]]:    
    if len(diagram) == 1: # Handle the case of first order correction
        return [(diagram[0], diagram[0], 1)] # By definition, this is a hole line, see p. 331 (or p. 333) of "Modern Quantum Chemistry" by Szabo and Ostlund
    
    # To generalize to N states, we must label each of the particle lines by 2 through N independently considering all combinations, see Sec. 6.2.2 (p. 335) of "Modern Quantum Chemistry" by Szabo and Ostlund
    num_particle_lines = determine_number_of_particle_lines(diagram)
    particle_line_labels = [sp.Symbol(f"p_{i}", integer=True, postive=True) for i in range(1, num_particle_lines+1)] # Labels for particle lines

    labeled_diagram = []
    particle_line_idx = 0
    for dot1, dot2 in diagram:
        # Assign labels based on the direction of the lines: decreasing tuple = particle line (one of the num_particle_lines labels), increasing tuple = hole line (label 0)
        if dot1 > dot2: # Particle line (decreasing tuple / pair -> up-pointing arrow = particle line)
            label = particle_line_labels[particle_line_idx]
            particle_line_idx += 1
        else: # Hole line (increasing tuple / pair -> down-pointing arrow = hole line)
            label = n

        labeled_diagram.append((dot1, dot2, label))
                            
    return labeled_diagram

def generate_diagrams(order: int, n: int = 1) -> Tuple[List[List[Tuple[int, int, int]]], List[List[Tuple[int, int]]]]:
    # Generate unlabeled diagrams
    unlabeled_diagrams = generate_unlabeled_diagrams(order) 
    
    # Label each diagram, unpack and output the result
    labeled_diagrams = [label_diagram(diagram, n) for diagram in unlabeled_diagrams]
    return labeled_diagrams, unlabeled_diagrams

def determine_dot_in_out_labels(labeled_diagram: List[Tuple[Tuple[int, int], int]], n: int = 1) -> np.ndarray[Tuple[object, object]]:
    labels_in_out = np.empty((len(labeled_diagram), 2), dtype=object) # Each dot has one incoming and one outgoing line
    
    if len(labeled_diagram) == 1: # Handle the case of first order correction
        labels_in_out[0] = (n,n) # See p. 333 of "Modern Quantum Chemistry" by Szabo and Ostlund (n is the hole line index, generalized so that we can also get the higher energy levels of the perturbed system)
        return labels_in_out

    for dot1, dot2, label in labeled_diagram:
        labels_in_out[dot2,0] = label # For dot2 this is the incoming line, ...
        labels_in_out[dot1,1] = label # whilst for dot1 this is the outgoing line

    return labels_in_out

def calculate_numerator(labeled_diagram: List[Tuple[int, int, int]], n: int = 1) -> sp.Expr:
    # Determine labels of in and outgoing lines
    labels_in_out = determine_dot_in_out_labels(labeled_diagram, n)
    
    # Rule 1: Each dot contributes a factor ⟨ψ_label_line_in|V|ψ_label_line_out⟩ to the numerator
    return sp.Mul(*[sp.IndexedBase("V")[*labels_in_out[i]] for i in range(len(labeled_diagram))])

def calculate_denominator(labeled_diagram: List[Tuple[int, int, int]], order: int, n: int = 1) -> sp.Expr:
    denominator_factor = 1 # Initial denominator_factor to 1
    for dot1 in range(order - 1): # Also accounts for first order case (you get range(0) so the for-loop is skipped)
        dot2 = dot1 + 1 # Adjacent pairs of dots
        
        # Identify lines crossing the imaginary line (these are the dots starting with a dot lower (i.e. higher dot index) than the bottom one out of the adjacent pair and ending with a dot above it (i.e. lower dot index), and vice versa)
        crossing_labels = [label for d1, d2, label in labeled_diagram if (d1 >= dot2 and d2 < dot2) or (d2 >= dot2 and d1 < dot2)]
            
        # Compute the sums of particle and hole lines
        sum_particle_lines = sum(sp.IndexedBase("E^{(0)}")[label] for label in crossing_labels if label != n)
        sum_hole_lines = sum(sp.IndexedBase("E^{(0)}")[label] for label in crossing_labels if label == n)
        
        # Calculate the difference between the sums and update the overall factor
        denominator_factor *= sum_hole_lines - sum_particle_lines

    return denominator_factor

def calculate_general_diagram_expression(labeled_diagram: List[Tuple[int, int, int]], order: int, n: int = 1) -> sp.Expr:
    # Rule 1: Each dot contributes a factor ⟨ψ_label_line_in|V|ψ_label_line_out⟩ to the numerator
    numerator = calculate_numerator(labeled_diagram, n)
    
    # Rule 2: Each pair of adjacent dots contributes to the denominator factor ∑ E(0)_{hole} − E(0)_{particle}. This sum runs over all lines crossing an imaginary horizontal line separating two adjacent dots
    denominator = calculate_denominator(labeled_diagram, order, n)
    
    # Rule 3: The overall sign of the expression is (−1)^{h+l}, where h is the number of hole lines and l is the number of closed loops (which is 1 for the diagrams we are considering) 
    h = determine_number_of_hole_lines(labeled_diagram) # Number of hole lines 
    l = 1 # Number of closed loops, this is always 1 for the diagrams we are considering (see Rule 3 on p. 332 of "Modern Quantum Chemistry" by Szabo and Ostlund)
    overall_sign = (-1)**(h + l)
    
    # Calculate expression using the results above and return the result
    return overall_sign * numerator / denominator 

def calculate_complete_expression(labeled_diagrams: List[List[Tuple[int, int, int]]], order: int, n: int = 1, start: int = 1, end: Union[int, sp.Expr] = sp.oo) -> sp.Expr:
    if len(labeled_diagrams[0]) == 1: # Handle the case of first-order correction
        return calculate_general_diagram_expression(labeled_diagrams[0], order, n) # Note that the output would be zero if we let the following code run for the first-order correction (it is not labelled by p_i since it has no particle lines)
    
    # Sum the results over all particle line labels (i.e. excluding the hole line index 0) and output the result
    output = 0
    for labeled_diagram in labeled_diagrams:
        num_particle_lines = determine_number_of_particle_lines(labeled_diagram)

        perform_sum = lambda x_i, i, remove = n, start = start, end = end: (
            sp.Sum(x_i, (i, start, remove-1)) * int(start != remove) # using "* int(start != remove)" to remove this first sum if start == remove
            + sp.Sum(x_i, (i, remove+1, end)) * int(end != remove) # same trick as the line above this but now to check for end == remove
        ) # This is equivalent to the sum "nm... =/= 1" in the book (note n >= 1 = start standard value) 
        
        total_sum = perform_sum(calculate_general_diagram_expression(labeled_diagram, order, n), sp.Symbol("p_1", integer=True, positive=True))   
        for i in range(2, num_particle_lines+1): # Note that we now skip p_1 since it was already used when initialzing total_sum
            total_sum = perform_sum(total_sum, sp.Symbol(f"p_{i}", integer=True, positive=True)) # Double sums are not yet implemented in Sympy as far as we are aware, hence this approach has to be taken
            
        output += total_sum
        
    return output

'''FUNCTIONS FOR EVALUATING THE GENERAL EXPANSIONS FOR A SPECIFIC PERTURBED SYSTEM'''

def evaluate_expression_for_problem(expression: sp.Expr, V_ij: sp.Expr, E0_i: sp.Expr) -> sp.Expr:
    # Collect IndexedBase symbols from expression so that we can substitute them using the sympy_V_ij and sympy_E0_i
    relevant_symbols = list(expression.atoms(sp.Indexed)) 
    based_symbols = sp.sift(relevant_symbols, lambda i: i.base)
    
    # Substitute the free symbols using the provided Sympy functions for V_ij and E0_i (we use xreplace() instead of subs() because otherwise the sum is not updated correctly)
    expression = expression.xreplace({V: V_ij.subs({sp.Symbol("i"): V.indices[0], sp.Symbol("j"): V.indices[1]}) for V in based_symbols[sp.IndexedBase("V")]})
    expression = expression.xreplace({E0: E0_i.subs({sp.Symbol("i"): E0.indices[0]}) for E0 in based_symbols[sp.IndexedBase("E^{(0)}")]})
        
    return expression

'''FUNCTIONS FOR EVALUATING THE CASE STUDY OF THE MIDPOINT DIRAC DELTA PERTURBED INFINITE SQUARE WELL WHICH REQUIRED EXTRA PROCESSING'''

def convert_subscript_to_mathematica(sum_idx_str: str): 
    # Convert 'p_{x}' (or 'p_x') to 'Subscript[p, x]' so that Mathematica can understand
    return re.sub(r'(\w)_(\{.*?\}|\w)', lambda match: f'Subscript[{match.group(1)}, {match.group(2)[1:-1]}]' if '{' in match.group(2) else f'Subscript[{match.group(1)}, {match.group(2)}]', sum_idx_str)

def convert_mathematica_subscript_to_latex(expr_str: str):
    # Convert 'Subscript[p, x]' to 'p_{x}' or 'p_x'
    return re.sub(r'Subscript\[(\w+), (\w+)\]', r'\1_{\2}', expr_str)

def convert_mathematica_rational(mathematica_str: str) -> str:
    # Convert all occurrences of 'Rational[a,b]' to 'a/b'
    def replace_rational(match: str):        
        numerator, denominator = map(int, match.groups())
        return f"{numerator}/{denominator}"

    return re.sub(r'Rational\[(\s*-?\d+\s*), (\s*-?\d+\s*)\]', replace_rational, mathematica_str)

def execute_mathematica_command(session: WolframLanguageSession, command: str, timeout_seconds: int = 60) -> Union[str, bool]:
    # Execute Mathematica command and return the result or False if timeout
    try:
        result = session.evaluate(wlexpr(f'TimeConstrained[{command}, {timeout_seconds}]')) # The standard timeout of evaluate() has some issues (https://github.com/WolframResearch/WolframClientForPython/issues/29)
        return result
    except Exception as e:
        print(f"Error executing Mathematica command: {e}")
        return False

def interpret_mathematica_output(output: str) -> Union[sp.Expr, bool]:
    # Convert Mathematica output to sympy expression
    if not isinstance(output, str):
        output = str(output) # Make sure output is a string
    output = output.replace('Global`','') # replace('Global`','') is added for correct interpretation by parse_mathematica
    if 'Subscript' in output:
        output = convert_mathematica_subscript_to_latex(output) # parse_mathematica cannot interpret the subscripts in Mathematica code
    if 'Rational' in output:
        output = convert_mathematica_rational(output) # parse_mathematica leaves Rational in the expressions until the end (and then gives an error message when displaying); this is undesired
    try:
        return parse_mathematica(output) 
    except Exception as e:
        print(f"Error interpreting Mathematica output: {e}")
        print(f"The output attempted to be parsed: {output}")
        return False
    
def find_indices_tuple_list_combination_to_match_reference(tuple_list: List[Tuple[Any]], reference_list: List[Any]) -> Union[List[int], None]:
    # Used to find the indices of tuples in tuple_list that combine to exactly match the elements of reference_list
        
    # Perform a check if it is possible to combine the tuples in tuple_list to represent reference_list
    min_length = len(reference_list) # Minimum length of tuples in tuple_list combined needed for a combination
    length = sum(map(len, tuple_list))
    if min_length > length:        
        return None # Return None if there are not enough elements in tuple_list to fully represent reference_list

    # Iterate over all possible combinations of tuples
    for r in range(1, len(tuple_list) + 1):
        for tuple_combi in itertools.combinations(tuple_list, r):
            flat_combi = [item for t in tuple_combi for item in t] # Flatten the combination of tuples to a single list
            if sorted(flat_combi) == sorted(reference_list): # Check if the flattened combination exactly matches the elements of reference_list
                tuple_indices = [tuple_list.index(t) for t in tuple_combi] # Get indices corresponding to tuples that make up reference_list
                return tuple_indices
    return None

def move_numbers_above_threshold_to_end(input_list: List[Union[float,int]], threshold: Union[float, int]) -> List[Union[float,int]]:
    # Use list comprehensions to create two lists: one for numbers below the threshold and one for numbers above or equal to the threshold
    below_threshold = [x for x in input_list if x < threshold]
    above_threshold = [x for x in input_list if x >= threshold]

    # Concatenate the two lists to get the final result
    return below_threshold + above_threshold

def compute_result_cluster(cluster: List[str], sum_info: List[str], timeout_seconds: int = 60, reverse_switch: bool = False) -> sp.Expr:
    failed_combinations_set = set() # Create set to store the combinations which we will save the failed combinations to. If this combination would occur again (some terms might be the same), then we can skip them without re-computing 
    succesful_combinations_list = [] # Create an empty list to store the combinations that did succesfully go through - this we can use to see if we have 'used' all equations and can break out of the for-loop below
    succesful_results_list = [] # Create an empty list to store the results of the succesful combinations

    # Define a lock to make sure that writing to the lists and sets above can only be accessed and changed by one thread at a time
    lock = threading.Lock()

    # Loop over all combinations until there is a combination of summands which can provide the full solution
    range_combi_sizes = list(reversed(range(1, len(cluster) + 1))) if reverse_switch else list(range(1, len(cluster) + 1))
    while range_combi_sizes: # Done like this so that we can alter the range_combi_size while in the loop; when there are no more sizes to go through, the loop will stop
        combi_size = range_combi_sizes[0] # Always index the first element (note that the combi_sizes are removed from the list after each iteration)
                
        # Generate all combinations of expressions with this combi_size
        combinations_list = list(itertools.combinations(cluster, combi_size))

        # Process each combination "in parallel"
        with ThreadPoolExecutor() as executor:
            futures = [executor.submit(
                process_combination, 
                combi, 
                cluster,
                sum_info,
                failed_combinations_set, 
                succesful_combinations_list, 
                succesful_results_list, 
                lock, 
                timeout_seconds
            ) for combi in combinations_list]

            # Collect results
            for future in as_completed(futures):
                done = future.result()
                if done: # Check if the problem was solved by the completed theads, ...
                    break # ... and if so, break out of the outer for-loop
        
        # Check if the problem was solved for this combi_size
        if done:
            break # ... and if so, break out of the outer for-loop
        
        # Move the range of values above the maximum number of elements in the succesful_combinations_list to the end and move this threshold value to the start (also remove values we already encountered to not re-compute them)
        range_combi_sizes.remove(combi_size) # Remove current combi_size
        if reverse_switch and succesful_combinations_list: # Only useful when reverse_switch = True (and succesful_combinations_list must not be empty)
            threshold = len(cluster) - max(map(len, succesful_combinations_list)) # Compute aforementioned threshold value
            range_combi_sizes = move_numbers_above_threshold_to_end(range_combi_sizes, threshold) # Update the range_combi_sizes accordingly by fist performing the block shift...
            if threshold in range_combi_sizes:
                range_combi_sizes.insert(0, range_combi_sizes.pop(range_combi_sizes.index(threshold))) # ... and secondly moving the threshold value
    
    # If all the combi_sizes have been run through without success, raise an error to indicate that the method has failed
    if not done:
        raise ValueError(f'No combinations of expressions in the cluster yielded a solution within the provided timeout ({timeout_seconds} seconds).')

    total_cluster_result = sum(succesful_results_list[i] for i in find_indices_tuple_list_combination_to_match_reference(succesful_combinations_list, cluster)) # Add the succesful results to compute the total result for this cluster
    return total_cluster_result

def process_combination(combi: List[str], summands: List[str], sum_info: List[str], failed_combinations_set: Set[Tuple[str]], succesful_combinations_list: List[Tuple[str]], \
                        succesful_results_list: List[sp.Expr], lock: threading.Lock, timeout_seconds: int = 60) -> bool:
    with lock:
        if combi in failed_combinations_set: # If this combination was already attempted unsuccesfully, ...
            return False # skip this combi
            
        if combi in succesful_combinations_list: # If this combination was already succesfully computed, skip the computation
            idx = succesful_combinations_list.index(combi)
            result = succesful_results_list[idx]
            
            # Append the combination of summands and the result to their intended lists and check if all terms have been successfully processed
            succesful_combinations_list.append(combi)
            succesful_results_list.append(result)
            if find_indices_tuple_list_combination_to_match_reference(succesful_combinations_list, summands):
                return True # The problem is solved, so return True

        # Attempt to simplify the sum of this combination of expressions
        session = WolframLanguageSession() # Start WolframLanguageSession
        mathematica_sum = '+'.join(map(lambda x: f'ToExpression["{x}", TeXForm]', combi))
        mathematica_command = f'partialresult = Simplify[{mathematica_sum}]' # The WolframClient outputs a shortened output for large equations which cannot be used here in Python; therefore we define this result also through Mathematica
        mathematica_sum_command = lambda summand, sum_data: f'Simplify[Sum[{summand}, {", ".join(sum_data)}]]'

        result = execute_mathematica_command(session, mathematica_command, timeout_seconds)

        if not result or any(x in str(result) for x in ['$Failed', '$Aborted']): # If this computation was unsuccessful, continue to the next combination
            failed_combinations_set.add(combi)
            session.terminate() # Terminate session
            return False

        # Compute the total sum by performing the infinite sums of the simplified summands
        result = execute_mathematica_command(session, mathematica_sum_command("partialresult", sum_info), timeout_seconds)

        if result and not any(x in str(result) for x in ['$Failed', '$Aborted', '<<', '>>']): # If this computation was successful (and does not contain the shortening bracket, an indication that the result has not reduced to the correct answer - these are pretty short expressions), ...
            result = interpret_mathematica_output(str(result))

            # append the combination of summands and the result to their intended lists and check if all terms have been successfully processed
            succesful_combinations_list.append(combi)
            succesful_results_list.append(result)

            if find_indices_tuple_list_combination_to_match_reference(succesful_combinations_list, summands):
                session.terminate() # Terminate session
                return True # The problem is solved, so return True
        
        session.terminate() # Terminate session
        return False

def simplify_order_expression_using_Mathematica(expression: sp.Expr, timeout_seconds: int = 60, reverse_switch: bool = None) -> sp.Expr:
    # Sympy is unable to solve the infinite sums directly so we make use of the Wolfram Client Library. This library, however, also has a hard time with some of the expressions, so let's first combine terms with the same number of summations (e.g. over p_1 and p_2) and simplify before proceeding to perform the complete sums
    total_result = expression # Set total_result to expression initially in case no Mathematica calculations are required

    # To this end, we must first separate the summands from the sum so that we can simplify them before taking the sums (if necessary)
    terms = sp.Add.make_args(expression) # Break up the expression into separate terms
    if str(terms).count('Sum') > 1: # If further simplication is necessary, ...
        # we now collect the summands and the summation info, and add and simplify the summands with the same number of nested sums
        summands = []
        num_sums = []
        corresponding_sum_info = []
        for term in terms:
            args = term.args # Get the arguments of the Sympy expression
            
            # Account for possible prefactors
            prefactor = 1 # Default prefactor (does nothing obviously)
            if str(args[1]).startswith('Sum'): # In this case there is a pre-factor
                prefactor = args[0] # Get the pre-factor
                args = args[1].args # Override args to get the summands and the summation limits
            
            summand = prefactor*args[0]
            sum_info = list(args[1:])
                
            summands.append(sp.latex(summand).replace('\\','\\\\')) # Add summand (in modified Latex format) to the list of summands (this is done for use in the Wolfram Client)
            num_sums.append(len(sum_info)) # Add the number of sums in this term to the number of nested sums
            corresponding_sum_info.append(tuple(f"{{{sum_idx if '_' not in str(sum_idx) else convert_subscript_to_mathematica(str(sum_idx))}, {low}, {up if 'oo' not in str(up) else 'Infinity'}}}" for sum_idx, low, up in sum_info)) # Convert to Mathematica syntax for later reference and add to the list of sum information
        
        # Cluster the summands based on the number of sums
        sorted_num_sums, sorted_summands = map(list, zip(*sorted(zip(num_sums, summands)))) # Sort num_sums and summands in increasing number of summations
        clustered_summands = [[cluster for _, cluster in group] for _, group in itertools.groupby(zip(sorted_num_sums, sorted_summands), key=lambda x: x[0])] # Cluster the summands with the same number of sums (in increasing order)
        unique_sum_info = sorted(set(corresponding_sum_info), key=len) # Save the sum information corresponding to each cluster

        # Call Mathematica from the terminal to simplify latex_expr and find the total_result 
        total_result = 0
        final_clustered_summands_idx = len(clustered_summands) - 1 # -1 to allow for checking the final index
        for c_idx, cluster in enumerate(clustered_summands): # Consider each cluster separately
            if reverse_switch is None: # If not set by the user
                if c_idx == 0 or c_idx == final_clustered_summands_idx: # Switch to reverse order of the cluster sizes attempted first, except for the term with the least and the most nested sums since we know they can be solved directly
                    reverse = False
                else: 
                    reverse = True
            else:
                reverse = reverse_switch
                
            total_result += compute_result_cluster(cluster, unique_sum_info[c_idx], timeout_seconds, reverse) # Compute the sums of this cluster by attempting all combinations of sums (to make sure Mathematica can evaluate the total expression for this order)

    elif str(terms).count('Sum') == 1: # In the 2nd order case, we also need the Wolfram Client but direclty, i.e. without the need of clusters
        session = WolframLanguageSession() # Start WolframLanguageSession

        latex_simplified_sum = sp.latex(expression).replace('\\', '\\\\') # Format the expression into LateX format for implementation in the Wolfram Client
        mathematica_total_sum_command = f'Simplify[ToExpression["{latex_simplified_sum}", TeXForm]]' # Write a command to perform the infinite sums
        result = execute_mathematica_command(session, mathematica_total_sum_command, timeout_seconds) # Run the command
        
        if result and not any(x in str(result) for x in ['$Failed', '$Aborted']):
            total_result = interpret_mathematica_output(result) # Interpret the results for further computations in Sympy and add to the total
        else:
            raise ValueError(f'No combinations of expressions in the cluster yielded a solution within the provided timeout ({timeout_seconds} seconds).')
        
        session.terminate() # Terminate session

    # Try to simplify the result even further using Sympy and return
    try:
        return total_result.simplify()
    except Exception:
        return total_result
    
##################################################################
# MAIN
##################################################################
if __name__ == '__main__':

    # Settings
    order = 4 # Perturbation order
    n = 1 # The energy level to consider
    precision = 50 # Precision (number of digits) for numerical evaluation
    timeout_seconds = 20 # Number of seconds used for timeout for Mathematica evaluations

    a = sp.Integer(1) # Particle in a box size
    i, j = sp.symbols("i j")
    V_ij = 2/a*sp.sin(i*sp.pi/2)*sp.sin(j*sp.pi/2) # Perturbative potential (formula for the problem at hand)
    E0_i = (i*sp.pi)**2 / (2*a) # Eigenenergies uperturbed system (formula for the problem at hand)

    # Performing computations
    E_symbolic = E_numeric = E0_i.subs({i: n}) # Perturbed energy, initialized to account for zeroth order term
    latex_expressions = np.empty(order+1, dtype=object)
    latex_expressions[0] = sp.latex(E_symbolic)
    for o in range(1, order+1):
        print('\n\n----------------------------------------------------------------')
        print(f"Expression for order {o}:")
        print('----------------------------------------------------------------\n')

        # Generate diagrams for order o
        labeled_diagrams, unlabeled_diagrams = generate_diagrams(order=o, n=n)

        print('Order\tNum. diagrams\tDiagram representation')
        print(o, '\t', np.shape(unlabeled_diagrams)[0], '\t\t', labeled_diagrams, end='')

        print('\n\n\tLabeled diagrams:')
        for labeled_diagram in labeled_diagrams:
            print('\n\t', labeled_diagram)
        
        # Compute full general expression
        expression = calculate_complete_expression(labeled_diagrams, order=o, n=n)
        print("\nGeneral expression:")
        display(expression)
        
        # Apply this to the case study
        expression = evaluate_expression_for_problem(expression, V_ij=V_ij, E0_i=E0_i)
        expression = simplify_order_expression_using_Mathematica(expression, timeout_seconds)
        print(f'\nResult for order {0}:')
        display(expression)
        latex_expressions[o] = sp.latex(expression)
        print(f'Latex form:\n{latex_expressions[o]}')
        E_symbolic += expression
    
    print('\n\n----------------------------------------------------------------')
    print(f'FINAL PERTURBED ENERGY UP TO ORDER {order}:')
    print('----------------------------------------------------------------\n')
    display(E_symbolic)
    print(f'Latex form:\n{"+".join(latex_expressions)}')

    Beep(300, int(1.5E3)) # Make beep sound when done