In [None]:
import jax
import jax.numpy as jnp

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

### Problem 1

Compute the gradients with respect to $U \in R^{n×k}$ and $V \in R^{k×n},\; k<n$

$J(U, V) = \|UV - Y\|^2_F + \frac{\lambda}{2}(\|U\|^2_F + \|V\|^2_F)$

a) *Let's consider $\frac{dJ(U,V)}{dV}$*

$\frac{d}{dV}\|U\|^2_F=\frac{d}{dV}\sum_{i=1}^k\sum_{j=1}|v_{ij}|^2 = 2V →$


$\frac{d}{dV}\frac{1}{2}\lambda(\|U\|^2_F + \|V\|^2_F) = \lambda V$;


$ \|UV - Y\|^2_F = \sum_{k=1}\sum_{j}(\sum_{i}u_{i,k}v_{k,j}-y_{ij})\;$

$\frac{d\|UV - Y\|^2_F}{dv_{ij}} = \sum_{k}2u_{ki}(\sum_{i=1}u_{ki}v_{ij}-y_{kj}) = \sum_k2u_{ki}(UV-Y)_{kj} → $

$\frac{d\|UV - Y\|^2_F}{dV} = 2U^T(UV-Y)$

Thus $\frac{J(U, V)}{dV} = 2U^T(UV-Y) + \lambda V$

b) *Let's consider $\frac{dJ(U,V)}{dU}$*

$\frac{d}{dV}\|U\|^2_F=\frac{d}{dU}\sum_{i=1}^k\sum_{j=1}|u_{ij}|^2 = 2U →$

$\frac{d}{dU}\frac{1}{2}\lambda(\|U\|^2_F + \|V\|^2_F) = \lambda U$;

$ \|UV - Y\|^2_F = \sum_{k=1}\sum_{j}(\sum_{i}u_{i,k}v_{k,j}-y_{ij})\;$

$\frac{d\|UV - Y\|^2_F}{dU} = 2(UV-Y)V^T $ (the reasoning is similar to the upper one) →

 $\frac{J(U, V)}{dU} = 2(UV-Y)V^T + \lambda U$




In [None]:
@jax.jit
def f_1(U, V, Y, lambd):
    first = jnp.linalg.norm(U @ V - Y)**2
    return first + lambd * 0.5 * (jnp.linalg.norm(U)**2 + jnp.linalg.norm(V)**2)

In [None]:
def custom_grad_f1_dv(U, V, Y, lambd):
    return 2. * (U.T @ (U @ V - Y)) + lambd * V


def custom_grad_f1_du(U, V, Y, lambd):
    return 2. * (U @ V - Y ) @ V.T + lambd * U

In [None]:
n = 3000
k = 1000
U = jax.random.normal(jax.random.PRNGKey(0), (n, k))
V = jax.random.normal(jax.random.PRNGKey(0), (k, n))
Y = jax.random.normal(jax.random.PRNGKey(0), (n, n))
lambd = jax.random.normal(jax.random.PRNGKey(0))



In [None]:
gradf1_du = jax.grad(f_1, argnums = 0)(U, V, Y, lambd)
gradf1_dv = jax.grad(f_1, argnums = 1)(U, V, Y, lambd)

In [None]:
print(f"Check correctness of found dU: {jnp.linalg.norm(custom_grad_f1_du(U, V, Y, lambd) - gradf1_du)}")
print(f"Check correctness of found dV: {jnp.linalg.norm(custom_grad_f1_dv(U, V, Y, lambd) - gradf1_dv)}")

Check correctness of found dU: 1.7008181002102575e-11
Check correctness of found dV: 1.7650432557857606e-11


In [None]:
print("Compare speed")
print("Analytical gradient")
%timeit custom_grad_f1_du(U, V, Y, lambd)
print("Grad function")
%timeit jax.grad(f_1, argnums = 0)(U, V, Y, lambd).block_until_ready()
jit_gradf = jax.jit(jax.grad(f_1, argnums = 0))
print("Jitted grad function")
%timeit jit_gradf(U, V, Y, lambd).block_until_ready()

