In [1]:
import jax
import jax_metrics as jm
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax import random
import os
import numpy as np
import matplotlib.pyplot as plt
# Switch off the cache 
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

In [2]:
jnp.log(4)

Array(1.3862944, dtype=float32, weak_type=True)

In [3]:
dct = {'a': 0., 'b': jnp.arange(5.)}
x = 1.
def foo(dct):
 return dct['a'] + dct['b'] 
out = vmap(foo, in_axes=({'a': None, 'b': 0},))(dct)
print(out)

[0. 1. 2. 3. 4.]


In [4]:
@jit
def logit_class(W: jnp, x:jnp, cl: int)->jnp:
    return jnp.log(jnp.dot(W[cl,:], x)/(1.0 + jnp.sum(jnp.dot(W[0:-1,:], x), axis=0)))

@jit
def logit_class_k(W: jnp, x:jnp, cl: int)->jnp:
    return jnp.log(1.0/(1.0 + jnp.sum(jnp.dot(W[0:-1,:], x), axis=0)))


def logit_model(y):
    """
    logit model
    """
    global W, X, number_of_classes
    cls = y.tolist()[0]
    x = X[y.tolist()[0], :]
    if idx!=number_of_classes-1:
        result = logit_class(W, x, y)
    else:
        result = logit_class_k(W, x, y)
    return result

def model(W: jnp,  X: jnp, y: np, )->float:
    number_of_classes = len(jnp.unique(y))
    dct_lst = [{'W': W, 'X':X[idx, :], 'y': ind, 'number_of_classes': number_of_classes} for idx, ind in enumerate(y)]
    print(dct_lst)
    return jax.lax.map(logit_model, dct_lst)

In [5]:
key = random.PRNGKey(0)
keys = random.split(key, 2)
W = random.normal(keys[0], (3,3))
X = jax.numpy.hstack([random.normal(keys[0], (3,3)), jnp.ones((3, 1))])
Y = jnp.array([[1], [2], [3]])

In [10]:
W

Array([[-2.6105583 ,  0.03385283,  1.0863333 ],
       [-1.480299  ,  1.5403248 ,  1.062516  ],
       [ 0.54174834,  0.0170228 ,  0.2722685 ]], dtype=float32)

In [11]:
t = jnp.argmax(W, axis=0)

In [15]:
t.shape

(3,)

In [14]:
2*-2.6105583+0.03385283

-5.18726377

In [12]:
W@t

Array([-5.1872635, -1.4202732,  1.1005195], dtype=float32)

In [None]:
Y.tolist()[1][0]

In [None]:
logits[Y.tolist()[1][0]](W, X, Y.tolist()[1][0])

In [None]:
logits[1](W, X, 1)

In [None]:
logits[2](W, X, 0)

In [44]:
t = jnp.reshape(jnp.array([0.1, 1, 1]), newshape=(1, 3))

In [45]:
t.shape

(1, 3)

In [46]:
s = jnp.repeat(t, repeats=5, axis=0)

In [47]:
jnp.sum(s, axis=0).shape

(3,)

In [48]:
1/t

Array([[10.,  1.,  1.]], dtype=float32)

In [49]:
s

Array([[0.1, 1. , 1. ],
       [0.1, 1. , 1. ],
       [0.1, 1. , 1. ],
       [0.1, 1. , 1. ],
       [0.1, 1. , 1. ]], dtype=float32)

In [50]:
s*s

Array([[0.01, 1.  , 1.  ],
       [0.01, 1.  , 1.  ],
       [0.01, 1.  , 1.  ],
       [0.01, 1.  , 1.  ],
       [0.01, 1.  , 1.  ]], dtype=float32)

In [51]:
s

Array([[0.1, 1. , 1. ],
       [0.1, 1. , 1. ],
       [0.1, 1. , 1. ],
       [0.1, 1. , 1. ],
       [0.1, 1. , 1. ]], dtype=float32)

In [52]:
s = jnp.vstack([s, 1/t])

In [53]:
s

Array([[ 0.1,  1. ,  1. ],
       [ 0.1,  1. ,  1. ],
       [ 0.1,  1. ,  1. ],
       [ 0.1,  1. ,  1. ],
       [ 0.1,  1. ,  1. ],
       [10. ,  1. ,  1. ]], dtype=float32)

In [55]:
1.0/(1.0 + jnp.exp(s))

Array([[4.7502080e-01, 2.6894143e-01, 2.6894143e-01],
       [4.7502080e-01, 2.6894143e-01, 2.6894143e-01],
       [4.7502080e-01, 2.6894143e-01, 2.6894143e-01],
       [4.7502080e-01, 2.6894143e-01, 2.6894143e-01],
       [4.7502080e-01, 2.6894143e-01, 2.6894143e-01],
       [4.5397868e-05, 2.6894143e-01, 2.6894143e-01]], dtype=float32)

In [58]:
t = jnp.transpose(jax.nn.one_hot(jnp.array([0, 1, 2, 1, 1]), 3))

In [59]:
t

Array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 1., 1.],
       [0., 0., 1., 0., 0.]], dtype=float32)

In [60]:
jnp.ones((1, 10))

Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

In [16]:
alpha = 2

In [17]:
alpha /= 2

In [18]:
alpha

1.0

In [19]:
key = random.PRNGKey(0)
keys = random.split(key, 1)
X = random.normal(keys[0], (2, 10))

In [20]:
X

Array([[-0.74996024,  0.5132945 , -2.1002128 , -1.3719002 ,  1.6364152 ,
        -2.6171365 , -0.251998  ,  0.1812941 ,  0.5050053 , -0.02773665],
       [-0.41538277, -1.6634675 ,  1.721781  ,  0.8013188 ,  0.44873542,
         0.631671  ,  0.49231628,  0.03253433,  2.4587312 , -0.19775471]],      dtype=float32)

In [21]:
X + 5*jnp.ones((2, 1))

Array([[4.2500396, 5.5132947, 2.8997872, 3.6281   , 6.6364155, 2.3828635,
        4.748002 , 5.181294 , 5.5050054, 4.9722633],
       [4.584617 , 3.3365326, 6.721781 , 5.8013186, 5.448735 , 5.631671 ,
        5.4923162, 5.032534 , 7.458731 , 4.802245 ]], dtype=float32)

In [22]:
X

Array([[-0.74996024,  0.5132945 , -2.1002128 , -1.3719002 ,  1.6364152 ,
        -2.6171365 , -0.251998  ,  0.1812941 ,  0.5050053 , -0.02773665],
       [-0.41538277, -1.6634675 ,  1.721781  ,  0.8013188 ,  0.44873542,
         0.631671  ,  0.49231628,  0.03253433,  2.4587312 , -0.19775471]],      dtype=float32)

In [23]:
5*jnp.ones((2, 1))

Array([[5.],
       [5.]], dtype=float32)