# Simple Bayesian

## Set up

In [1]:
%reset -s -f

In [1]:
import jax
import jax.numpy as np
from jax import random, vmap
from jax.config import config; config.update("jax_platform_name", "cpu")
from jax.scipy.special import logsumexp
import matplotlib
import matplotlib.pyplot as plt
import numpy as onp
import pandas as pd
import seaborn as sns

from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc

%matplotlib inline
plt.style.use('bmh')
plt.rcParams.update({'font.size': 16,
                     'xtick.labelsize': 14,
                     'ytick.labelsize': 14,
                     'axes.titlesize': 'large', 
                     'axes.labelsize': 'medium'})

hv.opt

ModuleNotFoundError: No module named 'matplotlib'

In [2]:
import holoviews as hv
# import 
from holoviews.operation.datashader import datashade, rasterize

hv.extension('bokeh') # important: doesn't render without it
hv.opts.defaults( hv.opts.Curve(width=800), hv.opts.Table(width= 800))

In [576]:
from bokeh.themes.theme import Theme

theme = Theme(
    json={
    'attrs' : {
        'Figure' : {
            'background_fill_color': '#2F2F2F',
            'border_fill_color': '#2F2F2F',
            'outline_line_color': '#444444',
        },
        'Grid': {
            'grid_line_dash': [6, 4],
            'grid_line_alpha': .3,
        },

        'Axis': {
            'major_label_text_color': 'white',
            'axis_label_text_color': 'white',
            'major_tick_line_color': 'white',
            'minor_tick_line_color': 'white',
            'axis_line_color': "white"
        },
        'Title': {
            'text_color': 'white'
        }
    }
})
hv.renderer('bokeh').theme = theme

In [577]:
from bokeh.themes import built_in_themes

## Test Data

### generating random data with Numpyro

> watch for use of random key in all sample statements

In [12]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng = random.PRNGKey(0)
display(f"first rng {rng}")
rng_, rng = random.split(rng)
display(f"rng_  {rng_}, rng {rng}")


'first rng [0 0]'

'rng_  [4146024105  967050713], rng [2718843009 1272950319]'

In [29]:
# NB. sample takes a random key on every call
mu_r1 =  dist.Normal(0., 0.5).sample(rng, (5000,))
mu_r2 =  dist.Normal(5., 0.5).sample(rng, (5000,))
mu_r = np.concatenate([mu_r1, mu_r2])
res = dist.Normal(mu_r, 1.0).sample(rng, (1,))


### making histograms with Numpyro and holoviews

In [30]:
h = onp.histogram(mu_r,20, density=True)
hv.Histogram(h).opts(tools=['hover'])

In [31]:
h = onp.histogram(res,20, density=True)
hv.Histogram(h).opts(tools=['hover'])

In [33]:
n_samples = 1000
xs = np.linspace(0, 10, n_samples)
beta = 2.3 # random slope
noise = dist.Normal(0., 1.0).sample(rng, (n_samples,))
ys = xs * beta + noise 

In [73]:
# normalisation of data - otherwise HMC doesn't play nice.
xs = (xs - xs.mean())/ xs.std()
beta = 2.3 # random slope
noise = dist.Normal(0., 1.0).sample(rng, (n_samples,))
ys = xs * beta + noise 

In [421]:
hv.Scatter((xs,ys)).opts(bgcolor='black')

## Model definition

In [103]:
def model(xs=None, ys=None):
    b = sample('beta', dist.Normal(0., 10.))
    if (xs is not None):
        mu = xs * b
    # we keep this, even if ys is None...
    sample('obs', dist.Normal(mu, 1.0) , obs=ys)
    

## Prior Predictive 

this is generating model outputs given from a 

I want to able to generate many `obs` samples from many `beta`. How's about we try substituting a bunch of beta in there and see what happens eg. the same number of beta a xs?? or do we want to modify the model and get an outer product of sorts...

or use vmap to try multiple beta vals... and this worked. I substituted in new betas that were created with `dist.Normal(0,10).sample(rng, num_samples)` where `num_samples` is the number of `beta` samples.