Compare speed
Analytical gradient
2.36 s ± 300 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Grad function
2.18 s ± 17.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jitted grad function
2.14 s ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Problem 2

#### Part 1

Compute the Jacobi matrix of the following function

$f: R^n→R^n, \; f(w_j) = \frac {e^{w_j}} {\sum_{k=1}^{n} e^{w_k}}$

Let's consider two cases:

Firstly, we compute derivative for elements in main diagonal of Jacobian, $\frac{df_x}{dw_x}$

Using the rule of differentiation of fractions and the chain rule, we get  
$\frac{df_x}{dw_x}(w) = \frac{df_x}{dw_x}\frac{e^{w_x}}{\sum_{k=1}^{n} e^{w_k}} $ =
$\frac{e^{w_x}\cdot[\sum_{k=1}^ne^{w_k}] - e^{2w_x}}{(\sum_{k=1}^{n}e^{w_k})^2}$

$=\frac{e^{w_x}(e^{w_x}+...e^{w_y}) - e^{2w_x}}{(\sum_{k=1}^ne^{w_k})^2}=
\frac{e^{w_x}\sum_{k=1, k\neq x}^ne^{w_k}}{(\sum_{k=1}^ne^{w_k})^2}$



Secondly, let's find $\frac{df_x}{dw_y}$, where $x\neq y$

$\frac{df_x}{dw_y}\frac{e^{w_x}}{\sum_{k=1}^{n} e^{w_k}}=e^{w_x}(\frac{d}{dw_y}\frac{1}{\sum_{k=1}^{n} e^{w_k}})$

$Let \frac{1}{\sum_{k=1}^ne^{w_k}}=u\;,\;\frac{d}{du}=\frac{-1}{u^2}$

Using the chain rule,the derivative $\frac{d}{dy}(\frac{1}{\sum_{k=1}^ne^{w_k}})=\frac{d}{du}\frac{du}{dy}$



$\frac{d}{dw_y}=-\frac{\frac{d}{dw_y}\sum_{k=1}^ne^{w_k}}{(\sum_{k=1}^ne^{w_k})^2}\cdot e^{w_x} = -\frac{e^{w_x+w_y}}{(\sum_{k=1}^ne^{w_k})^2}$

In [None]:
@jax.jit
def f_2(w):
    summ = jnp.sum(jnp.exp(w))
    return jnp.divide(jnp.exp(w), summ)

In [None]:
def d_fxwx(w):
    summ = jnp.sum(jnp.exp(w))
    e_w = jnp.exp(w) * (summ - jnp.exp(w))
    return jnp.divide(e_w, jnp.power(summ, 2))


def d_fxwy(w, x):
    '''возвращает столбец'''
    summ = jnp.sum(jnp.exp(w))
    dy = -1. * jnp.divide(jnp.exp(w[x] + w), jnp.power(summ, 2))
    dy = dy.at[x].set(0.)
    return dy

In [None]:
def jacobi_f_2(w):
    diag = jnp.diag(d_fxwx(w))
    res = d_fxwy(w, 0).reshape(1, -1)
    for row in range(1, len(w)):
        dy = d_fxwy(w, row).reshape(1, -1)
        res = jnp.concatenate([res, dy])
    jac = res + diag
    return jac

In [None]:
n = 100
w = jax.random.normal(jax.random.PRNGKey(0), (n,))

In [None]:
jac_fvec = jax.jacobian(f_2)(w)
print(f"Check correctness: {jnp.linalg.norm(jacobi_f_2(w) - jac_fvec)}")
print("Compare speed")
print("Analytical jacobi")
%timeit jacobi_f_2(w)
print("Ready Jacobi function")
%timeit jax.jacobian(f_2)(w).block_until_ready()
jit_jac = jax.jit(jax.jacobian(f_2))
print("Jitted jacobi function")
%timeit jit_jac(w).block_until_ready()

