In [1]:
import numpy as np
from scipy import sparse
import typing

def make_data(frustrated:bool=True):
    # Set the random seed
    np.random.seed(0)
    n = 500
    p = 0.5

    # Calculate the number of non-zero entries in the upper triangle (including diagonal)
    num_nonzero = int(p * n * (n + 1) / 2)

    # Get indices for the upper triangle (including diagonal)
    upper_triangle_indices = np.triu_indices(n)
    total_upper_entries = len(upper_triangle_indices[0])

    # Randomly select indices for non-zero entries
    selected_indices = np.random.choice(total_upper_entries, size=num_nonzero, replace=False)
    i_indices = upper_triangle_indices[0][selected_indices]
    j_indices = upper_triangle_indices[1][selected_indices]

    # Generate random data between -1 and 1 for the non-zero entries
    data = np.random.uniform(-1, 1, size=num_nonzero)

    # Create a sparse matrix for the upper triangle
    w_upper = sparse.coo_matrix((data, (i_indices, j_indices)), shape=(n, n))

    # Symmetrize the matrix to make it symmetric
    w = w_upper + w_upper.T - sparse.diags(w_upper.diagonal())

    if frustrated:
        #  Apply (w > 0) - (w < 0) to define a frustrated system
        w = (w > 0).astype(int) - (w < 0).astype(int)
    else:
        # this choice defines a ferro-magnetic (easy) system
        w = ( w > 0).astype(int)

    # Remove diagonal entries by setting them to zero
    w.setdiag(0)
    return w

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import time
import jax
import csv
import hashlib
import numpy as np
import jax.numpy as jnp
import jax.random as jr
from jax import jit
from jax import Array
from functools import partial
from tqdm import tqdm


# return neighbour set
@partial(jit, static_argnums = 2)
def get_neighbour(key, state, R):

    positions = jr.choice(key, len(state), shape=(R,), replace=False)

    def flip_one(state, pos):
        return state.at[pos].set(-state[pos])

    new_state = jax.lax.fori_loop(0, R,
                                  lambda i, s: flip_one(s, positions[i]),
                                  state)

    return new_state

@partial(jit, static_argnums = 1)
def reset(key, n:int):
    """
    samples a starting state of shape (n,), where all values of the state are -1 or 1
    """
    unifs = jr.normal(key, (n,))
    state = unifs>0.5
    state = 2*state - 1
    return state

@jax.jit
def energy_calc(state, data):
    energy = - 1/2 *state.T@data@state
    return energy

@jax.jit
def std_dev(mean, data):
    deviations = (data - mean)**2
    var = jnp.mean(deviations)

    return jnp.sqrt(var)

def hash_state(state):
    "converts a state to bytes and hashes it"

    return hashlib.sha256(state.tobytes()).hexdigest()

@jax.jit
def track_min_energy_state(curr_min_e, curr_min_state, new_energies, new_states):

    batch_min_idx = jnp.argmin(new_energies)
    batch_min_e   = new_energies[batch_min_idx]
    batch_min_state = new_states[batch_min_idx]

    min_e = jnp.where(batch_min_e < curr_min_e,
                      batch_min_e, curr_min_e)
    min_state = jnp.where(batch_min_e < curr_min_e,
                          batch_min_state, curr_min_state)

    return min_e, min_state

@partial(jit, static_argnums = [2, 3])
def iterative_improvement(key, data, vec_len, R, num_iterations):
    state = reset(key, vec_len)

    def body_fun(i, carry):
        state, key = carry
        key, subkey = jr.split(key)
        current_state = state
        neighbour_state = get_neighbour(subkey, state, R)

        neighbour_energy = energy_calc(neighbour_state, data)
        current_energy   = energy_calc(current_state, data)

        state = jnp.where(neighbour_energy < current_energy,
                          neighbour_state, current_state)
        return (state, key)

    final_state, _ = jax.lax.fori_loop(0, num_iterations-1, body_fun, (state,key))
    return final_state

@partial(jit, static_argnums = [1, 2, 4, 5])
def K_num_restarts(key, K, R, data, vec_len, num_iterations):
    keys = jr.split(key, K)

    def single_run(run_key):
        state = iterative_improvement(run_key, data, vec_len, R, num_iterations)
        return energy_calc(state, data), state

    energies, states = jax.vmap(single_run)(keys)

    best_idx = jnp.argmin(energies)
    return energies[best_idx], states[best_idx]

@partial(jit, static_argnums = [1, 2, 3, 5, 6])
def N_num_runs(key, num_runs, K, R, data, vec_len, num_iterations):
    keys = jr.split(key, num_runs)

    run_fn = lambda skey: K_num_restarts(skey, K, R, data, vec_len, num_iterations)
    energies, states = jax.vmap(run_fn)(keys)

    return energies, states

