# **Learning and Inference in a Lattice Model of Multicomponent Condensates**

Author: Cameron Chalk

This notebook recreates experiments using new backend code with improved speed and readability. The notebook is terse but should be explained sufficiently by the paper (https://doi.org/10.4230/LIPIcs.DNA.30.1). Parameters used for the results shown in the paper are included and can be further tested.

**USAGE:** Import the boltzmann_liquids_cuda.py from the Github code folder directly into the Colab files directory (or in the same directory as this .ipynb if you are running locally). Run the initialization section first. Each other section should be run "atomically", in the sense that the sections are not atomic, i.e. they share variable names, so that running the Hopfield condensation section after running the Avocado section will overwrite some of your Avocado variables.

**DATA:** If you wish to use the paper-reported $G_{i,j}$, $G_i$, or the pre-processed MNIST data, import the data from the Github directly (not as a folder) into the Colab files directory (or in the same directory as this .ipynb if you are running locally).

**IMPORTANT: DOCUMENTATION FOR THE MAIN MCMC FUNCTION:** Most of the notebook is standard Python/numpy/matplotlib code. Here, for convenience, I will paste the documentation for the mcmc() function from the boltzmann_liquids_cuda.py, since it is essentially the only non-Python/numpy/matplotlib function used. Thus it is important to read this carefully:

    GPU-accelerated Metropolis-Hastings MCMC for the 3D lattice Boltzmann liquid model.

        This is a wrapper around several Numba-CUDA kernels. Each lattice in `lattice_batch`
        is simulated independently, with one CUDA block per lattice.
        Arguments called "*_batch", if passed as a list of arguments of length equal
        to the number of lattices given by lattice_batch,
        will be used in each lattice (i.e., if a list of g_ij matrices are given,
        then the nth lattice in the batch will use the nth g_ij matrix).
        All arguments called "*_batch" can also be passed in as singletons
        (i.e., a single lattice and a single g_ij; or, a list of lattices
        can be passed in along with a single g_ij matrix and each lattice
        will use the single g_ij matrix.)
        Within a kernel launch,
        the block runs many proposal steps without returning to the CPU. After burn-in, the
        wrapper advances the chain between samples, copies the lattice(s) back to host, and
        accumulates summary statistics (n_ij, n_i, energies).
        A save_lattices option allows the return of all sampled lattice configurations,
        instead of (by default) the single endpoint lattice configuration(s).

        Model and energy
        ----------------
        - Lattice sites store integer species IDs in [0, num_species-1].
        - Pairwise interaction energies are given by `g_ij` (symmetric).
        - Optional per-species energies are given by `g_i`.
        - The Metropolis accept probability uses:  min(1, exp(-ΔE / kT)).

        Update modes (kernel choice)
        ----------------------------
        Exactly one mode is used per call:
        - Grand-canonical (default): single-site resampling (site chooses a random new species).
        Uses a **checkerboard** update to avoid simultaneously updating interacting neighbors.
        - Canonical local swaps (canonical=True): swap species between two sites.
        Uses a **mod-4 sublattice** scheme to keep simultaneous swaps disjoint.
        - Canonical ranged swaps (canonical=True, ranged=True): long-range swaps between
        mod-4 sublattices plus a small displacement from `swap_directions`.
        - Hybrid (hybrid=True): randomly mixes GC steps and ranged-swap steps *inside the kernel*.

        Important: if `canonical=True`, it takes precedence over `hybrid=True` (hybrid is ignored).

        Temperature / annealing
        -----------------------
        Burn-in is run as a single kernel launch of `burn_in` proposals with geometric annealing:
        kT : kT_high_batch -> kt_low_batch over `burn_in` steps (if different).
        During sampling, each inter-sample advance runs at constant temperature kt_low_batch
        (the wrapper passes kT_high = kT_low for those kernel calls).

        Parameters
        ----------
        lattice_batch : array-like (int32)
            Either a single lattice of shape (H, W, L) or a batch of lattices of shape
            (B, H, W, L). All lattices are assumed to have the same shape.
            H, W, and L must each be multiples of 4.

        num_parallel_proposals : int
            A budget used to compute the inter-sample spacing:
                sample_rate = (num_parallel_proposals - burn_in) // num_samples
            The actual number of proposals performed is:
                burn_in + num_samples * sample_rate
            (This is <= num_parallel_proposals due to floor division.)

        num_samples : int
            Number of samples to collect after burn-in.

        burn_in : int
            Number of proposal steps during burn-in (annealed from kT_high to kT_low if provided).

        num_species : int
            Number of species. Lattice entries must be valid indices: 0 <= s < num_species.

        g_ij_batch : array-like (float)
            Pairwise interaction energies. Accepts shape (num_species, num_species) or
            (B, num_species, num_species). For correctness with the swap kernels, `g_ij`
            is expected to represent symmetric interactions (g_ij[i,j] == g_ij[j,i]).

        g_i_batch : array-like (float), optional
            Per-species energies. Shape (num_species,) or (B, num_species). If omitted, zeros.
            Used by GC updates (and for reported energies). In canonical modes, counts are fixed,
            so `g_i` does not affect acceptance, but it is still included in configuration energy.

        kT_batch : float or array-like (float), optional
            Convenience temperature input. If provided and kT_high_batch/kt_low_batch are not,
            the wrapper sets kT_high = kT_low = kT_batch (i.e., constant temperature).

        kT_high_batch, kT_low_batch : float or array-like (float), optional
            Burn-in annealing endpoints. Each may be scalar or length-B. If neither is provided,
            both default to 1.0.

        free_pos_batch : array-like (bool/int), optional
            Mask of free (updatable) lattice positions. True/1 = free, False/0 = clamped.
            Accepts shape (H, W, L) or (B, H, W, L). If omitted, all positions are free.

        free_species_batch : array-like (int), optional
            Mask of free species IDs, used only in GC (and GC part of hybrid):
            proposals that would change to or from a non-free species are suppressed.
            Shape (num_species,) or (B, num_species). If omitted, all species are free.

        periodic : bool
            Boundary condition for local energy evaluation and neighbor addressing.
            - In the non-ranged canonical case, False here does prevent periodic swap proposals.

        canonical : bool
            If True, use canonical swap updates instead of GC resampling.

        ranged : bool
            Only used if canonical=True. If True, use ranged-swap kernel; else local-swap kernel.

        hybrid : bool
            If True (and canonical=False), use the hybrid kernel (randomly mixes GC and ranged steps).

        neighbor_displacements : tuple/list of (dx, dy, dz)
            Defines the neighborhood used in energy calculations.
            Constraints for correctness / detailed balance with the parallel update schemes:
            - GC mode: must be von Neumann (DEFAULT_NEIGHBOR_DISPLACEMENTS). Using Moore
            or larger neighborhoods causes race conditions which (probably) break detailed balance.
            - Canonical modes: must be a subset of the Moore neighborhood (dx,dy,dz ∈ {-1,0,1})
            to prevent race conditions. In short, for an extended neighborhood, proposals must be
            changed to a mod-k sublattice scheme (it is currently hardcoded mod-4).

        swap_directions : tuple/list of (dx, dy, dz)
            Possible displacements for swap proposals in canonical kernels.
            Must be a subset of the Moore neighborhood (dx,dy,dz ∈ {-1,0,1}) to prevent
            races/overlapping swaps under the mod-4 sublattice scheme.

        seed : int, optional
            RNG seed for xoroshiro128p states (ignored if `rng_states` is provided).

        save_lattices : bool
            If True, returns a list of sampled lattice configurations (host copies).

        rng_states : numba.cuda.random.Xoroshiro128pStates, optional
            Pre-initialized RNG states. Must have length:
                THREADS_PER_BLOCK_1D * num_lattices
            (One RNG stream per thread across all blocks.)

        Returns
        -------
        If save_lattices is False:
            lattice_final, avg_n_ij, avg_n_i, energies_list

        If save_lattices is True:
            saved_lattices, avg_n_ij, avg_n_i, energies_list

        Where:
        - lattice_final:
            Final lattice configuration(s), shape (H,W,L) if B=1 else (B,H,W,L).
            (Note: `saved_lattices` entries are not squeezed even when B=1.)
        - avg_n_ij:
            Sample-mean interaction counts. Shape (num_species,num_species) if B=1 else
            (B,num_species,num_species).
        - avg_n_i:
            Sample-mean species counts. Shape (num_species,) if B=1 else (B,num_species).
        - energies_list:
            Energies for bookkeeping, shape (num_samples+1,) if B=1 else (B,num_samples+1).
            energies_list[..., 0] is computed from the **initial** configuration (pre-burn-in);
            energies during burn-in are not recorded.

        Critical reminders
        ------------------------------
        - Canonical modes require lattice dimensions H, W, L to be multiples of 4.
        (The mod-4 offset scheme can otherwise index beyond lattice bounds.)
        - Ensure `num_parallel_proposals > burn_in` and that
            (num_parallel_proposals - burn_in) // num_samples >= 1
        or else sample_rate may be 0 (yielding repeated identical samples).
        - All lattices in the batch must share the same shape.

Unfortunately, exact recreation of the experiments (e.g., using identical RNG seeds) is not possible (or at least, not easy) due to changes in the backend code. Please alert ctchalk2@gmail.com about any discrepancies, as they may reflect errors in either the old backend or new backend code. The new backend code has been tested more thoroughly and so is more likely to be correct (although I believe both are correct, and I have yet to find any discrepancies).

Important erratum: The paper reports that all simulations employed non-periodic boundary conditions. However, the training and unclamped testing of the avocado interaction energies for the paper were done using periodic boundary conditions, as is the case in the notebook below.

##IMPORTANT:
The code here was needed to get the Numba CUDA code to work with the default Colab Runtime version (this was written 1/28/2026; I can't find the name of the Runtime version). You can test with and without running this configuration code. My guess is, for some Runtime version of the future, the code will work whether you run this or not. Alternatively, Numba CUDA support could become deprecated. If that's the case, you'd need to rewrite boltzmann_liquids_cuda.py in standard C++ CUDA.

In [None]:
from numba import config
config.CUDA_ENABLE_PYNVJITLINK = 1

##Initialization (run these first, no matter which section you'd like to run)

###Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import logsumexp
from itertools import product
import time
from tqdm import tqdm
import matplotlib.colors as mplcolors
from IPython.display import clear_output
import matplotlib.patheffects as pe

from boltzmann_liquids_cuda import (
    get_n_ij, get_n_i, config_energy, mcmc
)

print("All imports successful")

###Visualization/plotting functions

In [None]:
def plot_lattice(lattice,
                 n_sp,
                 title=None,
                 save=False,
                 filename=None,
                 hide_subset=False,
                 subset_bounds=None,
                 color_dict=None):
  L = lattice.shape[0]
  arr = lattice.copy()
  if title is None:
    title = ''
  fig = plt.figure()
  ax = fig.add_subplot(111, projection='3d')
  ax.set_title(title)
  ax.set_xlim(0, L)
  ax.set_ylim(0, L)
  ax.set_zlim(0, L)
  ax.set_xticks([0, L-1])
  ax.set_yticks([0, L-1])
  ax.set_zticks([0, L-1])
  colors = np.zeros(arr.shape, dtype=object)

  # Hide and/or bound subset
  if subset_bounds:
    (x_min, x_max, y_min, y_max, z_min, z_max) = subset_bounds
    for i in range(x_min, x_max):
        for j in range(y_min, y_max):
            for k in range(z_min, z_max):
                if hide_subset:
                    arr[i, j, k] = 0
    # Draw surfaces to cover tick lines inside the box
    ax.plot_surface(np.array([[x_min, x_max], [x_min, x_max]]), np.array([[y_min, y_min], [y_max, y_max]]), np.array([[z_min, z_min], [z_min, z_min]]), color=(1, 1, 1, 0.1))
    ax.plot_surface(np.array([[x_min, x_min], [x_max, x_max]]), np.array([[y_min, y_min], [y_max, y_max]]), np.array([[z_min, z_min], [z_max, z_max]]), color=(1, 1, 1, 0.5))
    # Draw dotted boundary box
    for corner in [(x_min, y_min, z_min), (x_min, y_min, z_max), (x_min, y_max, z_min), (x_min, y_max, z_max),
                    (x_max, y_min, z_min), (x_max, y_min, z_max), (x_max, y_max, z_min), (x_max, y_max, z_max)]:
        ax.plot([corner[0], corner[0]], [corner[1], corner[1]], [z_min, z_max], linestyle=':', color=(0, 0, 0, 0.5))
        ax.plot([corner[0], corner[0]], [y_min, y_max], [corner[2], corner[2]], linestyle=':', color=(0, 0, 0, 0.5))
        ax.plot([x_min, x_max], [corner[1], corner[1]], [corner[2], corner[2]], linestyle=':', color=(0, 0, 0, 0.5))
  for i in range(1, n_sp):
      colors[arr == i] = color_dict[i]
  ax.voxels(arr,
            facecolors=colors,
            lightsource=mplcolors.LightSource(azdeg=135, altdeg=0))
  if save:
    plt.savefig(filename)
  plt.show()
  plt.clf()


def plot_g_ij(arr, title=None, save=False, filename=None):
  if title is None:
    title = ''
  fig, ax = plt.subplots(1, 1)
  arr = arr.copy()
  ax.set_title(title)
  ax.imshow(arr, cmap='gray')
  absmax = np.max(np.abs(arr))
  cbar = plt.colorbar(ax.imshow(arr,
                                norm=mplcolors.TwoSlopeNorm(0, -absmax, absmax),
                                cmap='coolwarm'),
                      ax=ax,
                      orientation='vertical',
                      fraction=0.046,
                      pad=0.04)
  if save:
    plt.savefig(filename)
  plt.show()
  plt.clf()


def plot_g_i(arr, title=None, save=False, filename=None):
  if title is None:
    title = ''
  fig, ax = plt.subplots(1, 1)
  arr = arr.copy()
  ax.set_title(title)
  ax.imshow([g_i], cmap='gray')
  absmax = np.max(np.abs(arr))
  cbar = plt.colorbar(ax.imshow([arr],
                                norm=mplcolors.TwoSlopeNorm(0, -absmax, absmax),
                                cmap='coolwarm'),
                      ax=ax,
                      orientation='vertical',
                      fraction=0.046,
                      pad=0.04)
  if save:
    plt.savefig(filename)
  plt.show()
  plt.clf()


def plot_memory_lattice(lattice,
                        n_sp,
                        title=None,
                        save=False,
                        filename=None,
                        mem_filter=False):
  L = lattice.shape[0]
  arr = lattice.copy()
  if title is None:
    title = ''
  fig = plt.figure()
  ax = fig.add_subplot(111, projection='3d')
  ax.set_xlim(0, L)
  ax.set_ylim(0, L)
  ax.set_zlim(0, L)
  ax.set_title(title)
  ax.set_xticks([])
  ax.set_xticklabels([])
  ax.set_yticks([])
  ax.set_yticklabels([])
  ax.set_zticks([])
  ax.set_zticklabels([])

  discrete_cmap = plt.get_cmap('tab20', n_sp-1)
  colors = np.zeros(arr.shape, dtype=np.dtype((np.float32, 4)))
  for i in range(1, n_sp):
    colors[arr == i] = discrete_cmap(i-1)
  # Gray for inert species in inert species examples
  colors[arr == n_sp] = (0.5, 0.5, 0.5, 1)

  if mem_filter:
        # Color species blue if it is in the mem_filter
        for i in range(1, n_sp):
            if i-1 in np.argwhere(MEMS[mem_filter] == 1).flatten():
                colors[arr == i] = (0, 0, 1, 1)

  l = mplcolors.LightSource(azdeg=135, altdeg=0)
  ax.voxels(arr,
            facecolors=colors,
            lightsource=l)

  # ax.voxels(arr,
  #           facecolors=colors)

  if save:
    plt.savefig(f'{filename}')
  else:
    plt.show()


# Given a list of lattices, set each value >0 to 1 and plot a heat map in 3D
def plot_heatmap(lattices, title=None, species=None, cmap = 'Blues', save=False, filename=None):
    # Larger text
    plt.rcParams.update({'font.size': 20})
    if species:
        lattices = [lattice == species for lattice in lattices]
    else:
        lattices = [lattice > 0 for lattice in lattices]
    lattice = np.mean(lattices, axis=0)
    fig, ax = plt.subplots()
    if title:
        ax.set_title(title)
    heatmap_img = ax.imshow(np.rot90(lattice.sum(axis=2)), cmap=cmap, origin = 'lower')
    # Add colorbar
    cbar = plt.colorbar(heatmap_img)
    cbar.set_label('Mean z-axis sum')
    if save:
        plt.savefig(f'{filename}.svg')
    plt.show()
    # Default text size
    plt.rcParams.update({'font.size': 10})


def plot_energies(energies, num_parallel_proposals, burn_in, num_samples):
    if num_samples < 1:
        raise ValueError("num_samples must be >= 1 to plot a broken axis.")

    e = np.asarray(energies, dtype=np.float64).ravel()
    if e.shape[0] != num_samples + 1:
        raise ValueError("Expected energies to have length num_samples + 1 (initial + samples).")

    sample_rate = (num_parallel_proposals - burn_in) // num_samples
    if sample_rate <= 0:
        raise ValueError("sample_rate <= 0. Need (num_parallel_proposals - burn_in) // num_samples >= 1.")

    x = np.empty(num_samples + 1, dtype=np.int64)
    x[0] = 0
    x[1:] = burn_in + np.arange(1, num_samples + 1, dtype=np.int64) * sample_rate

    fig, (ax1, ax2) = plt.subplots(
        1, 2, sharey=True,
        gridspec_kw={"width_ratios": [1, 4], "wspace": 0.05}
    )

    # Left: single point only
    ax1.plot([x[0]], [e[0]], marker="o", linestyle="None")

    # Right: line plot over sampled region
    ax2.plot(x[1:], e[1:])  # default is a line

    left_pad = max(1, int(0.5 * sample_rate))
    ax1.set_xlim(x[0] - left_pad, x[0] + left_pad)
    ax1.set_xticks([x[0]])

    right_pad = max(1, int(0.5 * sample_rate))
    ax2.set_xlim(x[1] - right_pad, x[-1] + right_pad)

    ax1.spines["right"].set_visible(False)
    ax2.spines["left"].set_visible(False)
    ax2.yaxis.tick_right()
    ax2.tick_params(labelright=False)

    d = 0.015
    kwargs = dict(transform=ax1.transAxes, color="k", clip_on=False)
    ax1.plot((1 - d, 1 + d), (-d, +d), **kwargs)
    ax1.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)

    kwargs.update(transform=ax2.transAxes)
    ax2.plot((-d, +d), (-d, +d), **kwargs)
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)

    fig.supxlabel("Parallel proposals run")
    ax1.set_ylabel("Energy")

    plt.show()
    plt.clf()
    plt.close()