Check correctness: 1.706997365363207e-17
Compare speed
Analytical jacobi
260 ms ± 7.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Ready Jacobi function
3.92 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jitted jacobi function
21.6 µs ± 812 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


#### Part 2

Compute the gradient of the following functions with respect to matrix X

a) $f(X) = \sum_{i=1}^{n} \lambda_i(X)$



$\sum_{i=1}^{n}\lambda_i(X) = tr(X)$ (we can prove that casting matrix to Jordan form, where eigenvalues will be on the main diagonal)  

And $\frac{d}{dX} tr(X) = I, so \nabla f(X) = I$

*Code:*

In [None]:
@jax.jit
def f_eig_sum(X):
    return jnp.trace(X)

In [None]:
n = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, n))

In [None]:
custom_grad_eigsum = lambda y: jnp.identity(y.shape[0])
grad_f_sum = jax.grad(f_eig_sum)(x)
print(f"Check correctness, distance between answers: {jnp.linalg.norm(custom_grad_eigsum(x) - grad_f_sum)}")
print("Compare speed")
print("Analytical gradient")
%timeit custom_grad_eigsum(x)
print("Grad function")
%timeit jax.grad(f_eig_sum, argnums = 0)(x).block_until_ready()
jit_gradf = jax.jit(jax.grad(f_eig_sum, argnums = 0))
print("Jitted grad function")
%timeit jit_gradf(x).block_until_ready()

Check correctness, distance between answers: 0.0
Compare speed
Analytical gradient
3.9 ms ± 169 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Grad function
9.59 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jitted grad function
809 µs ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


б) $f(X) = \prod_{i=1}^{n} \lambda_i(X)$  

$\prod_{i=1}^{n} \lambda_i(X) = det(X)$

Knowing that $det(X) = \sum_{i=1}^{n}(-1)^{(i+j)}X_{ij}M_{ij}$,  
$\frac{d\:det(X)}{dX_{ij}} = \sum_{i=i}^n(-1)^{(i+j)}M_{ij}$, this is matrix of cofactors of X,   
so $\nabla f(X) = \sum_{i=i}^n(-1)^{(i+j)}M_{ij}$

In [None]:
@jax.jit
def f_eig_mult(X):
    return jnp.linalg.det(X)

We can calculate matrix of cofactors as $(X^{-1})^T*det(X)$ (assuming that matrix X is invertible)

In [None]:
def custom_grad_eigmult(X):
    return jnp.linalg.inv(X).T * jnp.linalg.det(X)

In [None]:
n = 15
x = jax.random.normal(jax.random.PRNGKey(0), (n, n))

In [None]:
grad_f_mult = jax.grad(f_eig_mult)(x)
print(f"Check correctness, distance between answers: {jnp.linalg.norm(custom_grad_eigmult(x) - grad_f_mult)}")

Check correctness, distance between answers: 7.535048156619044e-09


In [None]:
grad_f_mult = jax.grad(f_eig_mult)(x)
print(f"Check correctness, distance between answers: {jnp.linalg.norm(custom_grad_eigmult(x) - grad_f_mult)}")
print("Compare speed")
print("Analytical gradient")
%timeit custom_grad_eigmult(x)
print("Grad function")
%timeit jax.grad(f_eig_mult, argnums = 0)(x).block_until_ready()
jit_gradf = jax.jit(jax.grad(f_eig_mult, argnums = 0))
print("Jitted grad function")
%timeit jit_gradf(x).block_until_ready()

Check correctness, distance between answers: 7.535048156619044e-09
Compare speed
Analytical gradient
424 µs ± 16.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Grad function
2.48 ms ± 92.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jitted grad function
21.1 µs ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
