In [2]:
import numpy as np
from numba import jit

In [None]:

@jit(fastmath=True, nopython=True, cache=True)
def get_possible_transformations(S):
    """
    Return a numpy array of all possible transformations for a given square S.

    Args:
        S: A numpy array for which to generate all possible transformations.

    Returns:
        A numpy array of all possible transformations for the given square.
    """
    transformations = np.array([
        [-1, -1, -1, -1],  # uniform
        [-2, 0, 0, 0],  # single top
        [0, -2, 0, 0],  # single right
        [0, 0, -2, 0],  # single bottom
        [0, 0, 0, -2],  # single left
        [1, -1, 1, -1], [-1, 1, -1, 1],  # swap opposite
        [1, 1, -1, -1], [-1, -1, 1, 1],  # swap adjacent
        [-1, 1, 1, -1], [1, -1, -1, 1]   # swap diagonal
    ])

    # Create a mask for transformations to keep
    mask = np.ones(len(transformations), dtype=bool)

    # Adjust the mask based on conditions
    for i in range(4):
        if S[i] < 2:
            if i == 0:
                mask[1] = False  # single top
            elif i == 1:
                mask[2] = False  # single right
            elif i == 2:
                mask[3] = False  # single bottom
            elif i == 3:
                mask[4] = False  # single left

            if S[i] == 0:
                mask[0] = False  # uniform
                mask[5:10] = False  # various swaps

    # Always applicable transformations, added back
    always_applicable = np.array([
        [1, 1, 1, 1], [2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]
    ])

    # Concatenate the filtered and always applicable transformations
    final_transformations = np.vstack((transformations[mask], always_applicable))

    return final_transformations

@jit(fastmath=True, nopython=True, cache=True)
def get_local_time(grid, num_colors, x, y):
        """
        Calculate the local time for a given square at position (x, y).

        Args:
          x: The x-coordinate of the square.
          y: The y-coordinate of the square.

        Returns:
          The local time for the square at position (x, y).
        """
        local_time = 0
        for c in range(num_colors):
            local_time += grid[c, x, y, 0] + grid[c, x, y, 1] + grid[c, x, y + 1, 1] + grid[c, x + 1, y, 0]
        return local_time // 2

@jit(fastmath=True, nopython=True, cache=True)
def get_local_time_i(grid, c, x, y):
    """
    Calculate the local time for a given square at position (x, y) for color c.

    Args:
        c: The color for which to calculate the local time.
        x: The x-coordinate of the square.
        y: The y-coordinate of the square.

    Returns:
        The local time for the square at position (x, y) for color c.
    """
    return (grid[c, x, y, 0] + grid[c, x, y, 1] + grid[c, x, y + 1, 1] + grid[c, x + 1, y, 0] ) // 2
                       