###Helper functions

In [None]:
# Returns discrete integer coordinates
# for a discrete sphere centered at (a,b,c)
# of radius r
def sphere(a, b, c, r):
    coords = []
    for x in range(-r, r+1):
        for y in range(-r, r+1):
            for z in range(-r, r+1):
                if (x+0.5)**2 + (y+0.5)**2 + (z+0.5)**2 <= r**2:
                    coords.append((x+a, y+b, z+c))
    return coords


# numpy's shuffle only shuffles the first dim. of a multidim. array
# This function will shuffle all of a 3D lattice
def shuffle_lattice(lattice):
  flattened_lattice = lattice.flatten()
  np.random.shuffle(flattened_lattice)
  return flattened_lattice.reshape(lattice.shape)


# This function will replace 0s (solvent) of a lattice with other species until
# the species counts match the target lattice's species counts
# This function is slow and not well-made
def match_counts(lattice, target_lattice):
  species_counts = get_n_i(lattice, N_SP_AVO)
  target_species_counts = get_n_i(target_lattice, N_SP_AVO)
  for i in range(1, N_SP_AVO):
    while species_counts[i] < target_species_counts[i]:
      (x,y,z) = np.random.randint(0, LATTICE_LENGTH_AVO, 3)
      if lattice[x,y,z] == 0:
        lattice[x,y,z] = i
        species_counts[i] += 1
  return lattice


