In [None]:
# ------------------------------------------------------------------------
#
# TITLE - jeans_perturbations.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstring:
'''Test the application of the Jeans equation to mock data from DFs under the 
influence of various perturbations
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy
from astropy import units as apu

## Matplotlib
import matplotlib
from matplotlib import pyplot as plt

## Galpy
from galpy import orbit, potential, df

sys.path.insert(0,'../../src/')
from tng_dfs import util as putil
from tng_dfs import plot as pplot
from tng_dfs import kinematics as pkin

### Notebook setup
%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','RO','VO','ZO','LITTLE_H']
data_dir,ro,vo,zo,h = putil.parse_config_dict(cdict,keywords)

# Figure directory
fig_dir = './fig/jeans_perturbation_tests/'
os.makedirs(fig_dir,exist_ok=True)
show_figs = False

### First prepare some DFs and draw samples

### Now calculate the Jeans quantity

The Jeans equation
$\frac{\mathrm{d} (\nu\,\overline{v^2_r})}{\mathrm{d} r} +\,\nu\,
\left(\frac{\mathrm{d} \Phi}{\mathrm{d} r}+
\frac{2\overline{v_r^2}-\overline{v_\theta^2}-\overline{v_\phi^2}}{r}\right)= 0$

This equation has units of 

$J \equiv [\ell]^{-4}[v]^{2}$

And so one way of normalizing if not using real units is to divide by 

$\mathrm{ro}^{-4}\mathrm{vo}^{2}$

In [None]:
def offset_orbits_cartesian(orbs,vec):
    '''offset_orbits_cartesian:
    
    Offset orbits by a vector in cartesian coordinates
    
    Args:
        orbs (Orbits) - Orbits object containing orbits to offset
        vec (np.ndarray) - 6-vector of [x,y,z,vx,vy,vz] to offset orbits by
             in [kpc...,km/s...]
    
    Returns:
        orbs_offset (Orbits) - Orbits object containing offset orbits
    '''
    ro = orbs._ro
    vo = orbs._vo

    # Get the cartesian positions and velocities
    xs = orbs.x(use_physical=True).to(apu.kpc)
    ys = orbs.y(use_physical=True).to(apu.kpc)
    zs = orbs.z(use_physical=True).to(apu.kpc)
    vxs = orbs.vx(use_physical=True).to(apu.km/apu.s)
    vys = orbs.vy(use_physical=True).to(apu.km/apu.s)
    vzs = orbs.vz(use_physical=True).to(apu.km/apu.s)

    # Offset the orbits
    xs += vec[0]
    ys += vec[1]
    zs += vec[2]
    vxs += vec[3]
    vys += vec[4]
    vzs += vec[5]

    # Create a new Orbits object with the offset orbits
    Rs = (np.sqrt(xs**2+ys**2)).to(apu.kpc)
    phis = (np.arctan2(ys,xs)).to(apu.rad)
    vRs = ((xs*vxs+ys*vys)/Rs).to(apu.km/apu.s)
    vTs = ((xs*vys-ys*vxs)/Rs).to(apu.km/apu.s)

    vxvv = np.vstack((Rs.value/ro,
                      vRs.value/vo,
                      vTs.value/vo,
                      zs.value/ro,
                      vzs.value/vo,
                      phis.value)).T
    
    return orbit.Orbit(vxvv=vxvv,ro=ro,vo=vo)

def plot_6d_cartesian_phase_space(orbs,xrange=[-20,20],vrange=[-200,200]):
    '''plot_6d_phase_space:

    '''
    # Get the cartesian positions and velocities
    xs = orbs.x(use_physical=True).to(apu.kpc).value
    ys = orbs.y(use_physical=True).to(apu.kpc).value
    zs = orbs.z(use_physical=True).to(apu.kpc).value
    vxs = orbs.vx(use_physical=True).to(apu.km/apu.s).value
    vys = orbs.vy(use_physical=True).to(apu.km/apu.s).value
    vzs = orbs.vz(use_physical=True).to(apu.km/apu.s).value

    s=1
    alpha=0.1
    fontsize=10

    fig,axes = plt.subplots(2,3,figsize=(6,4))
    axes = axes.flatten()

    axes[0].scatter(xs,ys,s=s,alpha=alpha)
    axes[0].set_xlabel(r'$x$ [kpc]',fontsize=fontsize)
    axes[0].set_ylabel(r'$y$ [kpc]',fontsize=fontsize)
    axes[0].set_xlim(xrange)
    axes[0].set_ylim(xrange)

    axes[1].scatter(xs,zs,s=s,alpha=alpha)
    axes[1].set_xlabel(r'$x$ [kpc]',fontsize=fontsize)
    axes[1].set_ylabel(r'$z$ [kpc]',fontsize=fontsize)
    axes[1].set_xlim(xrange)
    axes[1].set_ylim(xrange)

    axes[2].scatter(ys,zs,s=s,alpha=alpha)
    axes[2].set_xlabel(r'$y$ [kpc]',fontsize=fontsize)
    axes[2].set_ylabel(r'$z$ [kpc]',fontsize=fontsize)
    axes[2].set_xlim(xrange)
    axes[2].set_ylim(xrange)
    
    axes[3].scatter(vxs,vys,s=s,alpha=alpha)
    axes[3].set_xlabel(r'$v_x$ [km/s]',fontsize=fontsize)
    axes[3].set_ylabel(r'$v_y$ [km/s]',fontsize=fontsize)
    axes[3].set_xlim(vrange)
    axes[3].set_ylim(vrange)

    axes[4].scatter(vxs,vzs,s=s,alpha=alpha)
    axes[4].set_xlabel(r'$v_x$ [km/s]',fontsize=fontsize)
    axes[4].set_ylabel(r'$v_z$ [km/s]',fontsize=fontsize)
    axes[4].set_xlim(vrange)
    axes[4].set_ylim(vrange)
    
    axes[5].scatter(vys,vzs,s=s,alpha=alpha)
    axes[5].set_xlabel(r'$v_y$ [km/s]',fontsize=fontsize)
    axes[5].set_ylabel(r'$v_z$ [km/s]',fontsize=fontsize)
    axes[5].set_xlim(vrange)
    axes[5].set_ylim(vrange)

    for ax in axes:
        ax.set_aspect('equal')
        ax.axhline(0,color='k',ls='--')
        ax.axvline(0,color='k',ls='--')
        ax.tick_params(labelsize=fontsize)

    plt.tight_layout()
    return fig,axes

### Start with a Hernquist embedded in another Hernquist

In [None]:
n_orbs = int(1e4)

pot = potential.HernquistPotential(amp=1e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
denspot = potential.HernquistPotential(amp=1.,a=5*apu.kpc,ro=ro,vo=vo)
edf = df.eddingtondf(pot,denspot=denspot,ro=ro,vo=vo)
orbs = edf.sample(n=n_orbs,return_orbit=True)

In [None]:
# Create the offset vector
roff = 10*apu.kpc
voff = potential.vcirc(pot,roff)
offset = [roff,0,0,0,voff/2.,0]
orbs_offset = offset_orbits_cartesian(orbs,offset)

In [None]:
fig,axs = plot_6d_cartesian_phase_space(orbs,xrange=[-20,20],vrange=[-200,200])
fig.savefig(fig_dir+'test_1.0_no_offset_xy.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
fig,axs = plot_6d_cartesian_phase_space(orbs_offset,xrange=[-20,20],vrange=[-300,300])
fig.savefig(fig_dir+'test_1.1_offset_xy.png',dpi=300)
if not show_figs:
    plt.close()

### Now integrate some orbits

In [None]:
tdyn = potential.tdyn(pot,pot.a)
tmax = 5*tdyn
nt = 50
ts = np.linspace(0,tmax.value,nt)*tmax.unit

# Integrate the orbits
orbs.integrate(ts,pot)
orbs_offset.integrate(ts,pot)

In [None]:
# Show the time evolution of the unperturbed orbits
tplot = [ts[0],ts[-1]]
tplot_title = [r'$t=0$',r'$t=5t_{dyn}$']
fig_names = ['test_1.2_no_offset_td0_xy.png','test_1.3_no_offset_td5_xy.png']

for i in range(len(tplot)):
    fig,axs = plot_6d_cartesian_phase_space(orbs(tplot[i]),xrange=[-50,50],
        vrange=[-300,300])
    fig.suptitle('unperturbed orbits '+tplot_title[i],fontsize=12)
    fig.savefig(fig_dir+fig_names[i],dpi=300)
    if not show_figs:
        plt.close()

In [None]:
# Show the time evolution of the perturbed orbits
tplot = [ts[0],ts[nt//10],ts[nt//5],ts[-1]]
tplot_title = [r'$t=0$',r'$t=t_{dyn}/2$',r'$t=t_{dyn}$',r'$t=5t_{dyn}$']
fig_names = ['test_1.4_offset_td0_xy.png',
             'test_1.5_offset_td0.5_xy.png',
             'test_1.6_offset_td1_xy.png',
             'test_1.7_offset_td5_xy.png']

for i in range(len(tplot)):
    fig,axs = plot_6d_cartesian_phase_space(orbs_offset(tplot[i]),xrange=[-50,50],
        vrange=[-300,300])
    fig.suptitle('perturbed orbits '+tplot_title[i],fontsize=12)
    fig.savefig(fig_dir+fig_names[i],dpi=300)
    if not show_figs:
        plt.close()

### Plot the Jeans quantities and diagnostics for a series of times

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(orbs(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:(nt//5)],rs,qs[:,0:(nt//5),:],
    adf=edf,r_range=samp_r_range)
fig.suptitle(r'unperturbed $t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_1.8_no_offset_td0-td1_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[int(nt-nt//5):],rs,
    qs[:,int(nt-nt//5):,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'unperturbed $t: 4t_{dyn} \rightarrow 5t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_1.9_no_offset_td4-td5_jeans.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure()
ax = fig.add_subplot(111)

ax.plot(ts/tdyn,Js_weighted)
ax.set_xlabel(r'$t/t_{dyn}$')
ax.set_ylabel(r'weighted $\langle J \rangle$')
fig.suptitle('unperturbed orbits')
fig.savefig(fig_dir+'test_1.10_no_offset_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(orbs_offset(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:(nt//5)],rs,qs[:,0:(nt//5),:],
    adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbed orbits $t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_1.11_offset_td0-td1_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[int(nt-nt//5):],rs,
    qs[:,int(nt-nt//5):,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbed orbits $t: 4t_{dyn} \rightarrow 5t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_1.12_offset_td4-td5_jeans.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure()
ax = fig.add_subplot(111)

ax.plot(ts/tdyn,Js_weighted)
ax.set_xlabel(r'$t/t_{dyn}$')
ax.set_ylabel(r'weighted $\langle J \rangle$')
fig.savefig(fig_dir+'test_1.13_offset_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()

## Try a different perturbation - sample orbits in equilibrium, then suddenly increase the mass of the potential

In [None]:
n_orbs = int(1e4)
pot = potential.HernquistPotential(amp=1e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
pot_perturb = potential.HernquistPotential(amp=4e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
denspot = potential.HernquistPotential(amp=1.,a=5*apu.kpc,ro=ro,vo=vo)
edf = df.eddingtondf(pot,denspot=denspot,ro=ro,vo=vo)
orbs = edf.sample(n=n_orbs,return_orbit=True)

print(potential.tdyn(pot,pot.a))
print(potential.tdyn(pot_perturb,pot_perturb.a))

In [None]:
tdyn = potential.tdyn(pot,pot.a)
tmax = 2*tdyn
nt = 20
ts = np.linspace(0,tmax.value,nt)*tmax.unit

# Integrate the orbits
orbs.integrate(ts,pot)
orbs_perturb = orbs(ts[-1])
orbs_perturb.integrate(ts,pot_perturb)

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(orbs(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

Js_perturb = np.zeros((nt,nbins))
qs_perturb = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(orbs_perturb(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js_perturb[i,:] = _J
    for j in range(len(_qs)):
        qs_perturb[j,i,:] = _qs[j]

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:nt//2],rs,qs[:,0:nt//2,:],
    adf=edf,r_range=samp_r_range)
fig.suptitle(r'unperturbed $t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_2.0_unperturbed_td0-td1_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js_perturb[0:nt//2],rs,
    qs_perturb[:,0:nt//2,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbed $t: 0 \rightarrow 2t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_2.1_perturbed_td0-td2_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js_perturb[(nt//2):],rs,
    qs_perturb[:,(nt//2):,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbed $t: 2t_{dyn} \rightarrow 4t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_2.2_perturbed_td2-td4_jeans.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)
Js_weighted_perturb = np.zeros(nt)
for i in range(nt):
    Js_weighted_perturb[i] = np.sum(np.abs(Js_perturb[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure()
ax = fig.add_subplot(111)

ax.plot(ts/tdyn,Js_weighted,color='DodgerBlue')
ax.plot(2*ts/tdyn+(ts/tdyn)[-1],Js_weighted_perturb,color='DarkOrange')
ax.set_xlabel(r'$t/t_{dyn}$')
ax.set_ylabel(r'weighted $\langle J \rangle$')
fig.savefig('test_2.3_perturbed_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()

### Do the same experiment but now use the Dehnen Wrapper to make the transformation happen smoothly

In [None]:
n_orbs = int(1e4)
pot = potential.HernquistPotential(amp=1e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
tdyn_init = potential.tdyn(pot, pot.a)
pot_perturb = potential.HernquistPotential(amp=4e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
pot_perturb_diff = potential.HernquistPotential(amp=3e12*apu.M_sun, a=20*apu.kpc, ro=ro, vo=vo)
denspot = potential.HernquistPotential(amp=1.,a=10*apu.kpc,ro=ro,vo=vo)
edf = df.eddingtondf(pot,denspot=denspot,ro=ro,vo=vo)
orbs = edf.sample(n=n_orbs,return_orbit=True)

print(potential.tdyn(pot,pot.a))
print(potential.tdyn(pot_perturb,pot_perturb.a))

In [None]:
# Create the steadily evolving potential
tform = 1*tdyn_init
tsteady = 5*tdyn_init
pot_grow = potential.DehnenSmoothWrapperPotential(pot=pot_perturb_diff, 
    tform=tform, tsteady=tsteady, decay=False)

In [None]:
def dehnen_amp(t,tform,tsteady,decay=False):
    if t < tform:
        xi = -1
    elif t >= tform and t <= (tform+tsteady):
        xi = 2*((t-tform)/tsteady)-1
    elif t > (tform+tsteady):
        xi = 1
    else:
        print('Something went wrong')
    amp = 3*xi**5/16 - 5*xi**3/8 + 15*xi/16 + 0.5
    if decay:
        return 1-amp
    else:
        return amp

fig = plt.figure()
ax = fig.add_subplot(111)

ts_plot = np.linspace(0.,1.,num=100)*6*tdyn_init.value
amp = np.array([dehnen_amp(t,tform.value,tsteady.value,decay=False) for t in ts_plot])

ax.plot(ts_plot/tdyn_init.value, amp)
ax.set_xlabel('time [initial dynamical]')
ax.set_ylabel('Amp')
fig.savefig('test_3.0_dehnen_amp.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
nt = 600
ts = np.linspace(0., 6*tdyn_init.value, num=nt)*apu.Gyr
orbs.integrate(ts, [pot,pot_grow])

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(orbs(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:nt//6],rs,qs[:,0:nt//6,:],
    adf=edf,r_range=samp_r_range)
fig.suptitle(r'unperturbed $t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_3.1_unperturbed_td0-td1_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[nt//6:(3*nt)//6],rs,
    qs[:,nt//6:(3*nt)//6,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation beginning $t: t_{dyn} \rightarrow 3t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_3.2_perturbed_td1-td3_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[(4*nt)//6:],rs,
    qs[:,(4*nt)//6:,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation ending $t: 4t_{dyn} \rightarrow 6t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_3.3_perturbed_td4-td6_jeans.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure(figsize=(6,4))
gs = matplotlib.gridspec.GridSpec(nrows=4, ncols=1, figure=fig)
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1:])

amp = np.array([dehnen_amp(t,tform,tsteady,decay=False) for t in ts])
ax1.plot(ts/tdyn_init, amp, color='DodgerBlue')
ax1.set_ylabel('amp')
ax1.tick_params(labelbottom=False)

ax2.plot(ts/tdyn_init,Js_weighted,color='DodgerBlue')
ax2.set_xlabel(r'$t/t_{dyn}$')
ax2.set_ylabel(r'weighted $\langle J \rangle$')
fig.savefig(fig_dir+'test_3.4_perturbed_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()

### Repeat that experiment but evolve the system for much longer to see if weighted J will decrease

In [None]:
n_orbs = int(1e4)
pot = potential.HernquistPotential(amp=1e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
tdyn_init = potential.tdyn(pot, pot.a)
pot_perturb = potential.HernquistPotential(amp=4e12*apu.M_sun,a=20*apu.kpc,ro=ro,vo=vo)
pot_perturb_diff = potential.HernquistPotential(amp=3e12*apu.M_sun, a=20*apu.kpc, ro=ro, vo=vo)
denspot = potential.HernquistPotential(amp=1.,a=10*apu.kpc,ro=ro,vo=vo)
edf = df.eddingtondf(pot,denspot=denspot,ro=ro,vo=vo)
orbs = edf.sample(n=n_orbs,return_orbit=True)

print(potential.tdyn(pot,pot.a))
print(potential.tdyn(pot_perturb,pot_perturb.a))

In [None]:
# Create the steadily evolving potential
tform = 1*tdyn_init
tsteady = 5*tdyn_init
pot_grow = potential.DehnenSmoothWrapperPotential(pot=pot_perturb_diff, 
    tform=tform, tsteady=tsteady, decay=False)

In [None]:
nt = 1000
ts = np.linspace(0., 100*tdyn_init.value, num=nt)*apu.Gyr
orbs.integrate(ts, [pot,pot_grow])

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(orbs(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

In [None]:
Js.shape

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:nt//100],rs,qs[:,0:nt//100,:],
    adf=edf,r_range=samp_r_range)
fig.suptitle(r'unperturbed $t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_4.0_unperturbed_td0-td1_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[nt//100:(3*nt)//100],rs,
    qs[:,nt//100:(3*nt)//100,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation beginning $t: t_{dyn} \rightarrow 3t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_4.1_perturbed_td1-td3_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[(4*nt)//100:(6*nt)//100],rs,
    qs[:,(4*nt)//100:(6*nt)//100,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation ending $t: 4t_{dyn} \rightarrow 6t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_4.2_perturbed_td4-td6_jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[(90*nt)//100:],rs,
    qs[:,(90*nt)//100:,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation late time $t: 20t_{dyn} \rightarrow 25t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_4.3_perturbed_td90-td100_jeans.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure(figsize=(6,4))
gs = matplotlib.gridspec.GridSpec(nrows=4, ncols=1, figure=fig)
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1:])

amp = np.array([dehnen_amp(t,tform,tsteady,decay=False) for t in ts])
ax1.plot(ts/tdyn_init, amp, color='DodgerBlue')
ax1.set_ylabel('amp')
ax1.tick_params(labelbottom=False)

ax2.plot(ts/tdyn_init,Js_weighted,color='DodgerBlue')
ax2.set_xlabel(r'$t/t_{dyn}$')
ax2.set_ylabel(r'weighted $\langle J \rangle$')
fig.show()
fig.savefig(fig_dir+'test_4.4_perturbed_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()

## Now do a complicated experiment where you embed an equilibrium Hernquist model in a moving object potential which is being accreted

In [None]:
## First try and create a moving object potential which experiences dynamical friction

# Host potential
pot_mass = 1e12*apu.M_sun
pot_scale = 20*apu.kpc
pot = potential.HernquistPotential(amp=2*pot_mass, a=pot_scale, ro=ro,vo=vo) # host potential

# Satellite potential
sat_mass = 1e11*apu.M_sun
sat_scale = 10*apu.kpc
sat_rhm = (1+np.sqrt(2))*sat_scale
satpot = potential.HernquistPotential(amp=2*sat_mass, a=sat_scale, ro=ro, vo=vo)

# Satellite density potential
sat_dens_scale = 5*apu.kpc
satdenspot = potential.HernquistPotential(amp=1., a=sat_dens_scale, ro=ro, vo=vo)
satdf = df.eddingtondf(pot=satpot, denspot=satdenspot, ro=ro, vo=vo)

# The dynamical friction potential
dynfricpot = potential.ChandrasekharDynamicalFrictionForce(GMs=sat_mass, rhm=sat_rhm, dens=pot, ro=ro, vo=vo)

In [None]:
# Start the orbit at twice the scale radius and let it evolve
R_start = 2*pot_scale
vcirc_start = potential.vcirc(pot, R_start)
tdyn_start = potential.tdyn(pot, R_start)
vxvv = [R_start.value/ro, 0., vcirc_start.value/vo, 0., 0., 0.]
o_dynfric = orbit.Orbit(vxvv, ro=ro, vo=vo)

nt = 1000
ts = np.linspace(0., 3*tdyn_start.value, num=nt)*apu.Gyr
o_dynfric.integrate(ts,[pot,dynfricpot])

satmop = potential.MovingObjectPotential(o_dynfric, pot=satpot, ro=ro, vo=vo)

In [None]:
fig = plt.figure()

ax = fig.add_subplot(111)
ax.plot(o_dynfric.x(ts), o_dynfric.y(ts))

ax.set_xlabel('X [kpc]')
ax.set_ylabel('Y [kpc]')
ax.set_xlim(-45,45)
ax.set_ylim(-45,45)
fig.savefig(fig_dir+'test_5.0_dynamical_friction_orbit_xy.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
fig = plt.figure()

ax = fig.add_subplot(111)
ax.plot(ts/tdyn_start, o_dynfric.r(ts))

ax.set_xlabel('t [dynamical]')
ax.set_ylabel('r [kpc]')
ax.set_xlim(0,3)
ax.set_ylim(0,45)
fig.savefig(fig_dir+'test_5.1_dynamical_friction_orbit_rs.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
# Create and offset the satellite orbits
satorbs = satdf.sample(n=10000)
offset_vec = [R_start,0.,0.,0.,vcirc_start,0.]
satorbs = offset_orbits_cartesian(satorbs, offset_vec)

nt = 600
ts = np.linspace(0., 3*tdyn_start.value, num=nt)*apu.Gyr
satorbs.integrate(ts, [satmop,pot])

In [None]:
fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=3).flatten()
n_axs = 6

ts_indx = [0, 100, 200, 300, 400, 500]
ts_strs = [r'$t=0$', r'$t=0.5t_{dyn}$', r'$t=t_{dyn}$', r'$t=1.5t_{dyn}$', r'$t=2t_{dyn}$', r'$t=2.5t_{dyn}$']

for i in range(n_axs):
    ax = axs[i]
    _t = ts[ts_indx[i]]
    
    ax.scatter(satorbs.x(_t), satorbs.y(_t), alpha=0.1, s=1., zorder=2, 
               color='Black')
    ax.plot(o_dynfric.x(ts), o_dynfric.y(ts), alpha=0.5, linewidth=1., 
            color='DodgerBlue', zorder=3)
    ax.scatter(o_dynfric.x(_t), o_dynfric.y(_t), s=16, marker='x', 
               color='DodgerBlue', zorder=4)
    
    ax.set_xlabel('X [kpc]')
    ax.set_ylabel('Y [kpc]')
    ax.set_xlim(-45,45)
    ax.set_ylim(-45,45)
    
    ax.set_title(ts_strs[i])

fig.tight_layout()
fig.savefig(fig_dir+'test_5.2_satellite_orbit_xy.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=3).flatten()
n_axs = 6

ts_indx = [0, 100, 200, 300, 400, 500]
ts_strs = [r'$t=0$', r'$t=0.5t_{dyn}$', r'$t=t_{dyn}$', r'$t=1.5t_{dyn}$', r'$t=2t_{dyn}$', r'$t=2.5t_{dyn}$']

for i in range(n_axs):
    ax = axs[i]
    _t = ts[ts_indx[i]]
    
    ax.hist(satorbs.r(_t).value, bins=25, range=(0,100), density=True, 
        histtype='step')
    
    ax.set_xlabel('r [kpc]')
    ax.set_ylabel('p(r)')
    ax.set_xlim(0,100)
    ax.set_ylim(0,0.075)
    
    ax.set_title(ts_strs[i])

fig.tight_layout()
fig.savefig(fig_dir+'test_5.3_satellite_orbit_rs_hist.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=3).flatten()
n_axs = 6

ts_indx = [0, 100, 200, 300, 400, 500]
ts_strs = [r'$t=0$', r'$t=0.5t_{dyn}$', r'$t=t_{dyn}$', r'$t=1.5t_{dyn}$', r'$t=2t_{dyn}$', r'$t=2.5t_{dyn}$']

for i in range(n_axs):
    ax = axs[i]
    _t = ts[ts_indx[i]]
    
    hist,bin_edges = np.histogram(satorbs.r(_t).value, bins=25, range=(0,100))
    bin_cents = (bin_edges[1:]+bin_edges[:-1])/2
    shell_vol = (4*np.pi/3)*(bin_edges[1:]**3-bin_edges[:-1])
    dens = hist/shell_vol
    
    ax.plot(bin_cents, dens, color='DodgerBlue')
    
    # Comparison with fiducial power laws
    alphas = [0.5,1.5,2.5,3.5,4.5,]
    rs = np.linspace(bin_cents[0],100,num=101)*apu.kpc
    for alpha in alphas:
        fpot = potential.PowerSphericalPotential(alpha=alpha)
        fdens = potential.evaluateDensities(fpot,rs,0)
        ax.plot(rs.value, fdens*(dens[0]/fdens[0]), color='Red', linestyle='dashed', 
                linewidth=0.5, alpha=0.5)
        
    
    ax.set_xlabel('r [kpc]')
    ax.set_ylabel('number density')
    ax.set_xlim(1,100)
    # ax.set_ylim(0,0.075)
    ax.set_yscale('log')
    ax.set_xscale('log')
    
    ax.set_title(ts_strs[i])

fig.tight_layout()
fig.savefig(fig_dir+'test_5.4_satellite_orbit_dens.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(satorbs(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:nt//3],rs,qs[:,0:nt//3,:],
    adf=None,r_range=samp_r_range)
fig.suptitle(r'$t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_5.5_t0-t1_Jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[nt//3:(2*nt)//3],rs,
    qs[:,nt//3:(2*nt)//3,:],adf=None,r_range=samp_r_range)
fig.suptitle(r'$t: t_{dyn} \rightarrow 2t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_5.6_t1-t2_Jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[(2*nt)//3:],rs,
    qs[:,(2*nt)//3:,:],adf=None,r_range=samp_r_range)
fig.suptitle(r'$t: 2t_{dyn} \rightarrow 3t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_5.7_t2-t3_Jeans.png',dpi=300)
if not show_figs:
    plt.close()


In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure()#figsize=(6,4))
# gs = matplotlib.gridspec.GridSpec(nrows=4, ncols=1, figure=fig)
#ax1 = fig.add_subplot(gs[0])
#ax2 = fig.add_subplot(gs[1:])
ax2 = fig.add_subplot(111)

# amp = np.array([dehnen_amp(t,tform,tsteady,decay=False) for t in ts])
# ax1.plot(ts/tdyn_start, amp, color='DodgerBlue')
# ax1.set_ylabel('amp')
# ax1.tick_params(labelbottom=False)

ax2.plot(ts/tdyn_start,Js_weighted,color='DodgerBlue')
ax2.set_xlabel(r'$t/t_{dyn}$')
ax2.set_ylabel(r'weighted $\langle J \rangle$')
fig.savefig(fig_dir+'test_5.8_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()

### Build on the experiment by having the potential dissolve with time

In [None]:
# Create the dissolving satellite potential
tform = 0*tdyn_start
tsteady = 3.*tdyn_start
satmop_dissolve = potential.DehnenSmoothWrapperPotential(pot=satmop, tform=tform, tsteady=tsteady, decay=True)

# Create and offset the satellite orbits
satorbs_dissolve = satdf.sample(n=10000)
offset_vec = [R_start,0.,0.,0.,vcirc_start,0.]
satorbs_dissolve = offset_orbits_cartesian(satorbs_dissolve, offset_vec)

In [None]:
fig = plt.figure()

ax1 = fig.add_subplot(111)
ax2 = ax1.twinx()

amp = np.array([dehnen_amp(t,tform,tsteady,decay=True) for t in ts])
ax1.plot(ts/tdyn_start, amp, color='DodgerBlue')
ax1.set_ylabel('amp')
ax1.set_xlabel('t [dynamical]')
# ax1.tick_params(labelbottom=False)

ax2.plot(ts/tdyn_start, o_dynfric.r(ts))
ax2.set_xlabel('t [dynamical]')
ax2.set_ylabel('r [kpc]')
ax2.set_xlim(0,3)
ax2.set_ylim(0,45)
fig.savefig(fig_dir+'test_6.0_dynamical_friction_orbit_xy.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
nt = 600
ts = np.linspace(0., 3*tdyn_start.value, num=nt)*apu.Gyr
satorbs_dissolve.integrate(ts, [satmop_dissolve,pot])

In [None]:
fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=3).flatten()
n_axs = 6

ts_indx = [0, 100, 200, 300, 400, 500]
ts_strs = [r'$t=0$', r'$t=0.5t_{dyn}$', r'$t=t_{dyn}$', r'$t=1.5t_{dyn}$', r'$t=2t_{dyn}$', r'$t=2.5t_{dyn}$']

for i in range(n_axs):
    ax = axs[i]
    _t = ts[ts_indx[i]]
    
    ax.scatter(satorbs_dissolve.x(_t), satorbs_dissolve.y(_t), alpha=0.1, s=1., zorder=2, color='Black')
    ax.plot(o_dynfric.x(ts), o_dynfric.y(ts), alpha=0.5, linewidth=1., color='DodgerBlue', zorder=3)
    ax.scatter(o_dynfric.x(_t), o_dynfric.y(_t), marker='x', s=16, color='DodgerBlue', zorder=4)
    
    ax.set_xlabel('X [kpc]')
    ax.set_ylabel('Y [kpc]')
    ax.set_xlim(-45,45)
    ax.set_ylim(-45,45)
    
    ax.set_title(ts_strs[i])

fig.tight_layout()
fig.savefig(fig_dir+'test_6.1_dynamical_friction_orbit_xy.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=3).flatten()
n_axs = 6

ts_indx = [0, 100, 200, 300, 400, 500]
ts_strs = [r'$t=0$', r'$t=0.5t_{dyn}$', r'$t=t_{dyn}$', r'$t=1.5t_{dyn}$', r'$t=2t_{dyn}$', r'$t=2.5t_{dyn}$']

for i in range(n_axs):
    ax = axs[i]
    _t = ts[ts_indx[i]]
    
    ax.hist(satorbs_dissolve.r(_t).value, bins=25, range=(0,100), density=True, 
        histtype='step')
    
    ax.set_xlabel('r [kpc]')
    ax.set_ylabel('p(r)')
    ax.set_xlim(0,100)
    ax.set_ylim(0,0.075)
    
    ax.set_title(ts_strs[i])

fig.tight_layout()
fig.savefig(fig_dir+'test_6.2_satellite_orbit_rs_hist.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=3).flatten()
n_axs = 6

ts_indx = [0, 100, 200, 300, 400, 500]
ts_strs = [r'$t=0$', r'$t=0.5t_{dyn}$', r'$t=t_{dyn}$', r'$t=1.5t_{dyn}$', r'$t=2t_{dyn}$', r'$t=2.5t_{dyn}$']

for i in range(n_axs):
    ax = axs[i]
    _t = ts[ts_indx[i]]
    
    hist,bin_edges = np.histogram(satorbs_dissolve.r(_t).value, bins=25, range=(0,100))
    bin_cents = (bin_edges[1:]+bin_edges[:-1])/2
    shell_vol = (4*np.pi/3)*(bin_edges[1:]**3-bin_edges[:-1])
    dens = hist/shell_vol
    
    ax.plot(bin_cents, dens, color='DodgerBlue')
    
    # Comparison with fiducial power laws
    alphas = [0.5,1.5,2.5,3.5,4.5,]
    rs = np.linspace(bin_cents[0],100,num=101)*apu.kpc
    for alpha in alphas:
        fpot = potential.PowerSphericalPotential(alpha=alpha)
        fdens = potential.evaluateDensities(fpot,rs,0)
        ax.plot(rs.value, fdens*(dens[0]/fdens[0]), color='Red', linestyle='dashed', 
                linewidth=0.5, alpha=0.5)
        
    
    ax.set_xlabel('r [kpc]')
    ax.set_ylabel('number density')
    ax.set_xlim(1,100)
    # ax.set_ylim(0,0.075)
    ax.set_yscale('log')
    ax.set_xscale('log')
    
    ax.set_title(ts_strs[i])

fig.tight_layout()
fig.savefig(fig_dir+'test_6.3_satellite_orbit_dens.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
nbins = 10
bin_r_range = [1,50] # The range of radii to use for the bins
samp_r_range = [bin_r_range[0]/2,bin_r_range[1]*4] # Sampling range
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

Js = np.zeros((nt,nbins))
qs = np.zeros((7,nt,nbins)) # dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs

for i in range(nt):
    _J,rs,_qs = pkin.calculate_spherical_jeans(satorbs_dissolve(ts[i]), pot, 
        r_range=bin_r_range, n_bin=nbins, norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)
    Js[i,:] = _J
    for j in range(len(_qs)):
        qs[j,i,:] = _qs[j]

In [None]:
# for i in range(len(tplot)):
fig,axs = pplot.plot_jeans_diagnostics(Js[0:nt//6],rs,qs[:,0:nt//6,:],
    adf=edf,r_range=samp_r_range)
fig.suptitle(r'unperturbed $t: 0 \rightarrow t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_6.4_t0-t1_Jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[nt//6:(3*nt)//6],rs,
    qs[:,nt//6:(3*nt)//6,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation beginning $t: t_{dyn} \rightarrow 3t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_6.5_t0-t3_Jeans.png',dpi=300)
if not show_figs:
    plt.close()

fig,axs = pplot.plot_jeans_diagnostics(Js[(4*nt)//6:],rs,
    qs[:,(4*nt)//6:,:],adf=edf,r_range=samp_r_range)
fig.suptitle(r'perturbation ending $t: 4t_{dyn} \rightarrow 6t_{dyn}$',fontsize=12)
fig.tight_layout()
fig.savefig(fig_dir+'test_6.6_t4-t6_Jeans.png',dpi=300)
if not show_figs:
    plt.close()

In [None]:
Js_weighted = np.zeros(nt)
for i in range(nt):
    Js_weighted[i] = np.sum(np.abs(Js[i,:]*rs**2))/np.sum(rs**2)

fig = plt.figure(figsize=(6,4))
gs = matplotlib.gridspec.GridSpec(nrows=4, ncols=1, figure=fig)
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1:])

amp = np.array([dehnen_amp(t,tform,tsteady,decay=True) for t in ts])
ax1.plot(ts/tdyn_init, amp, color='DodgerBlue')
ax1.set_ylabel('amp')
ax1.tick_params(labelbottom=False)

ax2.plot(ts/tdyn_init,Js_weighted,color='DodgerBlue')
ax2.set_xlabel(r'$t/t_{dyn}$')
ax2.set_ylabel(r'weighted $\langle J \rangle$')

fig.savefig(fig_dir+'test_6.7_weighted_J.png',dpi=300)
if not show_figs:
    plt.close()