@jit(fastmath=True, nopython=True, cache=True)
def acceptance_prob_optimized(S, M, s, X, c, beta, num_colors, algo, grid):
    """
        Calculate the acceptance probability for a transformation X on a square s of color c.

        Args:
          S: The current state of the square s.
          M: The number of possible transformations for the current state.
          s: The square to be transformed.
          X: The transformation to be applied to the square.
          c: The color of the square to be transformed.
          beta: the parameter beta of the simulation.
          num_colors: number of colors in the grid
          grid: the grid state.

        Returns:
          The acceptance probability for the transformation X on the square s of color c.
        """
    S_p = S + np.array(X)
    M_prime = len(get_possible_transformations(S_p))
    A = 0
    num_colors_half = num_colors / 2

    if np.array_equal(X, [1, 1, 1, 1]):
        A = beta**4 / (16 * S_p[0]*S_p[1]*S_p[2]*S_p[3] * \
            (num_colors_half + get_local_time(grid, num_colors,s[0], s[1])) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1)) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1])) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1))) * \
            (2*get_local_time_i(grid, c, s[0], s[1]) + 1) * (2*get_local_time_i(grid, c, s[0], s[1]-1) + 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1]-1) + 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1]) + 1)
    elif np.array_equal(X, [-1, -1, -1, -1]):
        A = (16 / (beta**4)) * S[0]*S[1]*S[2]*S[3] * \
            (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1) - 1) / \
            ((2*get_local_time_i(grid, c, s[0], s[1]) - 1) * (2*get_local_time_i(grid, c, s[0], s[1]-1) - 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1]-1) - 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1]) - 1))
    elif np.array_equal(X, [2, 0, 0, 0]):
        A = beta**2 / (4 * S_p[0] * (S[0]+1) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1])) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]))) * \
            (2*get_local_time_i(grid, c, s[0], s[1]) + 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1]) + 1)
    elif np.array_equal(X, [-2, 0, 0, 0]):
        A = 4 * S[0] * (S[0]-1) / (beta**2) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]) - 1) / \
            ((2*get_local_time_i(grid, c, s[0], s[1]) - 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1])))
    elif np.array_equal(X, [0, 2, 0, 0]):
        A = beta**2 / (4 * S_p[1] * (S[1]+1) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1])) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1))) * \
            (2*get_local_time_i(grid, c, s[0], s[1]) + 1) * (2*get_local_time_i(grid, c, s[0], s[1]-1) + 1)
    elif np.array_equal(X, [0, -2, 0, 0]):
        A = 4 * S[1] * (S[1]-1) / (beta**2) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1) - 1) / \
            ((2*get_local_time_i(grid, c, s[0], s[1]) - 1) * (2*get_local_time_i(grid, c, s[0], s[1]-1)))
    elif np.array_equal(X, [0, 0, 2, 0]):
        A = beta**2 / (4 * S_p[2] * (S[2]+1) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1)) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1))) * \
            (2*get_local_time_i(grid, c, s[0]-1, s[1]-1) + 1) * (2*get_local_time_i(grid, c, s[0], s[1]-1) + 1)
    elif np.array_equal(X, [0, 0, -2, 0]):
        A = 4 * S[2] * (S[2]-1) / (beta**2) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1) - 1) / \
            ((2*get_local_time_i(grid, c, s[0]-1, s[1]-1) - 1) * (2*get_local_time_i(grid, c, s[0], s[1]-1)))
    elif np.array_equal(X, [0, 0, 0, 2]):
        A = beta**2 / (4 * S_p[3] * (S[3]+1) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1])) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1))) * \
            (2*get_local_time_i(grid, c, s[0]-1, s[1]-1) + 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1]) + 1)
    elif np.array_equal(X, [0, 0, 0, -2]):
        A = 4 * S[3] * (S[3]-1) / (beta**2) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]) - 1) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1) - 1) / \
            ((2*get_local_time_i(grid, c, s[0]-1, s[1]-1) - 1) * (2*get_local_time_i(grid, c, s[0]-1, s[1])))
    elif np.array_equal(X, [-1, 1, -1, 1]):
        A = S[0]*S[2] / (S_p[1]*S_p[3])
    elif np.array_equal(X, [1, -1, 1, -1]):
        A = S[1]*S[3] / (S_p[0]*S_p[2])
    elif np.array_equal(X, [-1, -1, 1, 1]):
        A = S[0]*S[1] / (S_p[2]*S_p[3]) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]) - 1) / (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1)) * \
            (2*get_local_time_i(grid, c, s[0]-1, s[1]-1) + 1) / (2*get_local_time_i(grid, c, s[0], s[1]) - 1)
    elif np.array_equal(X, [1, 1, -1, -1]):
        A = S[2]*S[3] / (S_p[0]*S_p[1]) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]-1) - 1) / (num_colors_half + get_local_time(grid, num_colors,s[0], s[1])) * \
            (2*get_local_time_i(grid, c, s[0], s[1]) + 1) / (2*get_local_time_i(grid, c, s[0]-1, s[1]-1) - 1)
    elif np.array_equal(X, [-1, 1, 1, -1]):
        A = S[0]*S[3] / (S_p[1]*S_p[2]) * (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1]) - 1) / (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1)) * \
            (2*get_local_time_i(grid, c, s[0], s[1]-1) + 1) / (2*get_local_time_i(grid, c, s[0]-1, s[1]) - 1)
    elif np.array_equal(X, [1, -1, -1, 1]):
        A = S[2]*S[1] / (S_p[3]*S_p[0]) * (num_colors_half + get_local_time(grid, num_colors,s[0], s[1]-1) - 1) / (num_colors_half + get_local_time(grid, num_colors,s[0]-1, s[1])) * \
            (2*get_local_time_i(grid, c, s[0]-1, s[1]) + 1) / (2*get_local_time_i(grid, c, s[0], s[1]-1) - 1)

    # Calculate the acceptance probability based on the algorithm type
    return min(1, M/M_prime * A) if algo == 'metropolis' else 1/(1 + M_prime/(M*A))

