<a href="https://colab.research.google.com/github/jcandane/RCF/blob/main/gpjax_play.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
try:
    import gpjax
except:
    !pip install gpjax==0.8.2
    import gpjax as gpx

import plotly.graph_objects as go

Collecting gpjax==0.8.2
  Downloading gpjax-0.8.2-py3-none-any.whl (111 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.6/111.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.17.0,>=0.16.2 (from gpjax==0.8.2)
  Downloading beartype-0.16.4-py3-none-any.whl (819 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.1/819.1 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cola-ml<0.0.6,>=0.0.5 (from gpjax==0.8.2)
  Downloading cola_ml-0.0.5-py3-none-any.whl (68 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.3/68.3 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Collecting jaxopt<0.9.0,>=0.8.3 (from gpjax==0.8.2)
  Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.3/172.3 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxtyping<0.3.0,>=0.2.15 (from gpjax==0.8.2)
  Downloading jaxtyping-0.2.28-py3-none

  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)


In [2]:
from jax import config
config.update("jax_enable_x64", True)

import jax
import gpjax as gpx

class RCF():
    """ built: 3/19/2024
    this an object of a Random-Contionus-Function (RCF), with-respect-to a gpJAX kernel
    RCF : IN -> OUT
    we define a prior, and then sample to form a posterior.
    """

    def __init__(self, Domain, MO:int=1, N:int=17, seed:int=777,
                 IN_noise=None, OUT_noise=None,
                 kernel=gpx.kernels.RBF() ):
        """ initialize RCF object
        GIVEN >
             Domain : 2d-jax.Array (domain of input points, shape=(n-dimenions, 2))
                  N : int (number of points)
                 MO : int (Multiple-Output Dimension)
             **seed : int (opinonal, integer to define JAX PRNGKey random-key)
           **kernel : (opinonal, defaults to gpJAX's RBF kernel)
         **IN_noise : 1d-jax.Array (opinonal)
        **OUT_noise : 1d-jax.Array (opinonal)
        GET >
            None
        """
        self.dtype  = jax.numpy.float64
        self.IN     = Domain.astype(self.dtype) ### 2d-jax.Array
        self.N      = N      ### number of defining points
        self.MO     = MO     ### int (dimension of OUT)
        self.kernel = kernel
        self.seed   = seed

        self.key    = jax.random.PRNGKey(self.seed) ### define random sampling key
        ### define anisotropic i.i.d white-noise
        if IN_noise is None:
            self.IN_noise=jax.numpy.zeros(self.IN.shape[0], dtype=self.dtype)
        else:
            self.IN_noise = IN_noise
        if OUT_noise is None:
            self.OUT_noise=jax.numpy.zeros(self.MO, dtype=self.dtype)
        else:
            self.OUT_noise = OUT_noise

        ### find a series of random defining points, keep looping until we find a stable configuration of initial-points
        c_i        = jax.numpy.diff(self.IN, axis=1).reshape(-1)
        self.R_ix  = c_i[None,:]*jax.random.uniform(self.key, (N, self.IN.shape[0]), dtype=self.dtype)
        self.R_ix += self.IN[:,0][None,:]

        Σ_ij      = self.kernel.gram(self.R_ix).A
        self.L_ij = jax.numpy.linalg.cholesky(Σ_ij)
        if jax.numpy.sum( jax.numpy.isnan(self.L_ij).astype( jax.numpy.int32 ) )==0:
            None
        else: ### if cholesky-factorization fails... add random diagonal
            self.L_ij = jax.numpy.linalg.cholesky( Σ_ij + jax.numpy.diag( jax.random.uniform( self.key, (self.N, ) , dtype=self.dtype) ) ) ## not immutable
        ###

        Σ_i   = jax.numpy.diag(Σ_ij)
        D_iX  = jax.numpy.zeros(self.N, dtype=self.dtype)[:,None]*jax.numpy.ones(self.MO, dtype=self.dtype)[None,:]
        D_iX += (Σ_i[:,None]*jax.numpy.ones(self.MO, dtype=self.dtype)[None,:])
        D_iX *= jax.random.normal( self.key, (self.N,self.MO) , dtype=self.dtype) # Affine-transformation on jax.random.normal
        ## correlate D_iX using the Cholesky-factor, yielding random/correlated normal-samples
        self.S_iX = jax.scipy.linalg.cho_solve((self.L_ij, True), (self.L_ij @ D_iX))

    def __call__(self, D_ax):
        """ evaluate for arbitrary values/points in OUT given points in IN.
        GIVEN >
              self
              D_ax : 2d-jax.Array (D_ax ∈ IN)
        GET   >
              D_aX : 2d-jax.Array (D_aX ∈ OUT, note captial 'X')
        """
        D_ax += self.IN_noise*jax.random.normal(self.key, D_ax.shape, dtype=self.dtype)
        D_aX  = self.kernel.cross_covariance(D_ax, self.R_ix) @ self.S_iX
        D_aX += self.OUT_noise*jax.random.normal(self.key, D_aX.shape, dtype=self.dtype)
        return D_aX

In [4]:
dr_x   = jnp.array([3.5,3.5])
domain = jnp.array([[0.0,20.0],[0.0,20.0]])
f = RCF( domain, N=18, seed=86, kernel=gpx.kernels.Matern32() )
f = RCF( domain, N=15, seed=235) #, kernel=gpx.kernels.Matern32() )

R_ax = jnp.stack(jnp.meshgrid(*[ jnp.linspace(domain[i,0], domain[i,1], 30) for i in range(len(domain)) ]), axis=-1)
R_ax = R_ax.reshape((jnp.prod( jnp.asarray(R_ax.shape[:-1]) ), R_ax.shape[-1]))
D_ay = f( R_ax )

#### the plot
fig = go.Figure(data=[go.Scatter3d(x=R_ax[:,0], y=R_ax[:,1], z=D_ay[:,0], mode='markers'),
                      go.Scatter3d(x=(f.R_ix)[:,0], y=(f.R_ix)[:,1], z=(f(f.R_ix))[:,0], mode='markers')])
fig.show()

In [5]:
Domain = jnp.array([[0,10.],[-3,4.]], dtype=jnp.float64) #torch.tensor([[0,10.],[-3,4.],[-8,-2]]) ### numpy.2darray


f = RCF(Domain, N=18, seed=48) ## problems seed=1287

R_ax = jnp.stack(jnp.meshgrid(*[ jnp.linspace(Domain[i,0], Domain[i,1], 30) for i in range(len(Domain)) ]), axis=-1)
R_ax = R_ax.reshape((jnp.prod( jnp.asarray(R_ax.shape[:-1]) ), R_ax.shape[-1]))
D_ay = f( R_ax )

#### the plot
fig = go.Figure(data=[go.Scatter3d(x=R_ax[:,0], y=R_ax[:,1], z=D_ay[:,0], mode='markers'),
                      go.Scatter3d(x=(f.R_ix)[:,0], y=(f.R_ix)[:,1], z=(f(f.R_ix))[:,0], mode='markers')])
fig.show()

In [7]:
domain = jnp.array([[0.0,1.e-3],[0., 1.e-3]])
f = RCF( domain, N=18, seed=86, kernel=gpx.kernels.Matern32() )
f = RCF( domain, N=70, seed=235) #, kernel=gpx.kernels.Matern32() )

R_ax = jnp.stack(jnp.meshgrid(*[ jnp.linspace(domain[i,0], domain[i,1], 20) for i in range(len(domain)) ]), axis=-1)
R_ax = R_ax.reshape((jnp.prod( jnp.asarray(R_ax.shape[:-1]) ), R_ax.shape[-1]))
D_ay = f( R_ax )

#### the plot
fig = go.Figure(data=[go.Scatter3d(x=R_ax[:,0], y=R_ax[:,1], z=D_ay[:,0], mode='markers'),
                      go.Scatter3d(x=(f.R_ix)[:,0], y=(f.R_ix)[:,1], z=(f(f.R_ix))[:,0], mode='markers')])
fig.show()

print(jnp.sum(D_ay[:,0]))

-917.232656289191


In [None]:
type(domain)

jaxlib.xla_extension.ArrayImpl

## cpu information

In [None]:
!lscpu

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  2
  On-line CPU(s) list:   0,1
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Xeon(R) CPU @ 2.20GHz
    CPU family:          6
    Model:               79
    Thread(s) per core:  2
    Core(s) per socket:  1
    Socket(s):           1
    Stepping:            0
    BogoMIPS:            4400.39
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clf
                         lush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_
                         good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fm
                         a cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hyp
                         ervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd i