In [1]:
import jax
import math
import optax
import numpy as np
import flax.linen as nn
from jax import nn as jnn
from jax import numpy as jnp
import jax.tree_util as tree_util   
from functools import partial, lru_cache
from dataclasses import dataclass, KW_ONLY
from jax import random, jit, vmap, grad, value_and_grad, lax
from typing import NamedTuple, Dict, List, Tuple, Any, Callable, Optional, Union

from flax.linen import MultiHeadAttention, make_causal_mask, Embed
from flax.linen.initializers import xavier_uniform, zeros


In [2]:
class PositionalEncoder(nn.Module):
    d_model : int         	# Hidden dimensionality of the input.	
    max_len : int = 16  	# Maximum length of a sequence to expect.

    def setup(self):
        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = np.zeros((self.max_len, self.d_model))
        position = np.arange(0, self.max_len, dtype=np.float32)[:,None]
        div_term = np.exp(np.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        pe = pe[None]
        self.pe = jax.device_put(pe)

    def __call__(self, x):
        x = x + self.pe[:, :x.shape[1]]
        return x



In [3]:
class NQS(nn.Module):
    seq_len: int
    num_layers: int
    dropout_prob: float
    
    _: KW_ONLY
	# embedding parameters
    embed_dim: int = 32
    num_embeddings: int	= 2
    

    def setup(self):
        self.layers = [
            MultiHeadAttention(
                num_heads=8,
				kernel_init=xavier_uniform(),
                dropout_rate=self.dropout_prob,
			) 
            	for _ in range(self.num_layers)
        ]
        self.embedding = Embed(num_embeddings=self.num_embeddings, features=self.embed_dim)
        self.pos_encoder = PositionalEncoder(d_model=self.embed_dim, max_len=self.seq_len)
        
		# set the dimension of the amplitude head to the num_embeddings
        self.amplitude_head = nn.Dense(features=self.num_embeddings)

    def __call__(self, x, mask=None, train=True):
        # Setup mask
        mask = make_causal_mask(x) if mask is None else mask

        # Step: Positional encoding + input embedding
        # x = jnp.squeeze(x, axis=-1)
        x = self.embedding(x)
        x = self.pos_encoder(x)
        
        # Step: Get logits from the transformer encoder
        for layer in self.layers:
            x = layer(x, mask=mask, deterministic=not train)
        
		# Step: Apply Log-softmax to get probabilities
        x = self.amplitude_head(x)
        x = nn.activation.log_softmax(x, axis=-1)
        
		# Step: Return the probabilities
        return x

    def get_attention_maps(self, x, mask=None, train=True):
        # A function to return the attention maps within the model for a single application used for visualization purpose later
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask)
            attention_maps.append(attn_map)
            x = l(x, mask=mask, train=train)
        return attention_maps
    
    @staticmethod
    def _generate_mask(sz):
        mask = jnp.triu(jnp.ones((sz, sz)), k=1).T
        mask = jnp.where(mask == 0, -jnp.inf, 0.0)
        # mask = jnp.triu(jnp.ones((sz, sz)), k=1)
        return mask

In [4]:
@lru_cache(maxsize=128, typed=False)
def get_all_interactions_jax(n: int) -> tuple:
    """
    Get all to all interactions from a n by n lattice using the euclidean distances.
    Assume a unit distance (1) between nearest neighbours

    Parameters
    ---
    n: integer representing a side of the square

    Output
    ---
    tuple[unique_pairs, multipliers]
    """

    # Create a grid of coordinates
    x, y = jnp.meshgrid(jnp.arange(n), jnp.arange(n))
    coordinates = jnp.stack([x.flatten(), y.flatten()], axis=1)

    # Calculate distances between all unique pairs
    num_points = coordinates.shape[0]
    distances = jnp.sqrt(
        jnp.sum((coordinates[:, None, :] - coordinates[None, :, :]) ** 2, axis=-1)
    )

    # Mask to select only unique pairs
    mask = jnp.triu(jnp.ones((num_points, num_points), dtype=bool), k=1)

    # Extract unique pairs, distances, and calculate multipliers
    unique_pairs = jnp.argwhere(mask)
    unique_distances = distances[mask]
    multipliers = 1 / unique_distances ** 6

    return unique_pairs, multipliers

