In [45]:
from canonical_model_jax import * 
import time

In [46]:
def phi(u): 
    return 1 / (1 + jnp.exp(-u))

def psi(x): 
    return jnp.log(x / (1 - x))

In [47]:
Ps = jnp.linspace(1000, 50000, 41)

In [51]:
class OlfactorySensing:
    def __init__(self, N=100, n=2, M=30, P=1000, sigma_0=1e-2, sigma_c=2.0): 
        self.N = N
        self.n = n
        self.M = M
        self.P = P
        self.sigma_0 = sigma_0
        self.sigma_c = sigma_c
        self.set_sigma()
        self.set_vasicek_window()
        self.W = None  # Initialize W as None; it may be set later with set_random_W

    def _tree_flatten(self):
        # Treat `W` as a dynamic value, while the rest are static
        children = (self.W,)  # W is the only dynamic value
        aux_data = {
            'N': self.N,
            'n': self.n,
            'M': self.M,
            'P': self.P,
            'sigma_0': self.sigma_0,
            'sigma_c': self.sigma_c,
            'vasicek_window': self.vasicek_window,
        }
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        # Recreate an instance of OlfactorySensing from children and aux_data
        instance = cls(
            N=aux_data['N'],
            n=aux_data['n'],
            M=aux_data['M'],
            P=aux_data['P'],
            sigma_0=aux_data['sigma_0'],
            sigma_c=aux_data['sigma_c']
        )
        instance.W = children[0]
        instance.vasicek_window = aux_data['vasicek_window']
        return instance

    def set_sigma(self): 
        self.sigma = lambda x: x / (1 + x) 

    def set_vasicek_window(self): 
        self.vasicek_window = jax.lax.stop_gradient(jnp.round(jnp.sqrt(self.P) + 0.5)).astype(int)
    
    @jit 
    def draw_cs(self, key):
        # Split the key into subkeys for indices and concentrations
        subkeys = jax.random.split(key, self.P + 1)
        indices_key = subkeys[0]
        concentration_keys = subkeys[1:]

        # Generate indices for all samples (P x n)
        indices = jax.vmap(
            lambda k: jax.random.choice(k, self.N, shape=(self.n,), replace=False)
        )(jax.random.split(indices_key, self.P))

        # Generate concentrations for all samples (P x n)
        concentrations = jax.vmap(
            lambda k: jax.random.lognormal(k, sigma=self.sigma_c, shape=(self.n,))
        )(concentration_keys)

        # Initialize the full samples matrix (P x N)
        c = jnp.zeros((self.P, self.N))

        # Scatter the concentrations into the appropriate indices
        c = c.at[jnp.arange(self.P)[:, None], indices].set(concentrations)

        # Return the result transposed
        return c.T


    def set_random_W(self, key): 
        self.W = 1 / jnp.sqrt(self.N) * jax.random.normal(key, shape=(self.M, self.N))

    def compute_activity(self, W, c, key): 
        pre_activations = W @ c
        r = self.sigma(pre_activations) + self.sigma_0 * jax.random.normal(key, shape=pre_activations.shape) 
        return r

    # @jit # This might take hours to compile.  https://jax.readthedocs.io/en/latest/control-flow.html#control-flow is helpful for the following
    def compute_entropy_of_r(self, W, c, key):
        r = self.compute_activity(W, c, key)
        entropy = self.compute_sum_of_marginal_entropies(r) - self.compute_information_of_r(r)
        return entropy

    @jit 
    def compute_sum_of_marginal_entropies(self, r):
        compute_entropy_vmap = vmap(self._vasicek_entropy, in_axes=0)
        # Apply the vectorized function
        marginal_entropies = compute_entropy_vmap(r)
        # Sum the marginal entropies
        return jnp.sum(marginal_entropies)

    @jit 
    def compute_information_of_r(self, r): 
        M, P = r.shape
        G = norm.ppf((rankdata(r.T, axis=0) / (P + 1)), loc=0, scale=1) # this is just ranking the data and making it normally distributed. 
        bias_correction = 0.5 * jnp.sum(digamma((P - jnp.arange(1, M + 1) + 1) / 2) - jnp.log(P / 2)) 
        cov_matrix = jnp.cov(G, rowvar=False)
        chol_decomp = cholesky(cov_matrix)
        log_det = jnp.sum(jnp.log(jnp.diag(chol_decomp)))
        I = -(log_det - bias_correction) # remember: entropy overall is sum of marginals minus information. information is sum of marginals - entropy. Kind of stupid. 
        return I
    
    def sum_covariances(self, W, c, key): 
        r = self.compute_activity(W, c, key)
        cov_r = jnp.cov(r) 
        off_diag_mask = ~jnp.eye(cov_r.shape[0], dtype=bool)
        # Extract the off-diagonal elements
        off_diag_elements = cov_r[off_diag_mask]
        # Compute the Frobenius norm of the off-diagonal elements
        frob_norm = jnp.sqrt(jnp.sum(off_diag_elements**2))
        return frob_norm 
    
    def log_det_sigma(self, W, c, key):
        r = self.compute_activity(W, c, key)
        cov_r = jnp.cov(r)  
        chol = cholesky(cov_r) 
        log_det = jnp.sum(jnp.log(jnp.diag(chol)))
        return log_det

    
    def _pad_along_last_axis(self, X):
        first_value = X[0]
        last_value = X[-1]
        # Use `lax.full_like` to create padded arrays
        Xl = lax.full_like(x=jnp.empty((self.vasicek_window,)), fill_value=first_value)
        Xr = lax.full_like(x=jnp.empty((self.vasicek_window,)), fill_value=last_value)
        return jnp.concatenate((Xl, X, Xr))

    def _vasicek_entropy(self, X):
        n = X.shape[-1]
        X = jnp.sort(X, axis=-1)
        X = self._pad_along_last_axis(X)
        start1 = 2 * self.vasicek_window
        length = self.P
        differences = lax.dynamic_slice(X, (start1,), (length, )) - lax.dynamic_slice(X, (0,), (length,))
        logs = jnp.log(n / (2 * self.vasicek_window) * differences)
        return jnp.mean(logs, axis=-1)

# Register the custom class as a PyTree with JAX
tree_util.register_pytree_node(
    OlfactorySensing,
    OlfactorySensing._tree_flatten,
    OlfactorySensing._tree_unflatten
)


In [54]:
N, n, M, sigma_c, P = 60, 2, 30, 2.0, 50000
os = OlfactorySensing(N=N, n=n, M=M, P=P, sigma_c=sigma_c)
key = jax.random.PRNGKey(1) 
os.cs = os.draw_cs(key=key) 
W_init = jnp.clip(1 / jnp.sqrt(os.N) * jax.random.gamma(key, a=1, shape=(M, N)), min=0, max=1-(1e-10)) 
os.W = W_init 

In [56]:
Ws, ents, losses = natural_gradient_dual_space(20, W_init, os.cs, key, lambda * args: - os.log_det_sigma(*args), 1, os, phi, psi)

Step 0, Loss: 58.095767974853516
Step 2, Loss: 50.326011657714844
Step 4, Loss: 47.06768798828125
Step 6, Loss: 45.30582046508789
Step 8, Loss: 44.08979797363281
Step 10, Loss: 43.124271392822266
Step 12, Loss: 42.208187103271484
Step 14, Loss: 41.62626647949219
Step 16, Loss: 41.119747161865234
Step 18, Loss: 40.66557693481445