##Avocado

###Globals and helper functions

In [None]:
# Four species:
# 0: solvent
# 1: avocado skin (outer)
# 2: avocado flesh (middle)
# 3: avocado pit (inner)
N_SP_AVO = 4
LATTICE_LENGTH_AVO = 24
AVOCADO_CENTER_X = 12
AVOCADO_CENTER_Y = 12
AVOCADO_CENTER_Z = 12
SKIN_RADIUS = 7
FLESH_RADIUS = 5
PIT_RADIUS = 3
COLOR_DICT_AVO = {1: '0.2', 2: 'g', 3: 'y'}

def avocado(x=AVOCADO_CENTER_X,
            y=AVOCADO_CENTER_Y,
            z=AVOCADO_CENTER_Z,
            r1=SKIN_RADIUS,
            r2=FLESH_RADIUS,
            r3=PIT_RADIUS):
    lattice = np.zeros((LATTICE_LENGTH_AVO,
                        LATTICE_LENGTH_AVO,
                        LATTICE_LENGTH_AVO),
                       dtype = np.int32)
    # skin
    for pos in sphere(x, y, z, r1):
        lattice[pos] = 1
    # flesh
    for pos in sphere(x, y, z, r2):
        lattice[pos] = 2
    # pit
    for pos in sphere(x, y, z, r3):
        lattice[pos] = 3
    return lattice

plot_lattice(avocado(), N_SP_AVO, title='Avocado', save=False, color_dict=COLOR_DICT_AVO)
plot_lattice(avocado(), N_SP_AVO, title='Avocado', save=False, color_dict=COLOR_DICT_AVO, subset_bounds=(0, 24, 0, 12, 0, 24), hide_subset=True)

###Training initialization

In [None]:
max_epochs = 100000
parallel_proposals_per_epoch = LATTICE_LENGTH_AVO ** 4 * 4
save_frequency = 10
learning_rate = 5000 / LATTICE_LENGTH_AVO ** 3
num_samples = 100

epoch = 0
g_ij = np.zeros((N_SP_AVO, N_SP_AVO))
norm_sums = np.zeros_like(g_ij)

folder_filename = f'avocado_{time.strftime("%Y%m%d-%H%M%S")}'
!mkdir $folder_filename
with open(f'{folder_filename}/hyperparameters.txt', 'w') as f:
  f.write(f'parallel_proposals_per_epoch: {parallel_proposals_per_epoch}\n')
  f.write(f'learning_rate: {learning_rate}\n')

###Testing function

In [None]:
def test(g_ij):
  lattice = avocado()
  lattice = shuffle_lattice(lattice)
  lattice, _, _, energies = mcmc(lattice_batch = lattice,
                                 num_parallel_proposals = parallel_proposals_per_epoch,
                                 num_samples = num_samples,
                                 burn_in = parallel_proposals_per_epoch//2,
                                 num_species = N_SP_AVO,
                                 g_ij_batch = g_ij,
                                 canonical = True,
                                 ranged = True,
                                 periodic = True)
  return lattice, energies

###Training epoch function

In [None]:
def training_epoch(g_ij, learning_rate, adagrad=False, norm_sums=None):
  dw = np.zeros_like(g_ij)

  # Wake phase
  n_ij = get_n_ij(avocado(), N_SP_AVO)
  dw -= n_ij

  # Sleep phase
  lattice = avocado()
  lattice = shuffle_lattice(lattice)
  lattice, n_ij, _, energies = mcmc(lattice_batch = lattice,
                                    num_parallel_proposals = parallel_proposals_per_epoch,
                                    num_samples = num_samples,
                                    burn_in = parallel_proposals_per_epoch//2,
                                    num_species = N_SP_AVO,
                                    g_ij_batch = g_ij,
                                    canonical = True,
                                    ranged = True,
                                    periodic = True)
  dw += n_ij

  # Update weights
  if adagrad:
    norm_sums += dw ** 2
    g_ij += learning_rate * dw / (1e-8 + np.sqrt(norm_sums))
  else:
    g_ij += learning_rate * dw
  return g_ij, norm_sums, lattice, energies

###Training

In [None]:
for _ in range(max_epochs):
  g_ij, norm_sums, lattice, energies = training_epoch(g_ij,
                                                      learning_rate,
                                                      adagrad=True,
                                                      norm_sums=norm_sums)

  if epoch % save_frequency == 0:
    lattice, energies = test(g_ij)
    plot_lattice(lattice, N_SP_AVO, title=f'Avocado, post epoch {epoch}', save=False, color_dict=COLOR_DICT_AVO)
    plot_lattice(avocado(), N_SP_AVO, title=f'Avocado, post epoch {epoch}', save=False, color_dict=COLOR_DICT_AVO, subset_bounds=(0, 24, 0, 12, 0, 24), hide_subset=True)
    plot_energies(energies,
                  parallel_proposals_per_epoch,
                  parallel_proposals_per_epoch//2,
                  num_samples)
    plt.imshow(g_ij, cmap='viridis')
    plt.colorbar()
    plt.show()

  epoch = epoch + 1

###Paper-reported parameters and plotting
Running this will overwrite trained $G_{ij}$

In [None]:
g_ij = np.load('avocado_g_ij.npy')
plot_g_ij(g_ij, title='Avocado')

####Unclamped testing

In [None]:
lattice, energies = test(g_ij)

plot_lattice(lattice, N_SP_AVO, title=f'Avocado', save=False, color_dict=COLOR_DICT_AVO)
plot_lattice(lattice, N_SP_AVO, title=f'Avocado', save=False, color_dict=COLOR_DICT_AVO, subset_bounds=(0, 24, 0, 12, 0, 24), hide_subset=True)
plot_energies(energies,
              parallel_proposals_per_epoch,
              parallel_proposals_per_epoch//2,
              num_samples)
plt.imshow(g_ij, cmap='viridis')
plt.colorbar()
plt.show()

####Radial density

In [None]:
def plot_average_radial_density(lattices, save=False, filename=None):
    # Larger text
    plt.rcParams.update({'font.size': 22})
    num_bins = 50
    radial_densities = {0: [], 1: [], 2: [], 3: []}

    # Process each lattice
    for lattice in lattices:
        lattice_shape = lattice.shape
        indices = np.argwhere(lattice == 3)
        if len(indices) == 0:
            continue  # Skip if no '3' values found
        center_of_mass = indices.mean(axis=0)
        distances = np.linalg.norm(np.indices(lattice_shape).transpose(1, 2, 3, 0) - center_of_mass, axis=3)
        max_distance = int(np.ceil(distances.max()))
        bin_edges = np.linspace(0, max_distance, num_bins)
        digitized = np.digitize(distances, bin_edges)
        binned_values = {i: lattice[digitized == i] for i in range(1, len(bin_edges))}

        for value in [0, 1, 2, 3]:
            radial_density = [np.mean(bins == value) for bins in binned_values.values()]
            radial_densities[value].append(radial_density)

    # Averaging radial densities
    averaged_radial_densities = {key: np.mean(np.vstack(values), axis=0) for key, values in radial_densities.items()}

    # Plotting
    fig, ax = plt.subplots()
    ax.set_title('Average radial density')
    colors = ['gray', 'black', 'green', 'goldenrod']  # Colors for values 0, 1, 2, 3
    for value in [0, 1, 2, 3]:
        ax.plot(bin_edges[:-1], averaged_radial_densities[value], color=colors[value], label=f'Species {value}')

    ax.set_xlabel('Distance from species 3 center')
    ax.set_ylabel('Frequency')
    ax.legend()
    if save:
        plt.savefig(f'{filename}.svg')
    plt.show()
    # Return text to default
    plt.rcParams.update({'font.size': 10})



In [None]:
lattices = []
for _ in range(100):
  lattices.append(shuffle_lattice(avocado()))
