# Ising model simulation with Monte Carlo method

In [None]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

In [None]:
# Initialize JAX PRNG
key = random.key(183361063)

In [None]:
def random_spin(key, size=()):
    '''
    Returns an array filled with +1 and -1 values, randomly sampled from
    a uniform distribution.
    '''
    key, subkey = jax.random.split(key)
    return key, jnp.sign(jax.random.uniform(subkey, size) - 0.5)

In [None]:
def delta_E(S, i, j, J=1, h=0):
    '''
    Calculate the energy change after the flip of the (i, j) spin.
    '''
    Nr, Nc = S.shape                  # Size of the lattice
    nb = jnp.array((-1, -Nc, 1, Nc))  # Relative indexes of neighbouring spins

    k = i*Nc + j               # Flattened indexes inside the borders
    k_nb = (k - nb) % (Nr*Nc)  # Flattened indexes on the borders

    # Calculate the energy change
    dE = 2*J*jnp.sum(S.ravel()[k] * S.ravel()[k_nb]) + 2*h*S.ravel()[k]

    return dE

In [None]:
def mcmc_step(key, S, beta=0.01, J=1, h=0):
    '''TODO
    '''
    key, subkey = jax.random.split(key)
    # Choose random indexes
    i = jax.random.randint(key, shape=(), minval=0, maxval=S.shape[0])
    j = jax.random.randint(key, shape=(), minval=0, maxval=S.shape[1])

    # Calculate energy difference on flip of spin_ij
    dE = delta_E(S, i, j, J, h)

    # If dE < 0, then execute the flip
    if dE < 0:
        S = S.at[i, j].set(-S[i, j])
    # If dE == 0, then randomly flip the spin
    elif (dE == 0) & (jax.random.uniform(key, (), -1, 1) < 0):
        key, idx = random_spin(subkey)
        S = S.at[i, j].set(S[i, j] * idx)
    # If dE > 0, choose a random number R, between 0 and 1
    # If R < e^(-beta * dE), then execute the spin flip
    # If not, then leave the spins' state untouched and continue with the next step
    elif jax.random.uniform(subkey) < jnp.exp(-beta * dE):
            S = S.at[i, j].set(-S[i, j])

    return key, S

In [None]:
num_steps = 1000
key, S = get_random_spins(key, size=(25, 25))

for _ in range(num_steps):
    key, S = mcmc_step(key, S)