In [406]:
rng

DeviceArray([856958206, 578161011], dtype=uint32)

In [407]:
random.split(rng)

DeviceArray([[1399775720, 2158703605],
             [1021269120, 2954750475]], dtype=uint32)

In [519]:
def predict(rng, post_samples, model, *args, **kwargs):
    """
    :param rng : an array of seeds for random number generator
    :param post_samples : a dictionary (or an array of dictionaries) containing sample values for latent
        random variables
    :param model : a function containing primitive stochastic functions with names that match the dictionary 
        post_samples
    :param *args : positional args that may be used by model as input 
    :param **kwargs : kwargs that may be used by model.
    
    output : an array of the values returned by the model.
    """
    ## model gets modified to replace named distributions with generated samples
    model = substitute(seed(model, rng), post_samples)
    model_trace = trace(model).get_trace(*args, **kwargs)
    return model_trace['obs']['value']

In [561]:
# vectorize predictions via vmap
predict_fn = vmap(lambda rng, samples: 
                  predict(rng, samples, model, xs=xs))
rng, rng_ = random.split(rng)


# prior_samples are the values that we will pass via 'beta' to our vmapped fn predict_fn
prior_samples = dist.Normal(0,10).sample(rng, (num_samples,))
assert(num_samples == samples_1['beta'].shape[0])
predictions_1 = predict_fn(random.split(rng_, num_samples), {'beta': prior_samples})

print(f'predictions shape (xs.shape = 1000): {predictions_1.shape}')
mean_pred = np.mean(predictions_1, axis=0)
hpdi_pred = hpdi(predictions_1, 0.9)

predictions shape (xs.shape = 1000): (2000, 1000)


In [562]:
hv.Scatter((np.arange(2000),prior_samples),vdims='beta').opts(width=700).hist()

the red dots are the actuall data points $(xs,ys)$

the yellow line is the mean of the ys

the light mauve(?) is the 90% compatibility interval

In [578]:
hv_plot_regression(xs, mean_pred, hpdi_pred).opts(
    title='Predictions with 90% CI',
    xlabel='xs', ylabel='ys') 
# ax = plot_regression(dset.MarriageScaled.values, mean_pred, hpdi_pred)
# ax.set(xlabel='Marriage rate', ylabel='Divorce rate', title='Predictions with 90% CI');

In [579]:
hv.Scatter((np.tile(xs[::5],(1000,1)).flatten(), predictions_1[:1000,::5].flatten())).opts(
size=1, alpha=0.1,width=700, title="model predictions prior to seeing evidence")

In [None]:
### Working `get_trace` example! 

Notes:

- the model must be `seed`ed first. 
    - this initial the random generators somewhow
- then we wrap it with `trace` which returns a _handler_, which has the method `get_trace`
- calling `get_trace` with some input data `xs=xs` (the first `xs` is a keyword in the model definition), we get back an `OrderedDict` which contains all keys for every `sample` call in our model.

In [522]:
trace(seed(model, rng)).get_trace(xs=xs[:30])