lattices = np.array(lattices, dtype=np.int32)
lattices, _, _, energies_s = mcmc(lattice_batch = lattices,
                                  num_parallel_proposals = parallel_proposals_per_epoch,
                                  num_samples = num_samples,
                                  burn_in = parallel_proposals_per_epoch//2,
                                  num_species = N_SP_AVO,
                                  g_ij_batch = g_ij,
                                  canonical = True,
                                  ranged = True)

plot_average_radial_density(lattices, save=True, filename='average_radial_density.svg')

####Clamped testing

#####Surface

In [None]:
lattice = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
lattice[0, :, :] = avocado()[12, :, :]

plot_lattice(lattice, N_SP_AVO, title='Surface clamp', color_dict=COLOR_DICT_AVO)

free_pos = np.ones_like(lattice)
free_pos[lattice != 0] = 0

lattice = match_counts(lattice, avocado())

plot_lattice(lattice, N_SP_AVO, title='Surface clamp initital', color_dict=COLOR_DICT_AVO)

lattice, _, _, _ = mcmc(lattice_batch = lattice,
                        num_parallel_proposals = parallel_proposals_per_epoch,
                        num_samples = num_samples,
                        burn_in = parallel_proposals_per_epoch//2,
                        num_species = N_SP_AVO,
                        g_ij_batch = g_ij,
                        free_pos_batch = free_pos,
                        canonical = True,
                        ranged = True)

plot_lattice(lattice, N_SP_AVO, title='Surface clamp final', color_dict=COLOR_DICT_AVO)

In [None]:
lattice = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
lattice[0, :, :] = avocado()[12, :, :]

free_pos = np.ones_like(lattice)
free_pos[lattice != 0] = 0
free_pos_batch = [free_pos] * 100

lattices = []
for _ in range(100):
  lattices.append(match_counts(lattice.copy(), avocado()))

lattices, _, _, _ = mcmc(lattice_batch = lattices,
                         num_parallel_proposals = parallel_proposals_per_epoch,
                         num_samples = num_samples,
                         burn_in = parallel_proposals_per_epoch//2,
                         num_species = N_SP_AVO,
                         g_ij_batch = g_ij,
                         free_pos_batch = free_pos_batch,
                         canonical = True,
                         ranged = True,
                         save_lattices = True)

# Flattens lattices list returned by mcmc, which is a list of list of lattices,
# into a list of lattices.
lattices = [item for sublist in lattices for item in sublist]

plot_heatmap(lattices, title='Aggregate Heatmap\n 10,000 samples', cmap = 'hot')

#####Polymer clamp

In [None]:
lattice = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
lattice[:, 12, 12] = avocado()[:, 12, 12]

plot_lattice(lattice, N_SP_AVO, title='Polymer clamp', color_dict=COLOR_DICT_AVO)

free_pos = np.ones_like(lattice)
free_pos[lattice != 0] = 0

lattice = match_counts(lattice, avocado())

plot_lattice(lattice, N_SP_AVO, title='Polymer clamp initital', color_dict=COLOR_DICT_AVO)

lattice, _, _, _ = mcmc(lattice_batch = lattice,
                        num_parallel_proposals = parallel_proposals_per_epoch,
                        num_samples = num_samples,
                        burn_in = parallel_proposals_per_epoch//2,
                        num_species = N_SP_AVO,
                        g_ij_batch = g_ij,
                        free_pos_batch = free_pos,
                        canonical = True,
                        ranged = True)

plot_lattice(lattice, N_SP_AVO, title='Polymer clamp final', color_dict=COLOR_DICT_AVO)

In [None]:
lattice = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
lattice[:, 12, 12] = avocado()[:, 12, 12]

free_pos = np.ones_like(lattice)
free_pos[lattice != 0] = 0
free_pos_s = [free_pos] * 100

lattices = []
for _ in range(100):
  lattices.append(match_counts(lattice.copy(), avocado()))

lattices, _, _, _ = mcmc(lattice_batch = lattices,
                        num_parallel_proposals = parallel_proposals_per_epoch,
                        num_samples = num_samples,
                        burn_in = parallel_proposals_per_epoch//2,
                        num_species = N_SP_AVO,
                        g_ij_batch = g_ij,
                        free_pos_batch = free_pos_s,
                        canonical = True,
                        ranged = True,
                        save_lattices = True)

# Flattens lattices list returned by mcmc, which is a list of list of lattices,
# into a list of lattices.
lattices = [item for sublist in lattices for item in sublist]

plot_heatmap(lattices, title='Aggregate Heatmap\n 10,000 samples', cmap = 'hot')

#####Ordered vs. unordered surface

In [None]:
lattice = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
lattice[0, :, :] = avocado(r1=9, r2=7, r3=4)[12, :, :]

random_surface = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
for pos in sphere(12, 12, 12, 9):
  if pos[0] == 12:
    random_surface[pos[1], pos[2]] = np.random.randint(1, N_SP_AVO)
lattice[-1, :, :] = random_surface

plot_lattice(lattice, N_SP_AVO, title='Ordered vs. unordered surface clamp', color_dict=COLOR_DICT_AVO)

free_pos = np.ones_like(lattice)
free_pos[lattice != 0] = 0

lattice = match_counts(lattice, avocado())

plot_lattice(lattice, N_SP_AVO, title='Ordered vs. unordered surface clamp initital', color_dict=COLOR_DICT_AVO)

lattice, _, _, _ = mcmc(lattice_batch = lattice,
                        num_parallel_proposals =parallel_proposals_per_epoch,
                        num_samples = num_samples,
                        burn_in = parallel_proposals_per_epoch//2,
                        num_species = N_SP_AVO,
                        g_ij_batch = g_ij,
                        free_pos_batch = free_pos,
                        canonical = True,
                        ranged = True)

plot_lattice(lattice, N_SP_AVO, title='Ordered vs. unordered surface clamp final', color_dict=COLOR_DICT_AVO)

In [None]:
lattice = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
lattice[0, :, :] = avocado(r1=9, r2=7, r3=4)[12, :, :]

random_surface = np.zeros((LATTICE_LENGTH_AVO, LATTICE_LENGTH_AVO))
for pos in sphere(12, 12, 12, 9):
  if pos[0] == 12:
    random_surface[pos[1], pos[2]] = np.random.randint(1, N_SP_AVO)
lattice[-1, :, :] = random_surface

free_pos = np.ones_like(lattice)
free_pos[lattice != 0] = 0
free_pos_s = [free_pos] * 100

lattices = []
for _ in range(100):
  lattices.append(match_counts(lattice.copy(), avocado()))

lattices, _, _, _ = mcmc(lattice_batch = lattices,
                        num_parallel_proposals =parallel_proposals_per_epoch,
                        num_samples = num_samples,
                        burn_in = parallel_proposals_per_epoch//2,
                        num_species = N_SP_AVO,
                        g_ij_batch = g_ij,
                        free_pos_batch = free_pos_s,
                        canonical = True,
                        ranged = True,
                        save_lattices = True)

# Flattens lattices list returned by mcmc, which is a list of list of lattices,
# into a list of lattices.
lattices = [item for sublist in lattices for item in sublist]

plot_heatmap(lattices, title='Aggregate Heatmap\n 10,000 samples', cmap = 'hot')

##Hopfield condensates: surface-conditioned associative recall

###Globals and helper functions

In [None]:
N_MEMS = 2
MEM_LENGTH = 16
MEMS = np.array([np.array([0, 1, 0, 1,
                           1, 0, 0, 1,
                           1, 0, 0, 0,
                           0, 1, 1, 1]),
                 np.array([1, 0, 0, 1,
                           0, 1, 0, 1,
                           0, 1, 1, 1,
                           0, 0, 0, 1])])

# 0: solvent
# 1-16: "memory" species
N_SP_HOPFIELD = 1 + MEM_LENGTH
LATTICE_LENGTH_HOPFIELD = 12

def sphere_with_memory_composition(lattice,
                                   memory,
                                   a,
                                   b,
                                   c,
                                   r):
  for coord in sphere(a, b, c, r):
    # This check allows placement of a semi-sphere against the wall (see cell below)
    if coord[0] >= 0:
      lattice[coord] = np.random.choice(np.argwhere(memory == 1).flatten()) + 1
  return lattice

def surface_with_memory_composition(lattice,
                                    memory,
                                    a,
                                    b,
                                    c,
                                    r):
  for coord in sphere(a, b, c, r):
    if coord[0] == 0:
      lattice[coord] = np.random.choice(np.argwhere(memory == 1).flatten()) + 1
  return lattice

In [None]:
lattice = sphere_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD),
                                                  dtype=np.int32),
                                         MEMS[0],
                                         0,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2-1)
plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Memory 0')
lattice = sphere_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD),
                                                  dtype=np.int32),
                                         MEMS[1],
                                         0,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2-1)
plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Memory 1')

lattice = surface_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD),
                                                  dtype=np.int32),
                                         MEMS[0],
                                         0,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2-1)
plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Memory 0')

lattice = surface_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD),
                                                  dtype=np.int32),
                                         MEMS[1],
                                         0,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2-1)
plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Memory 1')

###Training initialization

In [None]:
max_epochs = 100000
parallel_proposals_per_epoch = LATTICE_LENGTH_HOPFIELD ** 4 * 40
save_frequency = 10
learning_rate = 1 / LATTICE_LENGTH_HOPFIELD ** 3
num_samples = 100

epoch = 0
g_ij = np.zeros((N_SP_HOPFIELD, N_SP_HOPFIELD))
g_i = np.zeros(N_SP_HOPFIELD)
for i in range(1, N_SP_HOPFIELD):
  g_i[i] = 8