def threshold(mean, std_dev, eps):
    return std_dev < jnp.abs(mean)*eps

def main():
    # save = False
    save = True
    vec_len = 500
    num_iterations = 100
    num_runs = 20
    # neighbourhood
    R = 1

    # rng key
    key = jr.PRNGKey(35)

    # a) for both ferromagnetic and frustrated, find number of restarts K needed to obtain reproducible results
    max_K = 5000
    increment_K = 500
    eps = 0.01 # 1%

    # ferro-magnetic
    data = jnp.array(make_data(frustrated=False).toarray())

    print(f"Starting a) for ferro-magnetic system, with N = {num_runs} runs, max K = {max_K}, and K increment = {increment_K}")
    for K in tqdm(range(50, max_K, increment_K)):
        key, subkey = jr.split(key)
        start = time.time()
        energies, states = N_num_runs(subkey, num_runs, K, R, data, vec_len, num_iterations)
        end = time.time()
        runtime = end-start
        runtime /= num_runs

        mean_E = jnp.mean(energies)
        SD_E   = std_dev(mean_E, energies)

        if threshold(mean_E, SD_E, eps):
            print(f"Reproduced state at K={K}")
            print(f"Energy of the final solution was: {mean_E}")
            print(f"Standard Deviation of the final solution was: {SD_E}")

    # frustrated
    data = jnp.array(make_data().toarray())

    print(f"Starting a) for frustrated system, with N = {num_runs} runs, max K = {max_K}, and K increment = {increment_K}")
    for K in tqdm(range(50, max_K, increment_K)):
        key, subkey = jr.split(key)
        start = time.time()
        energies, states = N_num_runs(subkey, num_runs, K, R, data, vec_len, num_iterations)
        end = time.time()
        runtime = end-start
        runtime /= num_runs

        mean_E = jnp.mean(energies)
        SD_E   = std_dev(mean_E, energies)

        if threshold(mean_E, SD_E, eps):
            print(f"Reproduced state at K={K}")
            print(f"Energy of the final solution was: {mean_E}")
            print(f"Standard Deviation of the final solution was: {SD_E}")

    # b) for frustrated
    results = []

    # make frustrated data
    data = jnp.array(np.loadtxt("w500"))

    Ks = [20, 100, 200, 500, 1000, 2000, 4000]
    # Ks = [20, 100, 200, 500]
    # Ks = [20]

    min_e = 0
    min_state = 0

    for K in Ks:
        key, subkey = jr.split(key)
        start = time.time()
        energies, states = N_num_runs(subkey, num_runs, K, R, data, vec_len, num_iterations)
        end = time.time()
        runtime = end-start
        runtime /= num_runs

        mean_E = jnp.mean(energies)
        SD_E   = std_dev(mean_E, energies)

        results.append([K, runtime, mean_E, SD_E])

        min_e, min_state = track_min_energy_state(min_e, min_state,
                                                  energies, states)

    if save:
        file_name = "/content/drive/MyDrive/iter_improvement.csv"
        with open(file_name, "w", newline="") as csvfile:
            csv_writer = csv.writer(csvfile)

            csv_writer.writerows(results)

    jax.debug.print("min_e: {e}", e=min_e)
    return results

if __name__ == "__main__":
    results = main()
    print(results)

Starting a) for ferro-magnetic system, with N = 20 runs, max K = 5000, and K increment = 500


100%|██████████| 10/10 [1:51:26<00:00, 668.68s/it] 


Starting a) for frustrated system, with N = 20 runs, max K = 5000, and K increment = 500


100%|██████████| 10/10 [1:50:33<00:00, 663.38s/it] 


min_e: -2346.0
[[20, 0.22509949207305907, Array(-1672.9, dtype=float32), Array(105.0742, dtype=float32)], [100, 0.2656669855117798, Array(-1838.8, dtype=float32), Array(86.942276, dtype=float32)], [200, 0.22720940113067628, Array(-1884.8, dtype=float32), Array(73.37411, dtype=float32)], [500, 0.3400750160217285, Array(-1976.2001, dtype=float32), Array(84.76533, dtype=float32)], [1000, 0.2188255786895752, Array(-1985., dtype=float32), Array(46.52956, dtype=float32)], [2000, 0.2293645977973938, Array(-2053.6, dtype=float32), Array(78.8685, dtype=float32)], [4000, 0.34446561336517334, Array(-2123.8, dtype=float32), Array(65.97242, dtype=float32)]]
