In [4]:
import os
import sys

nb_dir = os.path.split(os.path.split(os.getcwd())[0])[0]
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

In [32]:
from jax.lax import custom_root, custom_linear_solve
from jax.scipy.linalg import solve
import jax.numpy as np
import jax
from jax import jacrev as jacobian

In [21]:
from jaxsnn.dataset.yinyang import YinYangDataset
from jax import random


rng = random.PRNGKey(42)

trainset = YinYangDataset(rng, 100)

def load_data():
    return trainset.vals, trainset.classes

In [30]:
X_train, y_train = load_data()

def f(x):  # Objective function
    residual = np.dot(X_train, x) - y_train
    return np.sum(residual ** 2)

jax.grad(f, argnums=0)(np.ones(4))

DeviceArray([107.21371,  89.36005,  90.78631, 108.63995], dtype=float32)

In [84]:
X_train, y_train = load_data()
X_train = np.hstack((X_train, X_train, X_train))
num_features = 12

def f(x):  # Objective function
    residual = np.dot(X_train, x) - y_train
    return np.sum(residual ** 2)

# Since f is differentiable and unconstrained, the optimality
# condition F is simply the gradient of f in the first argument
F = jax.grad(f, argnums=0)

def solver(F, init_x):
    del init_x
    XX = np.dot(X_train.T, X_train)  # type: ignore
    Xy = np.dot(X_train.T, y_train)  # type: ignore
    # Finds the ridge reg solution by solving a linear system
    return np.linalg.solve(XX, Xy)


def tangent_solver(g, y):
    return np.linalg.solve(jacobian(g)(y), y)

init_x = np.zeros(num_features)
sol = custom_root(F, init_x, solver, tangent_solver)
print(f(sol))

nan


In [96]:
from jaxsnn.module.lif import LIFParameters
import jax.numpy as np
import jax


def ridge_solver(F, init_x):
    del init_x
    XX = np.dot(X_train.T, X_train)  # type: ignore
    Xy = np.dot(X_train.T, y_train)  # type: ignore
    return np.linalg.solve(XX, Xy)


def tangent_solver(g, y):
    return np.linalg.solve(jacobian(g)(y), y)

A = np.array([[-1,1], [0,-1]])
x0 = np.array([0.0, 1.0])

def v(t):
    return np.dot(jax.scipy.linalg.expm(A * t), x0)[0]

init_x = 0.1

f(1.0)
# sol = custom_root(f, init_x, ridge_solver, tangent_solver)





DeviceArray(58.97054, dtype=float32)