norm_sums_ij = np.zeros_like(g_ij)
norm_sums_i = np.zeros_like(g_i)

folder_filename = f'hopfield_{time.strftime("%Y%m%d-%H%M%S")}'
!mkdir $folder_filename
with open(f'{folder_filename}/hyperparameters.txt', 'w') as f:
  f.write(f'parallel_proposals_per_epoch: {parallel_proposals_per_epoch}\n')
  f.write(f'learning_rate: {learning_rate}\n')

###Testing function

In [None]:
def test_hopfield(g_ij, g_i):
  lattices = []
  surface_counts_s = []
  free_pos_batch = []
  for mem in MEMS:
    lattice = surface_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                        LATTICE_LENGTH_HOPFIELD,
                                                        LATTICE_LENGTH_HOPFIELD),
                                                       dtype=np.int32),
                                              mem,
                                              0,
                                              LATTICE_LENGTH_HOPFIELD//2,
                                              LATTICE_LENGTH_HOPFIELD//2,
                                              LATTICE_LENGTH_HOPFIELD//2-1)
    lattices.append(lattice)

    free_positions = np.ones_like(lattice)
    free_positions[np.where(lattice)] = 0
    free_pos_batch.append(free_positions)

    surface_counts = get_n_i(lattice, N_SP_HOPFIELD)
    surface_counts_s.append(surface_counts)

  lattices, _, counts_avgs_s, energies_s = mcmc(lattice_batch = lattices,
                                                num_parallel_proposals = parallel_proposals_per_epoch,
                                                num_samples = num_samples,
                                                burn_in = parallel_proposals_per_epoch//2,
                                                num_species = N_SP_HOPFIELD,
                                                g_ij_batch = g_ij,
                                                g_i_batch = g_i,
                                                free_pos_batch = free_pos_batch)

  for i in range(len(lattices)):
    plot_memory_lattice(lattices[i],
                        N_SP_HOPFIELD,
                        title=f'Memory {i}')
    plt.imshow([MEMS[i]])
    plt.colorbar()
    plt.show()
    nonsurface_counts_avgs = counts_avgs_s[i] - surface_counts_s[i]
    plt.imshow([nonsurface_counts_avgs[1:]])
    plt.colorbar()
    plt.show()
    plot_energies(energies_s[i],
                  parallel_proposals_per_epoch,
                  parallel_proposals_per_epoch//2,
                  num_samples)


###Training epoch function

In [None]:
def training_epoch(g_ij, g_i, learning_rate, adagrad=False, norm_sums_ij=None, norm_sums_i=None):
  dw = np.zeros_like(g_ij)
  db = np.zeros_like(g_i)

  # Wake phase
  for mem in MEMS:
    lattice = sphere_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                       LATTICE_LENGTH_HOPFIELD,
                                                       LATTICE_LENGTH_HOPFIELD),
                                                       dtype=np.int32),
                                             mem,
                                             0,
                                             LATTICE_LENGTH_HOPFIELD//2,
                                             LATTICE_LENGTH_HOPFIELD//2,
                                             LATTICE_LENGTH_HOPFIELD//2-1)
    n_ij = get_n_ij(lattice, N_SP_HOPFIELD)
    n_i = get_n_i(lattice, N_SP_HOPFIELD)
    dw -= n_ij
    db -= n_i

  # Sleep phase
  lattices = []
  free_pos_batch = []

  for mem in MEMS:
    lattice = surface_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                        LATTICE_LENGTH_HOPFIELD,
                                                        LATTICE_LENGTH_HOPFIELD),
                                                        dtype=np.int32),
                                              MEMS[1],
                                              0,
                                              LATTICE_LENGTH_HOPFIELD//2,
                                              LATTICE_LENGTH_HOPFIELD//2,
                                              LATTICE_LENGTH_HOPFIELD//2-1)
    lattices.append(lattice)

    free_positions = np.ones_like(lattice)
    free_positions[np.where(lattice)] = 0
    free_pos_batch.append(free_positions)

  lattices, n_ij_avgs_s, n_i_avgs_s, energies_s = mcmc(lattice_batch = lattices,
                                                       num_parallel_proposals = parallel_proposals_per_epoch,
                                                       num_samples = num_samples,
                                                       burn_in = parallel_proposals_per_epoch//2,
                                                       num_species = N_SP_HOPFIELD,
                                                       g_ij_batch = g_ij,
                                                       g_i_batch = g_i,
                                                       free_pos_batch = free_pos_batch)

  for i in range(len(lattices)):
    dw += n_ij_avgs_s[i]
    db += n_i_avgs_s[i]

  if adagrad:
    norm_sums_ij += dw ** 2
    norm_sums_i += db ** 2
    g_ij += learning_rate * dw / np.sqrt(norm_sums_ij + 1e-8)
    g_i += learning_rate * db / np.sqrt(norm_sums_i + 1e-8)
  else:
    g_ij += learning_rate * dw
    g_i += learning_rate * db
  return g_ij, g_i, norm_sums_ij, norm_sums_i

###Training

In [None]:
for _ in range(max_epochs):
  g_ij, _, _, _ = training_epoch(g_ij,
                                 g_i,
                                 learning_rate)

  if epoch % save_frequency == 0:
    clear_output(wait=True)
    np.save(f'{folder_filename}/g_ij_{epoch}', g_ij)
    np.save(f'{folder_filename}/g_i_{epoch}', g_i)
    test_hopfield(g_ij, g_i)
    plt.imshow(g_ij)
    plt.colorbar()
    plt.show()
    plt.imshow([g_i])
    plt.colorbar()
    plt.show()

  epoch = epoch + 1

###Paper-reported parameters and plotting
Running this will overwrite trained $G_{ij}$, $G_i$

In [None]:
g_ij = np.load('hopfield_g_ij.npy')
g_i = np.load('hopfield_g_i.npy')
plot_g_ij(g_ij)
plot_g_i(g_i)

####Surface recall

In [None]:
test_hopfield(g_ij, g_i)

####Partial memory recall

#####Partial memory 0

In [None]:
partial_memory = np.array([0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0])

lattice = surface_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD),
                                                  dtype = np.int32),
                                         partial_memory,
                                         0,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2-1)

surface_counts = get_n_i(lattice, N_SP_HOPFIELD)

plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Partial memory surface')

free_positions = np.ones_like(lattice)
free_positions[np.where(lattice)] = 0

lattice, _, counts_avg, energies = mcmc(lattice_batch = lattice,
                                         num_parallel_proposals = parallel_proposals_per_epoch,
                                         num_samples = num_samples,
                                         burn_in = parallel_proposals_per_epoch//2,
                                         num_species = N_SP_HOPFIELD,
                                         g_ij_batch = g_ij,
                                         g_i_batch = g_i,
                                         free_pos_batch = free_positions)

plot_memory_lattice(lattice,
                        N_SP_HOPFIELD,
                        title='Partial memory recall')
plt.imshow([MEMS[0]])
plt.colorbar()
plt.show()
nonsurface_counts_avg = counts_avg - surface_counts
plt.imshow([nonsurface_counts_avg[1:]])
plt.colorbar()
plt.show()
plot_energies(energies,
              parallel_proposals_per_epoch,
              parallel_proposals_per_epoch//2,
              num_samples)

#####Partial memory 1

In [None]:
partial_memory = np.array([0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])

