In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [2]:
import jax
import jax.numpy as jnp
import jaxopt
import sympy
# from sympy.parsing.latex import parse_latex
# from sympy2jax import sympy2jax
import matplotlib.pyplot as plt
import scipy
import scipy.constants as const
import numpy as np
import math

In [3]:
from jax.config import config
config.update("jax_enable_x64", True)

In [4]:
%matplotlib notebook

In [5]:
import scipy.special as sp

In [6]:
e = np.e
ln = np.log
ln_j = jnp.log
gamma_0 = sp.exp1
gamma_0_j = jax.scipy.special.exp1

In [7]:
# a = 1
R = jnp.float64(10)

This goes unstable around an alpha of 737 (for mirror ratios < 1000 the total value won't go above 222, it looks like). And that's just for 64 bit. 32 bit terminates around alpha of 83 and the derivative around an alpha of 39.

JAX 64 bit goes up to alpha of ~695, and ~342 for gradients

In [8]:
rhs_jax = lambda a, R: ((e ** (-a) / gamma_0_j(a) - a))*2/3 + 6 * (((e ** (-a) / gamma_0_j(a) - a)) * a ** 2 * ln_j(R) ** 2/(22.4) ** 2 * 3/2) ** (1/3)

In [9]:
rhs = lambda a, R: ((e ** (-a) - a*gamma_0(a))/gamma_0(a))*2/3 + 6 * (((e ** (-a) - a*gamma_0(a))/gamma_0(a)) * a ** 2 * ln(R) ** 2/(22.4) ** 2 * 3/2) ** (1/3)

In [10]:
rhs_jax_grad = jax.grad(rhs_jax)

In [11]:
a = np.exp(np.linspace(np.log(0.01), np.log(100), 100))

plt.figure()
plt.plot(a, rhs(a, 2))
plt.plot(a, rhs(a, 4))
plt.plot(a, rhs(a, 10))
plt.plot(a, rhs(a, 16))
plt.hlines(1, -1, 101)
plt.xscale('log')
plt.yscale('log')

<IPython.core.display.Javascript object>

# jaxopt

In [12]:
from jaxopt import Bisection

## Calculate alpha

In [13]:
def fun_to_solve(a, R_in, power_aux_in):
    return rhs_jax(a, R_in) - (1 + power_aux_in)

In [150]:
def find_alpha(R_in, power_aux_in):
    bisec = Bisection(optimality_fun=fun_to_solve, lower=0.001, upper=200.0, tol=1e-3, check_bracket=True, unroll=True)
    alpha = bisec.run(R_in=R_in, power_aux_in=power_aux_in).params
    return alpha

In [69]:
num_calcs = 100
power_aux = jnp.zeros(num_calcs)
R = jnp.linspace(2, 150, 100)

In [72]:
alpha_arr = []
for i in range(num_calcs):
    bisec = Bisection(optimality_fun=fun_to_solve, lower=0.1, upper=200.0, tol=1e-3, check_bracket=True)
    alpha = bisec.run(R_in=R[i], power_aux_in=power_aux[i]).params
    alpha_arr.append(alpha)

In [73]:
jnp.array(alpha_arr)

DeviceArray([0.91441193, 0.58803711, 0.48737946, 0.43552551, 0.40197296,
             0.37909622, 0.36079483, 0.34706879, 0.33563042, 0.32571716,
             0.31809158, 0.310466  , 0.30436554, 0.29902763, 0.29368973,
             0.28987694, 0.28606415, 0.2814888 , 0.27843857, 0.27538834,
             0.2723381 , 0.26928787, 0.2670002 , 0.26471252, 0.26242485,
             0.26013718, 0.25861206, 0.25632439, 0.25479927, 0.2525116 ,
             0.25098648, 0.24946136, 0.24793625, 0.24641113, 0.24564857,
             0.24412346, 0.24259834, 0.24183578, 0.24031067, 0.23954811,
             0.23802299, 0.23726044, 0.23573532, 0.23497276, 0.23421021,
             0.23344765, 0.23268509, 0.23115997, 0.23039742, 0.22963486,
             0.2288723 , 0.22810974, 0.22734718, 0.22658463, 0.22582207,
             0.22505951, 0.22429695, 0.22353439, 0.22353439, 0.22277184,
             0.22200928, 0.22124672, 0.22048416, 0.22048416, 0.2197216 ,
             0.21895905, 0.21819649, 0.21819649, 0.

In [None]:
y = grad * a_tmp + b

0 = grad * alpha + b
b = 

y = alpha_grad_tmp * a_tmp + power_aux_tmp + 1 - alpha_grad_tmp * alpha_tmp

In [118]:
alpha_grad_tmp

DeviceArray(0.6647264, dtype=float64, weak_type=True)

In [119]:
alpha_tmp

DeviceArray(0.37178672, dtype=float64)

In [152]:
a_range = jnp.linspace(0.01, 5, 100)
R_tmp = 10.0
power_aux_tmp = 0.0
alpha_tmp = find_alpha(R_tmp, power_aux_tmp)
dalphadP = jax.grad(find_alpha, 1)(R_tmp, power_aux_tmp)
dalphadR = jax.grad(find_alpha, 0)(R_tmp, power_aux_tmp)

fig, ax = plt.subplots(1, 1, figsize=(5, 4))
ax.plot(a_range, fun_to_solve(a_range, R_tmp, power_aux_tmp), label="P_aux(Alpha)")
ax.vlines(alpha_tmp, -1, 5, color='green')
ax.hlines(0, 0, 5)
ax.plot(a_range, 1/dalphadP * (a_range - alpha_tmp), color='orange', label="1/(dAlpha / dP)")
ax.set_ylabel("Aux power in")
ax.set_xlabel("Alpha value")
ax.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7fc5126e7e50>

## From alpha calculate Ti

In [144]:
def H(a):
    return a * gamma_0_j(a) / e ** -a

In [145]:
def Ti_given_alpha(a):
    return ((e ** (-a) / gamma_0_j(a) - a))*2/3

In [29]:
Ti_given_alpha(alpha)

DeviceArray(0.36333374, dtype=float64)

In [147]:
a_tmp = np.linspace(0, 1, 100)

plt.figure()
plt.plot(a_tmp, Ti_given_alpha(a))
plt.plot(a_tmp, H(a))

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fc4b1e9c100>]

## From alpha calculate Te

In [148]:
def Te_given_alpha(a, R):
    return (((e ** (-a) / gamma_0_j(a) - a)) * a ** 2 * ln_j(R) ** 2/(22.4) ** 2 * 2/3) ** (1/3)

In [44]:
alpha_arr

DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float64)

In [50]:
Te_given_alpha(jnp.array(alpha_arr), R)

DeviceArray([0.09279669, 0.10220298, 0.10611106, 0.10891456, 0.11168403,
             0.11333567], dtype=float64)

In [32]:
(3/2 * alpha / 22.4)**(2/3)  # This is the factor in front ln(R)^(2/3)

DeviceArray(0.08528021, dtype=float64)

For power_aux = 0 and R = 10, the $T_i \sim 0.363$, and $T_e \sim 0.105 \sim 0.085 (\ln{R})^{2/3}$

Demonstration of the $T_e$ equation in Cary's spreadsheet vs solving for alpha for each mirror ratio case:

In [149]:
plt.figure()
plt.plot(R, Te_given_alpha(jnp.array(alpha_arr), R), label="Calculated Te")
plt.plot(R, 0.089 * (ln(R)/ln(10))**(2/5), label="0.089 log_10(R)^(.4)")
plt.legend()
plt.ylabel('Te/E_beam')
plt.xlabel('Mirror ratio')

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Mirror ratio')

## Given Ti calculate alpha

In [23]:
def Ti_to_solve(a):
    return ((e ** (-a) / gamma_0_j(a) - a)) - 0.6

In [24]:
a = jnp.linspace(0.01, 20, 100)

plt.figure()
plt.plot(a, Ti_to_solve(a))

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fc518d460b0>]

In [25]:
bisec_Ti = Bisection(optimality_fun=Ti_to_solve, lower=0.1, upper=200.0, tol=1e-2)
result_Ti = bisec_Ti.run().params
result_Ti

DeviceArray(0.58803711, dtype=float64)

DeviceArray(0.08474705, dtype=float64)