OrderedDict([('beta',
              {'type': 'sample',
               'name': 'beta',
               'fn': <numpyro.distributions.continuous.Normal at 0x1a41a9a208>,
               'args': (),
               'kwargs': {'random_state': DeviceArray([1399775720, 2158703605], dtype=uint32)},
               'value': DeviceArray(-2.7955894, dtype=float32),
               'is_observed': False}),
             ('obs',
              {'type': 'sample',
               'name': 'obs',
               'fn': <numpyro.distributions.continuous.Normal at 0x1a3ee85278>,
               'args': (),
               'kwargs': {'random_state': DeviceArray([4146559817, 1601669810], dtype=uint32)},
               'value': DeviceArray([5.2982235, 5.1126318, 3.120929 , 3.5799675, 5.44522953,
                            6.90079403, 4.46349525, 4.24121475, 4.23534346, 5.7496686,
                            4.24852371, 4.9650178, 4.27080584, 4.37368393, 3.4981432,
                            4.75101948, 4.15790224, 3.4

In [523]:
# here we run the model and get the trace
# we verify that if we pass in the ys=ys then these are exactyl the 'obs' values
trace2 = trace(seed(model, rng)).get_trace(xs=xs, ys=ys)
assert (trace2['obs']['value'] == ys).all()

In [524]:
obs_vals = trace(seed(model, rng)).get_trace(xs=xs)['obs']['value']

In [525]:
hv.Scatter((xs, obs_vals))

In [526]:
trace(seed(model, rng)).get_trace(xs=xs)['beta']['value']

DeviceArray(-2.7955894, dtype=float32)

In [527]:
trace(seed(model, random.split(rng)[0])).get_trace(xs=xs)['beta']['value']

DeviceArray(10.659007, dtype=float32)

In [528]:
trace(seed(model, random.split(rng)[1])).get_trace(xs=xs)['beta']['value']

DeviceArray(-2.2482624, dtype=float32)

## Posterior Distribution

In [539]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng = random.PRNGKey(0)
rng_, rng = random.split(rng)

# Initialize the model.
init_params, potential_fn, constrain_fn = initialize_model(rng_, model, 
                                                           xs = xs,
                                                           ys = ys
                                                          )
num_warmup, num_samples = 1000, 2000

# Run NUTS.
samples_0 = mcmc(num_warmup, num_samples, init_params,
                 potential_fn=potential_fn, 
                 trajectory_length=10, 
                 constrain_fn=constrain_fn)

warmup: 100%|██████████| 1000/1000 [00:08<00:00, 116.52it/s, 3 steps of size 1.26e+00. acc. prob=0.79]
sample: 100%|██████████| 2000/2000 [00:01<00:00, 1025.88it/s, 1 steps of size 1.26e+00. acc. prob=0.85]




                           mean         sd       5.5%      94.5%      n_eff       Rhat
                beta       2.32       0.03       2.27       2.37     810.62       1.00


## Model evaluation 

We run our HMC magic and it spits out the right answer.

Note: initially I tried:

`sample('obs', mu , obs=ys)` but had to change it to 

`sample('obs', dist.Normal(mu, 1.0) , obs=ys)` 
    
to make it work with HMC.


In [540]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng = random.PRNGKey(0)
rng_, rng = random.split(rng)

# Initialize the model.
init_params, potential_fn, constrain_fn = initialize_model(rng_, model, 
                                                           xs = xs,
                                                           ys = ys)
num_warmup, num_samples = 1000, 2000

# Run NUTS.
samples_1 = mcmc(num_warmup, num_samples, init_params,
                 potential_fn=potential_fn, 
                 trajectory_length=10, 
                 constrain_fn=constrain_fn)

warmup: 100%|██████████| 1000/1000 [00:07<00:00, 128.07it/s, 3 steps of size 1.26e+00. acc. prob=0.79]
sample: 100%|██████████| 2000/2000 [00:01<00:00, 1050.98it/s, 1 steps of size 1.26e+00. acc. prob=0.85]




                           mean         sd       5.5%      94.5%      n_eff       Rhat
                beta       2.32       0.03       2.27       2.37     810.62       1.00


## Explore Samples

In [541]:
samples_1.keys()

dict_keys(['beta'])

Note: by appending the `.hist()` function we get the same histogram as the following graph.

In [542]:
hv.Scatter((np.arange(2000),samples_1['beta']),vdims='beta').opts(
    width=700,tools=['hover']).hist()

In [543]:
h = onp.histogram(samples_1['beta'],20, density=True)
hv.Histogram(h).opts(tools=['hover'])

Actual $\beta$ was $2.3$. Interestingly, 2.3 is not within 1 standard deviation of $\hat{\beta}$. (I wonder if this would change if we normalized our data first...).

Normalized the inputs (which normalized the outputs?) and this changed the hpdi considerably. Even though `beta.mean()` did not change, the spread of the `beta` got much larger.

In [544]:
samples_1['beta'].mean()

DeviceArray(2.3180168, dtype=float32)

In [545]:
samples_1['beta'].std()

DeviceArray(0.03089946, dtype=float32)

In [546]:
hpdi(samples_1['beta'],prob=0.9)

array([2.2671132, 2.369199 ], dtype=float32)

## Predictive Posterior

observation: because I used the same seed for the `dist.Normal().sample(.)` the predictive results are _very_ close. Zoom in on the results to see.

I'm not sure if this correct way to calculate the predictive posterior. I am taking the mean of the 

In [547]:
y_p = ((samples_1['beta'] * np.expand_dims(xs,-1)).mean(axis=1) + 
dist.Normal().sample(rng, (1000,)))

In [548]:
hv.Scatter((xs,y_p)).opts(alpha=0.4) * hv.Scatter((xs,ys)).opts(
    tools=['hover'], alpha=0.4)

In [549]:
y_p2 = ((samples_1['beta'].mean() * xs) + 
dist.Normal().sample(rng, (1000,)))

In [550]:
hv.Scatter((xs,y_p2)).opts(alpha=0.4) * hv.Scatter((xs,ys)).opts(
    tools=['hover'], alpha=0.4)

In [551]:
def hv_plot_regression(x, y_mean, y_hpdi, post_mu=None, draw_lines = True):
    # Sort values for plotting by x axis
    idx = np.argsort(x)
    x_ord = x[idx]
    mean = y_mean[idx]
    hpdi = y_hpdi[:, idx]
#     y_ord = dset.y_ordScaled.values[idx]
    y_ord = ys[idx] # ..
    std_plot = (hv.Curve((x_ord, mean)).opts(tools=['hover'], height=400) * 
        hv.Points((x_ord, mean)).opts(tools=['hover'], size=4, color='orange') *
        hv.Scatter((x_ord, y_ord)).opts(color='red', size=4, tools=['hover']) * 
        hv.Area((x_ord, hpdi[1,:],hpdi[0,:]), vdims=['l','u']).opts(color='pink', alpha=0.3, tools=['hover']))
    final_plot = std_plot
    if (post_mu is not None) and draw_lines:
        lines = hv.Overlay([hv.Curve((x_ord, post_mu[i,idx])).opts(alpha=0.05, color='gray') for i in range(100)])
        dots = hv.Overlay([hv.Scatter((x_ord,post_mu[i,idx])) for i in range(50)])
        final_plot = std_plot * lines * dots
    return final_plot


In [552]:
# this is what I need to understand:
# we use substitute the model's beta=Normal(.) with post_samples (eg. samples of our beta)
# then we run trace(model).get_trace(.) to get a new set of observations ie. ys

def predict(rng, post_samples, model, *args, **kwargs):
    """
    :param rng : an array of seeds for random number generator
    :param post_samples : a dictionary (or an array of dictionaries) containing sample values for latent
        random variables
    :param model : a function containing primitive stochastic functions with names that match the dictionary 
        post_samples
    :param *args : positional args that may be used by model as input 
    :param **kwargs : kwargs that may be used by model.
    
    output : an array of the values returned by the model.
    """
    ## model gets modified to replace named distributions with generated samples
    model = substitute(seed(model, rng), post_samples)
 
    ## this code is very similar to lines 42.. in rethinking.py
    ## can I try this on the original model? 
    ## not yet... 
    model_trace = trace(model).get_trace(*args, **kwargs)
    return model_trace['obs']['value']

In [557]:
# vectorize predictions via vmap
predict_fn = vmap(lambda rng, samples: 
                  predict(rng, samples, model, xs=xs))
rng, rng_ = random.split(rng)

# samples_1 is a dict having a key 'beta' which has all of the generated values of beta
# random.split(rng_, num_samples) generates an array of random keys 

## jm: num_samples must equal the number of beta samples we're sending in 
assert(num_samples == samples_1['beta'].shape[0])
predictions_1 = predict_fn(random.split(rng_, num_samples), samples_1)

print(f'predictions shape (xs.shape = 1000): {predictions_1.shape}')
mean_pred = np.mean(predictions_1, axis=0)
hpdi_pred = hpdi(predictions_1, 0.9)

predictions shape (xs.shape = 1000): (2000, 1000)


In [558]:
# vectorize predictions via vmap
predict_fn = vmap(lambda rng, samples: 
                  predict(rng, samples, model, xs=xs))
rng, rng_ = random.split(rng)

# samples_1 is a dict having a key 'beta' which has all of the generated values of beta
# random.split(rng_, num_samples) generates an array of random keys 


## jm: num_samples must equal the number of beta samples we're sending in 
assert(num_samples == samples_1['beta'].shape[0])
nn = 20
n_samples = dict()

n_samples['beta'] = samples_1['beta'][:nn]
print(f"n_samples['beta'].shape = {n_samples['beta'].shape}")
rng_a = random.split(rng_, nn)
print(f'rng_ shape is {rng_a.shape}' )
predictions_10 = predict_fn(rng_a, n_samples)

mean_pred = np.mean(predictions_10, axis=0)
hpdi_pred = hpdi(predictions_10, 0.9)

print(f'predictions shape: {predictions_1.shape}')

n_samples['beta'].shape = (20,)
rng_ shape is (20, 2)
predictions shape: (2000, 1000)


In [559]:
hv_plot_regression(xs, mean_pred, hpdi_pred).opts(
    title='Predictions with 90% CI',
    xlabel='xs', ylabel='ys') 
# ax = plot_regression(dset.MarriageScaled.values, mean_pred, hpdi_pred)
# ax.set(xlabel='Marriage rate', ylabel='Divorce rate', title='Predictions with 90% CI');

In [560]:
hv.Scatter((np.tile(xs,(50,1)).flatten(), predictions_1[:50].flatten())).opts(
size=1, alpha=0.2,width=700)
# ax = plot_regression(dset.MarriageScaled.values, mean_pred, hpdi_pred)
# ax.set(xlabel='Marriage rate', ylabel='Divorce rate', title='Predictions with 90% CI');

### explorations

```
Init signature: substitute(fn=None, param_map=None)
Docstring:     
Given a callable `fn` and a dict `param_map` keyed by site names,
return a callable which substitutes all primitive calls in `fn` with
values from `param_map` whose key matches the site name. If the
site name is not present in `param_map`, there is no side effect.

:param fn: Python callable with NumPyro primitives.
:param dict param_map: dictionary of `numpy.ndarray` values keyed by
   site names.

**Example:**

 .. testsetup::

   from jax import random
   from numpyro.handlers import sample, seed, substitute, trace
   import numpyro.distributions as dist

.. doctest::

   >>> def model():
   ...     sample('a', dist.Normal(0., 1.))

   >>> model = seed(model, random.PRNGKey(0))
   >>> exec_trace = trace(substitute(model, {'a': -1})).get_trace()
   >>> assert exec_trace['a']['value'] == -1
File:           ~/anaconda3/envs/pytorch_exp/lib/python3.7/site-packages/numpyro/handlers.py
Type:           type
```

In [169]:
m1 = substitute(seed(model,random.split(rng_,num_samples)), samples_1)

### Working `substitute` example!

In [442]:
mj = substitute(seed(model, rng), {'beta': 3.0})
trace(mj).get_trace(xs=xs[:10])

OrderedDict([('beta',
              {'type': 'sample',
               'name': 'beta',
               'fn': <numpyro.distributions.continuous.Normal at 0x1a3bebe908>,
               'args': (),
               'kwargs': {'random_state': DeviceArray([856958206, 578161011], dtype=uint32)},
               'value': 3.0,
               'is_observed': False}),
             ('obs',
              {'type': 'sample',
               'name': 'obs',
               'fn': <numpyro.distributions.continuous.Normal at 0x1a3b6d5390>,
               'args': (),
               'kwargs': {'random_state': DeviceArray([2331400018, 1417683751], dtype=uint32)},
               'value': DeviceArray([-6.40509272, -5.93795681, -5.57958698, -4.96822357,
                            -5.31321812, -6.18614721, -5.09394693, -5.08060408,
                            -2.21431637, -7.19801807], dtype=float32),
               'is_observed': False})])

In [485]:
# IMPORTANT: must have rng.shape[0]=2, 1 <= 'beta'.shape[0] <= 2
rngs = random.split(rng, 10)[0]
print(rngs, rngs.shape)
seeded_model = seed(model, rngs)
mj = substitute(seeded_model, {'beta': np.array([3.0 ])})
trace(mj).get_trace(xs=xs[:9])

[1086397533 3062596817] (2,)


OrderedDict([('beta',
              {'type': 'sample',
               'name': 'beta',
               'fn': <numpyro.distributions.continuous.Normal at 0x1a3c9e6208>,
               'args': (),
               'kwargs': {'random_state': DeviceArray([1086397533, 3062596817], dtype=uint32)},
               'value': DeviceArray([3.], dtype=float32),
               'is_observed': False}),
             ('obs',
              {'type': 'sample',
               'name': 'obs',
               'fn': <numpyro.distributions.continuous.Normal at 0x1a3c9e6f60>,
               'args': (),
               'kwargs': {'random_state': DeviceArray([3906920350, 2408377984], dtype=uint32)},
               'value': DeviceArray([-5.49607277, -5.75218868, -5.82157898, -6.00096607,
                            -5.51485157, -6.14832115, -3.56323338, -5.7398634,
                            -4.36138201], dtype=float32),
               'is_observed': False})])

```
Init signature: trace(fn=None)
Docstring:     
Returns a handler that records the inputs and outputs at primitive calls
inside `fn`.

**Example**

.. testsetup::

   from jax import random
   import numpyro.distributions as dist
   from numpyro.handlers import sample, seed, trace
   import pprint as pp

.. doctest::

   >>> def model():
   ...     sample('a', dist.Normal(0., 1.))

   >>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
   >>> pp.pprint(exec_trace)  # doctest: +SKIP
   OrderedDict([('a',
                 {'args': (),
                  'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
                  'is_observed': False,
                  'kwargs': {'random_state': DeviceArray([0, 0], dtype=uint32)},
                  'name': 'a',
                  'type': 'sample',
                  'value': DeviceArray(-0.20584235, dtype=float32)})]
```


In [439]:
mt = trace(m1)

### `get_trace` doc string
```
Signature: mt.get_trace(*args, **kwargs)
Docstring:
Run the wrapped callable and return the recorded trace.

:param `*args`: arguments to the callable.
:param `**kwargs`: keyword arguments to the callable.
:return: `OrderedDict` containing the execution trace.
File:      ~/anaconda3/envs/pytorch_exp/lib/python3.7/site-packages/numpyro/handlers.py
Type:      method
```

In [None]:
m2 = seed(model,rng)

In [440]:
mgt = mt.get_trace(xs=xs)

ValueError: Incompatible shapes for broadcasting: ((2000,), (1000,))

In [171]:
mgt = mt.get_trace(xs=xs)

ValueError: too many values to unpack (expected 2)

In [151]:
m1.param_map

{'beta': DeviceArray([2.29566765, 2.30147052, 2.29055333, ..., 2.29949427,
              2.29946637, 2.29907155], dtype=float32)}

In [155]:
m1.fn, m1.fn.fn

(<numpyro.handlers.seed at 0x1a38eb8e10>,
 <function __main__.model(xs=None, ys=None)>)

-------------

In [131]:
mean_pred.shape

(1000,)

In [132]:
hpdi_pred.shape

(2, 1000)

In [None]:
se

## vmap, tile, broadcast

In [300]:
xv = np.array([1.,2., 3,])
yv = np.array([4., 5.])

In [314]:
vv = lambda x, y: np.vdot(x, y)  #  ([a], [a]) -> []
vvv = lambda x, y: x + y

In [315]:
# (0, None) 'means...' 0 axis means first arg is the row of shape b, and None means all of it (I think)
# but how does the "out" parameter do?
mv = vmap(vv, (0, None), 0)      #  ([a,b], [b]) -> [a]

mvv = vmap(vvv, (0, None), 0)      #  ([a,b], [b]) -> [a]

In [317]:
vvv(xv,xv)

DeviceArray([2., 4., 6.], dtype=float32)

In [316]:
mvv(xv,xv)

DeviceArray([[2., 3., 4.],
             [3., 4., 5.],
             [4., 5., 6.]], dtype=float32)

In [344]:
# xv = [1 2 3]
# iterate over rows x all -> 
print(vmap(vvv, (0, None), 1)(xv, xv **2 ))
xv16 = xv.tile((1,2))
xv23 = xv.tile((2,1)).cumsum(axis=0)

print(xv.tile((1,2)), xv.tile((1,2)).shape, xv**2)
print(xv.tile((2,1)), xv.tile((2,1)).shape, xv**2)
# iterate over columns x all -> 
print("\n(1,None),0)\n",vmap(vvv, (1, None), 0)(xv16, xv **2 ))
print("\n(1,None),1)\n",vmap(vvv, (1, None), 1)(xv16, xv **2 ))
print("\n(0,None),0)\n",vmap(vvv, (0, None), 0)(xv23, xv **2 ))
print("\n(0,None),0)\n",vmap(vvv, (0, None), 1)(xv23, xv **2 ))
print("\n(1,0),0)\n",vmap(vvv, (1, 0), 0)(xv23, xv **2 ))
# print("\n(0,0),0)\n",vmap(vvv, (0, None), 1)(xv23, xv **2 ))
 

[[ 2.  3.  4.]
 [ 5.  6.  7.]
 [10. 11. 12.]]
[[1. 2. 3. 1. 2. 3.]] (1, 6) [1. 4. 9.]
[[1. 2. 3.]
 [1. 2. 3.]] (2, 3) [1. 4. 9.]

(1,None),0)
 [[ 2.  5. 10.]
 [ 3.  6. 11.]
 [ 4.  7. 12.]
 [ 2.  5. 10.]
 [ 3.  6. 11.]
 [ 4.  7. 12.]]

(1,None),1)
 [[ 2.  3.  4.  2.  3.  4.]
 [ 5.  6.  7.  5.  6.  7.]
 [10. 11. 12. 10. 11. 12.]]

(0,None),0)
 [[ 2.  6. 12.]
 [ 3.  8. 15.]]

(0,None),0)
 [[ 2.  3.]
 [ 6.  8.]
 [12. 15.]]

(1,0),0)
 [[ 2.  3.]
 [ 6.  8.]
 [12. 15.]]


In [363]:
xv16 = xv.tile((1,2))
xv23 = xv.tile((2,1)).cumsum(axis=0)

def vmap_exp(a1, a2, in_shape, out_shape):
    print("\n a1 ", a1, a1.shape)
    print("\n a2 ", a2, a2.shape)
    print(f"\n {in_shape},{out_shape}\n",vmap(vvv, in_shape, out_shape)(a1,a2 ))
    
# vmap_exp(xv23, (xv**2).tile((2,1)), (1,1),1)
vmap_exp(xv23, (xv**2).tile((2,1)).T, (1,0),1)

 


 a1  [[1. 2. 3.]
 [2. 4. 6.]] (2, 3)

 a2  [[1. 1.]
 [4. 4.]
 [9. 9.]] (3, 2)

 (1, 0),1
 [[ 2.  6. 12.]
 [ 3.  8. 15.]]


In [303]:
# (None, 1) : take all of first arg, and axis 1 or COLUMN of 2nd arg. (this is working)
mm = vmap(mv, (None, 1), 1)      #  ([a,b], [b,c]) -> [a,c]
mm0 = vmap(mv, (None, 1), 0)      #  ([a,b], [b,c]) -> [a,c]
mmm = vmap(mv, (None, 1), 0)      #  ([a,b], [b,c]) -> [a,c]


In [304]:
dv = vmap(vv)

In [305]:
_a = xv.broadcast((2,1))
_a, _a.shape

(DeviceArray([[[1., 2., 3.]],
 
              [[1., 2., 3.]]], dtype=float32), (2, 1, 3))

In [306]:
_a = xv.broadcast((1,2))
_a, _a.shape

(DeviceArray([[[1., 2., 3.],
               [1., 2., 3.]]], dtype=float32), (1, 2, 3))

In [307]:
# tile: each position gives how many copies is the positions axis
# m.tile((a,b,c)) gives a cop
xv, xv.tile((1,2)), xv.tile((2,1)), xv.tile((2,1)).tile((3,1))

(DeviceArray([1., 2., 3.], dtype=float32),
 DeviceArray([[1., 2., 3., 1., 2., 3.]], dtype=float32),
 DeviceArray([[1., 2., 3.],
              [1., 2., 3.]], dtype=float32),
 DeviceArray([[1., 2., 3.],
              [1., 2., 3.],
              [1., 2., 3.],
              [1., 2., 3.],
              [1., 2., 3.],
              [1., 2., 3.]], dtype=float32))

In [308]:
dv(xv,xv)

DeviceArray([1., 4., 9.], dtype=float32)

In [309]:
dv(xv22, xv22)

DeviceArray([26., 26.], dtype=float32)

In [310]:
dv(xv22.T, xv22.T)

DeviceArray([ 8., 18.,  8., 18.], dtype=float32)

In [311]:
xv22 = xv.tile((2,2)); xv22, xv22.shape

(DeviceArray([[1., 2., 3., 1., 2., 3.],
              [1., 2., 3., 1., 2., 3.]], dtype=float32), (2, 6))

In [290]:
yv.tile((2,1)).shape, yv.tile((4,1)).cumsum(axis=0)

((2, 2), DeviceArray([[ 4.,  5.],
              [ 8., 10.],
              [12., 15.],
              [16., 20.]], dtype=float32))

In [312]:
rm = onp.random.rand(4,2)
rm, xv22

(array([[0.33052856, 0.46554089],
        [0.750993  , 0.46288669],
        [0.71516137, 0.89367461],
        [0.1472789 , 0.40977241]]), DeviceArray([[1., 2., 3., 1., 2., 3.],
              [1., 2., 3., 1., 2., 3.]], dtype=float32))

In [313]:
mm(xv22.cumsum(axis=0), rm) 

TypeError: Incompatible shapes for dot: got (2, 6) and (4, 2).

In [299]:
mm0(xv22.cumsum(axis=0), rm) 

DeviceArray([[2.71959376, 5.43918753],
             [4.0226717, 8.0453434]], dtype=float32)

In [283]:
mm0(xv22.tile((2,1)), xv22.T **2)

DeviceArray([[70., 70., 70., 70.],
             [70., 70., 70., 70.]], dtype=float32)

In [236]:
xv21 = xv.tile((2,1)); xv21

DeviceArray([[2., 3.],
             [2., 3.]], dtype=float32)

In [250]:
xv22.shape

(2, 4)

In [238]:
# (2,4) x (4,2) -> (2,2)
mm(xv22, xv22.T)

DeviceArray([[26., 26.],
             [26., 26.]], dtype=float32)

In [249]:
# (4,2) x (2,4) -> (4,4)
mm(xv22.T, xv22)

DeviceArray([[ 8., 12.,  8., 12.],
             [12., 18., 12., 18.],
             [ 8., 12.,  8., 12.],
             [12., 18., 12., 18.]], dtype=float32)

In [247]:
# (4,2) x (2,) -> (4,)
mv(xv22.T, xv)

DeviceArray([10., 15., 10., 15.], dtype=float32)

In [252]:
#  Incompatible shapes for dot: got (2, 4) and (2,).
mv(xv22, xv)

TypeError: Incompatible shapes for dot: got (2, 4) and (2,).

In [251]:
#  Incompatible shapes for dot: got (2, 4) and (2,).
mv(xv22, xv.tile(2))

DeviceArray([26., 26.], dtype=float32)

In [245]:
xv.shape

(2,)

In [246]:

xv

DeviceArray([2., 3.], dtype=float32)

In [None]:
onp.r