lattice = surface_with_memory_composition(np.zeros((LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD,
                                                   LATTICE_LENGTH_HOPFIELD),
                                                  dtype=np.int32),
                                         partial_memory,
                                         0,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2,
                                         LATTICE_LENGTH_HOPFIELD//2-1)

surface_counts = get_n_i(lattice, N_SP_HOPFIELD)

plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Partial memory surface')

free_positions = np.ones_like(lattice)
free_positions[np.where(lattice)] = 0

lattice, _, counts_avg, energies = mcmc(lattice_batch = lattice,
                                         num_parallel_proposals =parallel_proposals_per_epoch,
                                         num_samples = num_samples,
                                         burn_in = parallel_proposals_per_epoch//2,
                                         num_species = N_SP_HOPFIELD,
                                         g_ij_batch = g_ij,
                                         g_i_batch = g_i,
                                         free_pos_batch = free_positions)

plot_memory_lattice(lattice,
                        N_SP_HOPFIELD,
                        title=f'Partial memory recall')
plt.imshow([MEMS[1]])
plt.colorbar()
plt.show()
nonsurface_counts_avg = counts_avg - surface_counts
plt.imshow([nonsurface_counts_avg[1:]])
plt.colorbar()
plt.show()
plot_energies(energies,
              parallel_proposals_per_epoch,
              parallel_proposals_per_epoch//2,
              num_samples)

####Partial polymer recall

In [None]:
polymer_positions_in_plane = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
                                       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
                                       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                                       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
                                       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
                                       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                                       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
                                       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

def string_lattice(mem):
  lattice = np.zeros((LATTICE_LENGTH_HOPFIELD,
                      LATTICE_LENGTH_HOPFIELD,
                      LATTICE_LENGTH_HOPFIELD),
                     dtype=np.int32)
  for j in range(LATTICE_LENGTH_HOPFIELD):
    for k in range(LATTICE_LENGTH_HOPFIELD):
      if polymer_positions_in_plane[j, k] == 1:
        lattice[j, k, LATTICE_LENGTH_HOPFIELD//2] = np.random.choice(np.argwhere(mem==1).flatten())+1
  return lattice

In [None]:
lattice = string_lattice(partial_memory)

free_positions = np.ones_like(lattice)
free_positions[np.where(lattice)] = 0

surface_counts = get_n_i(lattice, N_SP_HOPFIELD)

plot_memory_lattice(lattice,
                    N_SP_HOPFIELD,
                    title='Partial polymer clamp')

lattice, _, counts_avg, energies = mcmc(lattice_batch = lattice,
                                         num_parallel_proposals =parallel_proposals_per_epoch,
                                         num_samples = num_samples,
                                         burn_in = parallel_proposals_per_epoch//2,
                                         num_species = N_SP_HOPFIELD,
                                         g_ij_batch = g_ij,
                                         g_i_batch = g_i,
                                         free_pos_batch = free_positions)

plot_memory_lattice(lattice,
                        N_SP_HOPFIELD,
                        title=f'Partial polymer recall')
plt.imshow([MEMS[1]])
plt.colorbar()
plt.show()
nonsurface_counts_avg = counts_avg - surface_counts
plt.imshow([nonsurface_counts_avg[1:]])
plt.colorbar()
plt.show()
plot_energies(energies,
              parallel_proposals_per_epoch,
              parallel_proposals_per_epoch//2,
              num_samples)

###

##Surface pattern recognition

###Globals and helper functions

In [None]:
LATTICE_LENGTH_SPR = 8

N_DATA = 2
N_CLASS = 2
N_HIDDEN = 4
N_SP_SPR = 1 + N_DATA + N_CLASS + N_HIDDEN

COLOR_DICT_SPR = {1: 'k',
                  2: 'gray',
                  3: 'r',
                  4: 'b',
                  5: 'lawngreen',
                  6: 'yellow',
                  7: 'magenta',
                  8: 'cyan'}

def checkerboard(length):
    arr = np.ones((length, length), dtype=np.int32)
    arr[1::2, ::2] = 2
    arr[::2, 1::2] = 2
    return arr


def halfnhalf(length):
    arr = np.ones((length, length), dtype=np.int32)
    arr[length//2:, :] = 2
    return arr


def shuffled_hidden(length):
    arr = np.zeros((length, length), dtype=np.int32)
    arr[:length//2, :length//2] = 5
    arr[length//2:, :length//2] = 6
    arr[:length//2, length//2:] = 7
    arr[length//2:, length//2:] = 8
    arr = arr.flatten()
    np.random.shuffle(arr)
    arr = arr.reshape((length, length))
    return arr


def training_distribution():
    lattice = np.zeros((LATTICE_LENGTH_SPR,
                        LATTICE_LENGTH_SPR,
                        LATTICE_LENGTH_SPR), dtype=np.int32)
    # Set the initial state
    lattice[0] = checkerboard(LATTICE_LENGTH_SPR)
    lattice[-1] = halfnhalf(LATTICE_LENGTH_SPR)
    lattice[2] = 3
    lattice[-3] = 4
    lattice[1] = shuffled_hidden(LATTICE_LENGTH_SPR)
    lattice[-2] = shuffled_hidden(LATTICE_LENGTH_SPR)
    return lattice

def test_lattice():
  lattice = np.zeros((LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR),
                     dtype=np.int32)
  lattice[0] = checkerboard(LATTICE_LENGTH_SPR)
  lattice[-1] = halfnhalf(LATTICE_LENGTH_SPR)
  lattice[2] = 3
  lattice[-3] = 4
  lattice[1] = shuffled_hidden(LATTICE_LENGTH_SPR)
  lattice[-2] = shuffled_hidden(LATTICE_LENGTH_SPR)
  innards = lattice.copy()[1:-1, :, :]
  # shuffle innards
  innards = innards.flatten()
  np.random.shuffle(innards)
  innards = innards.reshape((LATTICE_LENGTH_SPR-2, LATTICE_LENGTH_SPR, LATTICE_LENGTH_SPR))
  lattice[1:-1, :, :] = innards
  return lattice

In [None]:
plot_lattice(training_distribution(), N_SP_SPR, color_dict=COLOR_DICT_SPR)

plot_lattice(test_lattice(), N_SP_SPR, color_dict=COLOR_DICT_SPR)

###Training initialization

In [None]:
max_epochs = 100000
parallel_proposals_per_epoch = LATTICE_LENGTH_SPR ** 4 * 100
save_frequency = 10
learning_rate = 50 / LATTICE_LENGTH_SPR ** 3
num_samples = 100

epoch = 0
g_ij = np.zeros((N_SP_SPR, N_SP_SPR))
norm_sums_ij = np.zeros_like(g_ij, dtype=np.float64)

folder_filename = f'surfacePatternRecognition_{time.strftime("%Y%m%d-%H%M%S")}'
!mkdir $folder_filename
with open(f'{folder_filename}/hyperparameters.txt', 'w') as f:
  f.write(f'parallel_proposals_per_epoch: {parallel_proposals_per_epoch}\n')
  f.write(f'learning_rate: {learning_rate}\n')

###Testing function

In [None]:
def test(g_ij):
  lattice = test_lattice()

  free_pos = np.ones((LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR), dtype=bool)

  # Species 0-1 are not free
  mask = np.isin(lattice, [1, 2])
  free_pos[mask] = False

  lattice, _, _, energies = mcmc(lattice_batch = lattice,
                                 num_parallel_proposals = parallel_proposals_per_epoch,
                                 num_samples = num_samples,
                                 burn_in = parallel_proposals_per_epoch//2,
                                 num_species = N_SP_SPR,
                                 g_ij_batch = g_ij,
                                 free_pos_batch = free_pos,
                                 kT_high_batch = 10.0,
                                 kT_low_batch = 1.0,
                                 canonical = True,
                                 ranged = True)

  plot_lattice(lattice, N_SP_SPR, color_dict=COLOR_DICT_SPR)
  plot_energies(energies,
                parallel_proposals_per_epoch,
                parallel_proposals_per_epoch//2,
                num_samples)

###Training epoch function

In [None]:
def training_epoch(g_ij, learning_rate, adagrad=False, norm_sums_ij=None):
  dw = np.zeros_like(g_ij)

  # Wake phase
  lattice = training_distribution()

  free_pos = np.ones((LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR), dtype=bool)

  # Species 0-4 are not free
  mask = np.isin(lattice, [0, 1, 2, 3, 4])
  free_pos[mask] = False

  lattice, n_ij, _, energies = mcmc(lattice_batch = lattice,
                                    num_parallel_proposals = parallel_proposals_per_epoch,
                                    num_samples = num_samples,
                                    burn_in = parallel_proposals_per_epoch//2,
                                    num_species = N_SP_SPR,
                                    g_ij_batch = g_ij,
                                    free_pos_batch = free_pos,
                                    kT_high_batch = 10.0,
                                    kT_low_batch = 1.0,
                                    canonical = True,
                                    ranged = True)

  dw -= n_ij

  # Sleep phase
  lattice = test_lattice()
  free_pos = np.ones((LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR,
                      LATTICE_LENGTH_SPR), dtype=bool)

  # Species 0-1 are always clamped
  mask = np.isin(lattice, [1, 2])
  free_pos[mask] = False

  lattice, n_ij, _, energies = mcmc(lattice_batch = lattice,
                                    num_parallel_proposals =parallel_proposals_per_epoch,
                                    num_samples = num_samples,
                                    burn_in = parallel_proposals_per_epoch//2,
                                    num_species = N_SP_SPR,
                                    g_ij_batch = g_ij,
                                    free_pos_batch = free_pos,
                                    kT_high_batch = 10.0,
                                    kT_low_batch = 1.0,
                                    canonical = True,
                                    ranged = True)

  dw += n_ij

  if adagrad:
    norm_sums_ij += dw ** 2
    g_ij += learning_rate * dw / np.sqrt(norm_sums_ij + 1e-8)
  else:
    g_ij += learning_rate * dw
  return g_ij, norm_sums_ij

###Training

In [None]:
for _ in range(max_epochs):
  g_ij, norm_sums_ij = training_epoch(g_ij,
                                      learning_rate,
                                      adagrad=True,
                                      norm_sums_ij=norm_sums_ij)

  if epoch % save_frequency == 0:
    clear_output(wait=True)
    print(f'Epoch: {epoch}')
    np.save(f'{folder_filename}/g_ij_{epoch}', g_ij)
    test(g_ij)
    plt.imshow(g_ij)
    plt.colorbar()
    plt.show()

  epoch = epoch + 1

###Paper-reported parameters and plotting
Running this will overwrite trained $G_{ij}$

In [None]:
g_ij = np.load('spr_g_ij.npy')
plot_g_ij(g_ij)

In [None]:
test(g_ij)

####Red and blue species heatmaps

In [None]:
num_trials = 100

lattices = []

for _ in range(num_trials):
  lattice = test_lattice()
  lattices.append(lattice)

free_pos = np.ones((LATTICE_LENGTH_SPR,
                    LATTICE_LENGTH_SPR,
                    LATTICE_LENGTH_SPR), dtype=bool)

# Species 0-1 are not free
mask = np.isin(lattice, [1, 2])
free_pos[mask] = False

free_pos_s = np.array([free_pos] * num_trials)

lattices, _, _, energies = mcmc(lattice_batch = lattices,
                                num_parallel_proposals = parallel_proposals_per_epoch,
                                num_samples = num_samples,
                                burn_in = parallel_proposals_per_epoch//2,
                                num_species = N_SP_SPR,
                                g_ij_batch = g_ij,
                                free_pos_batch = free_pos_s,
                                kT_high_batch = 10.0,
                                kT_low_batch = 1.0,
                                canonical = True,
                                ranged = True)

In [None]:
red_heatmaps = []
blue_heatmaps = []
for lattice in lattices:
    red_lattice = np.copy(lattice)
    for i in range(1, N_SP_SPR):
        if i != 3:
            red_lattice[red_lattice == i] = 0
        else:
            red_lattice[red_lattice == i] = 1
    blue_lattice = np.copy(lattice)
    for i in range(1, N_SP_SPR):
        if i != 4:
            blue_lattice[blue_lattice == i] = 0
        else:
            blue_lattice[blue_lattice == i] = 1
    red_heatmap = np.sum(red_lattice, axis=2)
    blue_heatmap = np.sum(blue_lattice, axis=2)
    red_heatmaps.append(red_heatmap)
    blue_heatmaps.append(blue_heatmap)

red_heatmaps = np.mean(red_heatmaps, axis=0)
red_heatmaps = np.transpose(red_heatmaps)
blue_heatmaps = np.mean(blue_heatmaps, axis=0)
blue_heatmaps = np.transpose(blue_heatmaps)

plt.imshow(red_heatmaps, cmap='hot')
plt.title('Red species heatmap')
plt.colorbar()
plt.show()

plt.imshow(blue_heatmaps, cmap='hot')
plt.title('Blue species heatmap')
plt.colorbar()
plt.show()

##Recognition of concentration profiles mapped from MNIST digits

###Plotting functions for MNIST tests

In [None]:
def image_2d_arrays(arr, titles=None, numbers=True, figsizeparam=4):
    if titles is None:
        titles = ['']*len(arr)
    n = len(arr)
    fig, axs = plt.subplots(1, n, figsize=(n*figsizeparam, figsizeparam))
    #Set range for color scale to the min and max over all arrays
    all_elements = np.concatenate([ar.flatten() for ar in arr])
    vmin = np.min(all_elements)
    vmax = np.max(all_elements)
    for i in range(n):
        axs[i].set_title(titles[i])
        axs[i].imshow(arr[i], vmin=vmin, vmax=vmax, cmap='viridis')
         # Add array's float values to imshow
        ax = axs[i]
        if numbers:
            for (j,k),label in np.ndenumerate(arr[i]):
                ax.text(k,j,np.round(label,2),ha='center',va='center',color='white',path_effects=[pe.withStroke(linewidth=2, foreground="black")])
    #Hide tick marks
    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()


def list_plot(arr):
    plt.plot(arr)
    plt.show()


def image_2d_array(arr, title=None, numbers=False, figsize=(12,12)):
    if title is None:
        title = ''
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.set_title(title)
    ax.imshow(arr, cmap='viridis')
    # Add array's float values to imshow
    if numbers:
        for (i,j),label in np.ndenumerate(arr):
            ax.text(j,i,np.round(label,2),ha='center',va='center',color='white', path_effects=[pe.withStroke(linewidth=2, foreground="black")])
    #If numbers is false, plot max and min as key to the right of the plot
    else:
        ax.text(1.1,0.5,np.round(np.max(arr),2),ha='center',va='center', transform=ax.transAxes)
        ax.text(1.1,0.1,np.round(np.min(arr),2),ha='center',va='center', transform=ax.transAxes)
    #Hide tick marks
    ax.set_xticks([])
    ax.set_yticks([])
    plt.show()

###Globals and helper functions

In [None]:
# 28 * 28 = 784 species, 10 classes, round up to 1000 for 206 hidden species, total 1000 species
# 64^3 = 262144 lattice, about 262 positions per species if divvied up equally

N_SP_DIGIT = 28*28
N_SP_CLASS = 10
N_SP_HIDDEN = 1000 - N_SP_DIGIT - N_SP_CLASS
N_SP_MNIST = N_SP_DIGIT + N_SP_CLASS + N_SP_HIDDEN
LATTICE_LENGTH_MNIST = 32
# Dimensions for plotting
DIGIT_DIMS = (28,28)
CLASS_DIMS = (1,10)
HIDDEN_DIMS = (2,103)
DIGIT_SHARE = LATTICE_LENGTH_MNIST**3 * N_SP_DIGIT/1000
CLASS_SHARE = LATTICE_LENGTH_MNIST**3 * N_SP_CLASS/1000
HIDDEN_SHARE = LATTICE_LENGTH_MNIST**3 * N_SP_HIDDEN/1000

TRAINING_STEPS = 10000
SWAPS = LATTICE_LENGTH_MNIST**3 * 64
SWAPS_PER_STEP = LATTICE_LENGTH_MNIST**3 / 4**3
STEPS = SWAPS // SWAPS_PER_STEP
burn_in_STEPS = 10000
COLLECTION_STEPS = 10000
SAMPLES_TO_COLLECT = 100
SAMPLERATE = int(COLLECTION_STEPS // SAMPLES_TO_COLLECT)
LEARNINGRATE = 0.025
BATCHSIZE = 1

In [None]:
def random_field():
    return np.random.randint(N_SP_MNIST, size=(LATTICE_LENGTH_MNIST, LATTICE_LENGTH_MNIST, LATTICE_LENGTH_MNIST)).astype(np.int32)


def shuffle_field(field):
  shuffled_field = np.copy(field).flatten()
  np.random.shuffle(shuffled_field)
  return shuffled_field.reshape((LATTICE_LENGTH_MNIST, LATTICE_LENGTH_MNIST, LATTICE_LENGTH_MNIST))


def shuffled_input_field(v):
  return shuffle_field(np.reshape(np.repeat(np.arange(N_SP_MNIST), v), (LATTICE_LENGTH_MNIST, LATTICE_LENGTH_MNIST, LATTICE_LENGTH_MNIST)))


# Turns MNIST data into a vector of counts for the lattice
# The training version of this function has all class-count share in the correct class...
def digit_to_counts_training(digit, classification):
  # Digit counts should be normalized
  digitcounts = digit * (DIGIT_SHARE // np.sum(digit))
  classcounts = np.zeros(N_SP_CLASS, dtype=np.int32)
  classcounts[classification] = int(CLASS_SHARE)
  if N_SP_HIDDEN != 0:
    hiddens = np.ones(N_SP_HIDDEN, dtype=np.int32) * int(HIDDEN_SHARE // N_SP_HIDDEN)
  counts = np.concatenate((digitcounts, classcounts, hiddens)).astype(np.int32)
  # If total count is not LATTICE_LENGTH^3, add random molecule counts to make it so
  if np.sum(counts) < LATTICE_LENGTH_MNIST**3:
    for _ in range(LATTICE_LENGTH_MNIST**3 - np.sum(counts)):
      counts[np.random.randint(N_SP_MNIST)] += 1
  return counts


# ... while the test version has the class counts split between the classes
def digit_to_counts_test(digit):
  # Digits counts should be normalized
  digitcounts = digit * (DIGIT_SHARE // np.sum(digit))
  classcounts = np.ones(N_SP_CLASS, dtype=np.int32) * (CLASS_SHARE // N_SP_CLASS)
  if N_SP_HIDDEN != 0:
    hiddens = np.ones(N_SP_HIDDEN, dtype=np.int32) * int(HIDDEN_SHARE // N_SP_HIDDEN)
  counts = np.concatenate((digitcounts, classcounts, hiddens)).astype(np.int32)
  # If total count is not LATTICE_LENGTH^3, add random molecule counts to make it so
  if np.sum(counts) < LATTICE_LENGTH_MNIST**3:
    for _ in range(LATTICE_LENGTH_MNIST ** 3 - np.sum(counts)):
      counts[np.random.randint(N_SP_MNIST)] += 1
  return counts


def classification_test_counts(classification):
  #Uniformly distribute digit counts
  digitcounts = np.ones(N_SP_DIGIT, dtype=np.int32) * (DIGIT_SHARE // N_SP_DIGIT)
  classcounts = np.zeros(N_SP_CLASS)
  classcounts[classification] = int(CLASS_SHARE)
  if N_SP_HIDDEN != 0:
    hiddens = np.ones(N_SP_HIDDEN, dtype=np.int32) * (HIDDEN_SHARE // N_SP_HIDDEN)
  counts = np.concatenate((digitcounts, classcounts, hiddens)).astype(np.int32)
  if np.sum(counts) < LATTICE_LENGTH_MNIST**3:
    for _ in range(LATTICE_LENGTH_MNIST**3 - np.sum(counts)):
      counts[np.random.randint(N_SP_MNIST)] += 1
  return counts

###Training initialization

In [None]:
max_epochs = 10000
parallel_proposals_per_epoch = 20000
save_frequency = 10
learning_rate = 0.025
num_samples = 10

epoch = 0
g_ij = np.zeros((N_SP_MNIST, N_SP_MNIST))
g_i = np.zeros(N_SP_MNIST)
norm_sums_ij = np.zeros_like(g_ij, dtype=np.float64)
norm_sums_i = np.zeros_like(g_i, dtype=np.float64)

folder_filename = f'mnistRecognition_{time.strftime("%Y%m%d-%H%M%S")}'

!mkdir $folder_filename
with open(f'{folder_filename}/hyperparameters.txt', 'w') as f:
  f.write(f'parallel_proposals_per_epoch: {parallel_proposals_per_epoch}\n')
  f.write(f'learning_rate: {learning_rate}\n')

###Import MNIST digits

In [None]:
import csv

#Import MNIST data
def import_mnist(filename):
    with open(filename, newline='') as csvfile:
        reader = csv.reader(csvfile)
        rows = []
        for row in reader:
            rows.append(row)
            #First and last values have mathematica opening and closing brackets respectively
            rows[-1][0] = rows[-1][0][1:]
            rows[-1][-1] = rows[-1][-1][:-1]
        #Convert to numpy array
        rows = np.array(rows)
        #Convert to float
        rows = rows.astype(np.float64)
        return rows

mnistTrain = []
mnistTest = []
for i in range(10):
    mnistTrain.append(import_mnist(f'mnistTrain{i}.csv'))
    mnistTest.append(import_mnist(f'mnistTest{i}.csv'))

####MNIST data sanity check

In [None]:
# Sanity check MNIST data
image_2d_array(np.reshape(mnistTrain[0][np.random.randint(len(mnistTrain[0]))], DIGIT_DIMS), title='Random training digit', numbers=True)
image_2d_array(np.reshape(mnistTest[9][np.random.randint(len(mnistTest[9]))], DIGIT_DIMS), title='Random test digit', numbers=True)

###Training epoch function

In [None]:
def training_epoch(g_ij, g_i, adagrad=False, norm_sums_ij=None, norm_sums_i=None):
  dw = np.zeros((N_SP_MNIST, N_SP_MNIST))
  db = np.zeros(N_SP_MNIST)

  lattices = []
  for i in range(N_SP_CLASS):
    lattices.append(shuffled_input_field(digit_to_counts_training(mnistTrain[i][np.random.randint(len(mnistTrain[i]))], i)))
  # Wake phase; only hidden species are free
  free_species = np.concatenate((np.zeros(N_SP_DIGIT, dtype=bool),
                                np.zeros(N_SP_CLASS, dtype=bool),
                                np.ones(N_SP_HIDDEN, dtype=bool)))
  lattices, n_ijs, n_is, _ = mcmc(lattice_batch = lattices,
                                  num_parallel_proposals = parallel_proposals_per_epoch,
                                  num_samples = num_samples,
                                  burn_in = parallel_proposals_per_epoch//2,
                                  num_species = N_SP_MNIST,
                                  g_ij_batch = g_ij,
                                  g_i_batch = g_i,
                                  free_species_batch = free_species,
                                  hybrid = True)

  for i in range(N_SP_CLASS):
    dw -= n_ijs[i]
    db -= n_is[i]

  # Sleep phase
  lattices = []
  for i in range(N_SP_CLASS):
    lattices.append(random_field())
  free_species = np.concatenate((np.ones(N_SP_DIGIT, dtype=bool),
                                np.ones(N_SP_CLASS, dtype=bool),
                                np.ones(N_SP_HIDDEN, dtype=bool)))
  lattices, n_ijs, n_is, _ = mcmc(lattice_batch = lattices,
                                  num_parallel_proposals =parallel_proposals_per_epoch,
                                  num_samples = num_samples,
                                  burn_in = parallel_proposals_per_epoch//2,
                                  num_species = N_SP_MNIST,
                                  g_ij_batch = g_ij,
                                  g_i_batch = g_i,
                                  free_species_batch = free_species,
                                  hybrid = True)

  for i in range(N_SP_CLASS):
    dw += n_ijs[i]
    db += n_is[i]

  if adagrad:
    norm_sums_ij += dw ** 2
    g_ij += learning_rate * dw / np.sqrt(norm_sums_ij + 1e-8)
    norm_sums_i += db ** 2
    g_i += learning_rate * db / np.sqrt(norm_sums_i + 1e-8)
  else:
    g_ij += learning_rate * dw
    g_i += learning_rate * db

  return g_ij, g_i, norm_sums_ij, norm_sums_i

###Testing function

In [None]:
def digit_clamp_test(g_ij, g_i):
  lattices = []
  for i in range(N_SP_CLASS):
    lattices.append(shuffled_input_field(digit_to_counts_test(mnistTest[i][np.random.randint(len(mnistTest[i]))])))
  init_lattices = np.copy(lattices)
  free_species = np.concatenate((np.zeros(N_SP_DIGIT, dtype=bool),
                                 np.ones(N_SP_CLASS, dtype=bool),
                                 np.ones(N_SP_HIDDEN, dtype=bool)))
  lattices, n_ijs, n_is, energies = mcmc(lattice_batch = lattices,
                                         num_parallel_proposals = parallel_proposals_per_epoch,
                                         num_samples = num_samples,
                                         burn_in = parallel_proposals_per_epoch//2,
                                         num_species = N_SP_MNIST,
                                         g_ij_batch = g_ij,
                                         g_i_batch = g_i,
                                         free_species_batch = free_species,
                                         hybrid = True)
  return n_ijs, n_is, energies, init_lattices, lattices

###Training

In [None]:
for i in range(max_epochs):
  g_ij, g_i, norm_sums_ij, norm_sums_i = training_epoch(g_ij,
                                                        g_i,
                                                        adagrad=True,
                                                        norm_sums_ij=norm_sums_ij,
                                                        norm_sums_i=norm_sums_i)
  epoch = epoch + 1
  if i % save_frequency == 0:
    clear_output(wait=True)
    print(f'Epoch: {epoch}')
    np.save(f'{folder_filename}/g_ij_{i}', g_ij)
    np.save(f'{folder_filename}/g_i_{i}', g_i)
    n_ijs, n_is, energies, init_lattices, lattices = digit_clamp_test(g_ij, g_i)
    for j in range(N_SP_CLASS):
      print(f'Class: {j}')
      image_2d_array(np.reshape(n_is[j][:N_SP_DIGIT], DIGIT_DIMS), numbers=True)
      image_2d_array(np.reshape(n_is[j][N_SP_DIGIT:N_SP_DIGIT+N_SP_CLASS], CLASS_DIMS), numbers=True)
      image_2d_array(np.reshape(n_is[j][N_SP_DIGIT+N_SP_CLASS:N_SP_MNIST], HIDDEN_DIMS), numbers=True, figsize=(36,36))
      plot_energies(energies[j],
                    parallel_proposals_per_epoch,
                    parallel_proposals_per_epoch//2,
                    num_samples)

###Paper-reported parameters
Running this will overwrite trained $G_{ij}$, $G_i$

In [None]:
g_ij = np.load('mnist_g_ij.npy')
plot_g_ij(g_ij)
g_i = np.load('mnist_g_i.npy')
plot_g_ij(np.reshape(g_i, (10, 100)))

In [None]:
n_ijs, n_is, energies, init_lattices, lattices = digit_clamp_test(g_ij, g_i)
for j in range(N_SP_CLASS):
  print(f'Class: {j}')
  image_2d_array(np.reshape(n_is[j][:N_SP_DIGIT], DIGIT_DIMS), numbers=True)
  image_2d_array(np.reshape(n_is[j][N_SP_DIGIT:N_SP_DIGIT+N_SP_CLASS], CLASS_DIMS), numbers=True)
  image_2d_array(np.reshape(n_is[j][N_SP_DIGIT+N_SP_CLASS:N_SP_MNIST], HIDDEN_DIMS), numbers=True, figsize=(36,36))
  plot_energies(energies[j],
                parallel_proposals_per_epoch,
                parallel_proposals_per_epoch//2,
                num_samples)

####Confusion Matrix

In [None]:
def classification_test(g_ij, g_i, num_tests):
  confusion_matrix = np.zeros((N_SP_CLASS, N_SP_CLASS))
  sums_matrix = np.zeros((N_SP_CLASS, N_SP_CLASS))
  for j in range(num_tests):
    lattices = []
    for i in range(N_SP_CLASS):
      lattices.append(shuffled_input_field((digit_to_counts_test(mnistTest[i][j]))))
    free_species = np.concatenate((np.zeros(N_SP_DIGIT, dtype=bool),
                                     np.ones(N_SP_CLASS, dtype=bool),
                                     np.ones(N_SP_HIDDEN, dtype=bool)))
    lattices, n_ijs, n_is, _ = mcmc(lattice_batch = lattices,
                                    num_parallel_proposals = parallel_proposals_per_epoch,
                                    num_samples = num_samples,
                                    burn_in = parallel_proposals_per_epoch//2,
                                    num_species = N_SP_MNIST,
                                    g_ij_batch = g_ij,
                                    g_i_batch = g_i,
                                    free_species_batch = free_species,
                                    hybrid = True)
    for i in range(N_SP_CLASS):
      confusion_matrix[i, np.argmax(n_is[i][N_SP_DIGIT:N_SP_DIGIT+N_SP_CLASS])] += 1
      sums_matrix[i, :] += n_is[i][N_SP_DIGIT:N_SP_DIGIT+N_SP_CLASS]
  return confusion_matrix, sums_matrix

In [None]:
num_tests = 10

confusion_matrix, sums_matrix = classification_test(g_ij, g_i, num_tests)

accuracy = np.trace(confusion_matrix) / np.sum(confusion_matrix)
print(f'Accuracy: {accuracy}')

In [None]:
plt.imshow(confusion_matrix, cmap='gray_r')
plt.xticks(np.arange(N_SP_CLASS))
plt.yticks(np.arange(N_SP_CLASS))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.colorbar()
plt.title(f'Confusion matrix ({num_tests} test samples per class)')
plt.clim(0, num_tests)
plt.savefig('confusion_matrix_100tests.svg')
plt.show()