In [5]:
@dataclass
class VMC:
    nsamples: int
    n: int
    learning_rate: float
    num_epochs: int
    output_dim: int
    sequence_length: int
    num_hidden_units: int

    def __post_init__(self):
        self.pairs, self.multipliers = get_all_interactions_jax(self.n)
        self.Omega = 1.0
        self.delta = 1.0

    
    
    def sample(self, key, params, model) -> List[Union[float, Tuple[float, ...]]]:
        sample_key, dropout_key = random.split(key)
        samples = jnp.zeros((self.nsamples, 1), dtype=jnp.int32)
        sample_keys = jax.random.split(sample_key, self.sequence_length)
        dropout_keys = jax.random.split(dropout_key, self.sequence_length)

        for i in range(self.sequence_length):
            log_prob = model.apply({'params': params}, samples, train=True, rngs={'dropout': dropout_keys[i]})

			# take the last log-probability
            log_prob = log_prob[:, -1, :]

			# sample from the log-probabilities
            sample = random.categorical(sample_keys[i], log_prob)
            sample = jnp.expand_dims(sample, axis=1)

            samples = sample if i == 0 else jnp.concatenate([samples, sample], axis=1)

        return samples
    
    # def sample(self, key, params, model) -> List[Union[float, Tuple[float, ...]]]:
    #     sample_key, dropout_key = random.split(key)
    #     samples = jnp.zeros((self.nsamples, 1), dtype=jnp.int32)
    #     sample_keys = jax.random.split(sample_key, self.sequence_length)
    #     dropout_keys = jax.random.split(dropout_key, self.sequence_length)
        
        
    #     def step(i, state):
    #         # network_input = lax.dynamic_slice(operand=state, start_indices=(0, 0), slice_sizes=(self.nsamples, i+1))
    #         network_input, samples = state
    #         log_prob = model.apply({'params': params}, network_input, train=True, rngs={'dropout': dropout_keys[i]})
    #         log_prob = log_prob[:, -1, :]
    #         sample = random.categorical(sample_keys[i], log_prob)
    #         sample = jnp.expand_dims(sample, axis=1)
    #         samples = lax.dynamic_update_slice(operand=samples, update=sample, start_indices=(0, i))
    #         network_input = jnp.concatenate([network_input, sample], axis=1)
    #         return network_input, state
        
    #     initial_samples = jnp.zeros((self.nsamples, self.sequence_length), dtype=jnp.int32)
    #     initial_input = jnp.zeros((self.nsamples, 1), dtype=jnp.int32)
    #     samples = lax.fori_loop(0, self.sequence_length, step, (initial_input, initial_samples))
        
    #     return samples
    

    def logpsi(self, samples: List[Union[float, Tuple[float, ...]]], params, model, dropout_key) -> List[float]:
        ss = (0, self.sequence_length - 1)
        nsamples = samples.shape[0]
        data   = samples[:, ss[0]:ss[1]]
        inputs = jnp.concatenate([jnp.zeros((nsamples, 1), dtype=jnp.int32), data], axis = 1)

        log_probs = model.apply({'params': params}, inputs, train=False, rngs={'dropout': dropout_key})

        logP   = jnp.sum(jnp.multiply(log_probs, jnn.one_hot(samples, self.output_dim)), axis=2)
        logP = 0.5 * jnp.sum(logP, axis=1)
        return logP
    
    


    def get_loss(self, params, rng_key, model):
        def l2_loss(x, alpha):
            return alpha * (x ** 2).mean()

        @jit
        def all_reg():
            return sum(
                l2_loss(w, alpha=0.001) for w in tree_util.tree_leaves(params["params"])
            )
        
        s_key, d_key, d_key2 = random.split(rng_key, 3)

        samples = self.sample(s_key, params, model)
        log_psi = self.logpsi(samples, params, model, d_key)
        e_loc = self.local_energy(samples, params, model, log_psi, d_key2)
        e_o = e_loc.mean()

        # We expand the equation in the text above
        first_term = 2 * jnp.multiply(log_psi, e_loc)
        second_term = 2 * jnp.multiply(e_o, log_psi)

        # l2_reg = all_reg()

        loss = jnp.mean(first_term - second_term)
        # loss = l2_reg(params) + loss
        # loss += l2_reg
        return loss, e_loc
    

    def train(self, rng_key, params, model):
        rng_key = random.PRNGKey(0)
        
        optimizer = optax.adam(learning_rate=self.learning_rate)
        opt_state = optimizer.init(params)

        loss_fn = self.get_loss
    

        @partial(jit, static_argnums=(3,))
        def step(params, rng_key, opt_state, get_loss=loss_fn):
            rng_key, new_key = random.split(rng_key)

            value, grads = value_and_grad(get_loss, has_aux=True)(params, rng_key, model)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return new_key, params, opt_state, value

        energies = []
        for i in range(self.num_epochs):
            rng_key, params, opt_state, (loss, eloc) = step(params, rng_key, opt_state)
            energies.append(eloc)

            if i % 100 == 0:
                print(f'step {i}, loss: {loss}')

        return energies
    
    
    
    def local_energy(self, samples, params, model, log_psi, dropout_key) -> List[float]:
        output = jnp.zeros((samples.shape[0]), dtype=jnp.float32)

        def step_fn_chemical(i, state):
            s, output = state
            output += - self.delta * s[:, i]
            return s, output

        def step_fn_intr(i, state):
            samples, pairs, multipliers, output = state
            output += multipliers[i] * samples[:, pairs[i, 0]] * samples[:, pairs[i, 1]]
            return samples, pairs, multipliers, output


        def step_fn_transverse(i, state):
            s, output = state
            flipped_state = s.at[:, i].set(1 - s[:, i])
            flipped_logpsi = self.logpsi(flipped_state, params, model, dropout_key)
            output += - self.Omega * jnp.exp(flipped_logpsi - log_psi)
            return s, output


    	# Interaction Term
        _, _, _, interaction_term = lax.fori_loop(0, 120, step_fn_intr, (samples, self.pairs, self.multipliers, output))
        # Off Diagonal Term
        _, transverse_field = lax.fori_loop(0, 16, step_fn_transverse, (samples, output))
        # _, transverse_field = lax.fori_loop(0, 16, step_fn_transverse, (samples, output))
        # Occupancy Term
        _, chemical_potential = lax.fori_loop(0, 16, step_fn_chemical, (samples, output))

        # Total energy
        loc_e = transverse_field + chemical_potential + interaction_term
		
        return chemical_potential

In [6]:
class VMCConfig(NamedTuple):
    nsamples: int = 500
    n: int = 4
    learning_rate: float = 0.01
    num_epochs: int = 500
    output_dim: int = 2
    sequence_length: int = 16
    num_hidden_units: int = 64

In [7]:
input_shape = (1,16)
nqs = NQS(seq_len=16, num_layers=8, dropout_prob=0.5)
params = nqs.init(jax.random.PRNGKey(0), jnp.ones(input_shape, dtype=jnp.int32))['params']

In [8]:
config = VMCConfig()
vmc = VMC(**config._asdict())
rng_key = random.PRNGKey(0)
densities = vmc.train(rng_key, params, nqs)

step 0, loss: 1.4691606760025024
step 100, loss: -750928.9375
step 200, loss: 0.0


KeyboardInterrupt: 