In [1]:
import numpy as np
import matplotlib.pyplot as plt 
import jax
import jax.numpy as jnp
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
                          growth_rate, growth_rate_second, dGfa, Gf)
from jax import tree
from jax import tree
from jaxpm.pm import pm_forces
from diffrax import ConstantStepSize,  SaveAt, diffeqsolve,StepTo
from jaxpm.plotting import plot_fields_single_projection
from jaxpm.painting import cic_paint , cic_paint_dx
import diffrax
#jax.config.update("jax_enable_x64", True)
import jax
jax.print_environment_info()

!nvidia-smi --query-gpu=gpu_name --format=csv,noheader
%matplotlib inline 


jax:    0.5.2
jaxlib: 0.5.1
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
device info: Quadro RTX 6000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='midway3-0284.rcc.local', release='4.18.0-305.3.1.el8.x86_64', version='#1 SMP Tue Jun 1 16:14:33 UTC 2021', machine='x86_64')


$ nvidia-smi
Mon Apr  7 22:14:41 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Quadro RTX 6000                On  | 00

  pid, fd = os.forkpty()


In [2]:
from SuperResPM.configure import Configuration

In [3]:
import matplotlib.pyplot as plt
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('jpeg')


In [4]:
from SuperResPM.diffrax_helper import FPMODE,FPMLeapFrog,symplectic_ode

In [5]:
cosmo_in = [0.3175, 0.049, 0.6711, 0.9624, 0.834]#Omega_m                 Omega_b                  h                        n_s                      sigma_8
import jax_cosmo as jc
from diffrax import ConstantStepSize
from jaxpm.pm import linear_field, lpt

# Tutorial 

In this set of code, I add a fastPM-like leapfrog solver to diffrax. This solver also allows a gradient calculation with the reversible adjoint method.  

The API should be similar to jaxpm with standard diffrax solver. 

Some additonal modification of APT details below:

- This code declares the solver. 
>>>

        ode_fn = tree.map(
                FPMODE,
                symplectic_ode(mesh_shape, paint_absolute_pos=paint_absolute_pos>0)
        )
        solver = FPMLeapFrog(initial_t0=initial_t0, final_t1=final_time)
        

-  In addition to dx and p, we also need the initial force (or acceleration) at the start point. You can get it with pm_force in jaxpm.
   Note that the calculation can be slightly different depending on whether you do absolute position or relative position. 
>>>
        if paint_absolute_pos>0:
            initforce = pm_forces(
                particles+dx,
                mesh_shape=mesh_shape,
                paint_absolute_pos=paint_absolute_pos>0,
            )* 1.5* cosmo.Omega_m
        else:
            initforce = pm_forces(
                dx,
                mesh_shape=mesh_shape,
                paint_absolute_pos=paint_absolute_pos>0,
            )* 1.5* cosmo.Omega_m        
        if paint_absolute_pos>0:
            y0= jnp.stack([particles+dx, p,initforce], axis=0)
        else:
            y0= jnp.stack([dx, p,initforce], axis=0)

- Also, since jax_cosmo is not a pure function, we have to initialize it first. 

  >>>
    cosmo = jc.Cosmology(Omega_c=cosmo_in[0]-cosmo_in[1], Omega_b=cosmo_in[1], h=cosmo_in[2], sigma8 = cosmo_in[4], n_s=cosmo_in[3],
                      Omega_k=0., w0=-1., wa=0.)
    ain=np.atleast_1d(1)
    _ =  growth_rate_second(cosmo,ain)
    _ = growth_rate(cosmo, ain)
    _ = growth_factor(cosmo, ain)
    jc.background.radial_comoving_distance(cosmo, ain)
  >>>

- Finally, we can just do the ordinary diffeqsolve. Note that the args value has two zeroes at the end. Those are for bookkeeping reasons. You can put whatever value you want. 
>>>
        res = diffeqsolve(ode_fn,solver,\
                      t0=initial_t0,\
                      t1=final_time,\
                      dt0=dt0,\
                      y0=y0,
                      args=[cosmo, cosmo._workspace, initial_t0, conf, 0, 0],\
                      saveat=SaveAt(t1=True),
                      stepsize_controller=stepsize_controller,adjoint=diffrax.ReversibleAdjoint())


### Content in this notebook. 
To test the implementation, I did the following tests 
- 1. Running code two times to test reproducibility
  2. Running code forward and backward to test reversibility, which is important for gradient calculation. 
  3. Optimizing the code to produce faspm universe. This is to check that the overall gradient makes sense.
  4. Check against nbosy simulation (Qujote) to make sure that the cosmology calculation is valid.  

# Test reproducibility

In [7]:
def check_reproducibility(
        nmesh,
        da,
        seed=0,
        paint_absolute_pos=0
    ):


    conf = Configuration(mesh_shape=[nmesh, nmesh, nmesh], 
                     snapshots= [1],
                     BoxSize=[256,256,256], 
                     initial_t0=1/64,
                     final_time = 1,
                     density_plane_npix = nmesh,
                     dt0=da,
                     density_plane_width= 100.0
                    )
    cosmo = jc.Cosmology(Omega_c=cosmo_in[0]-cosmo_in[1], Omega_b=cosmo_in[1], h=cosmo_in[2], sigma8 = cosmo_in[4], n_s=cosmo_in[3],
                      Omega_k=0., w0=-1., wa=0.)
    ain=np.atleast_1d(1)
    _ =  growth_rate_second(cosmo,ain)
    _ = growth_rate(cosmo, ain)
    _ = growth_factor(cosmo, ain)
    jc.background.radial_comoving_distance(cosmo, ain)
    
    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in conf.mesh_shape]),axis=-1).reshape([-1,3])
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(cosmo, k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
    # Create initial conditions and particle
    initial_conditions = linear_field(conf.mesh_shape, conf.BoxSize, pk_fn, seed=jax.random.PRNGKey(seed))
    #@jax.jit
    def model(initial_conditions, cosmo, conf, particles,paint_absolute_pos):
        mesh_shape = conf.mesh_shape
        initial_t0 = conf.initial_t0
        final_time = conf.final_time
        snapshots = conf.snapshots
        nmesh = conf.mesh_shape[0]
        dt0 = conf.dt0
        if paint_absolute_pos>0:
            dx, p, f = lpt(cosmo, initial_conditions, particles, a=conf.initial_t0)
        else:
            dx, p, f = lpt(cosmo, initial_conditions, a=conf.initial_t0)
        ode_fn = tree.map(
                FPMODE,
                symplectic_ode(mesh_shape, paint_absolute_pos=paint_absolute_pos>0)
        )
        solver = FPMLeapFrog(initial_t0=initial_t0, final_t1=final_time)

        stepsize_controller = ConstantStepSize()
        if paint_absolute_pos>0:
            initforce = pm_forces(
                particles+dx,
                mesh_shape=mesh_shape,
                paint_absolute_pos=paint_absolute_pos>0,
            )* 1.5* cosmo.Omega_m
        else:
            initforce = pm_forces(
                dx,
                mesh_shape=mesh_shape,
                paint_absolute_pos=paint_absolute_pos>0,
            )* 1.5* cosmo.Omega_m        
        if paint_absolute_pos>0:
            y0= jnp.stack([particles+dx, p,initforce], axis=0)
        else:
            y0= jnp.stack([dx, p,initforce], axis=0)
        res = diffeqsolve(ode_fn,solver,\
                      t0=initial_t0,\
                      t1=final_time,\
                      dt0=dt0,\
                      y0=y0,
                      args=[cosmo, cosmo._workspace, initial_t0, conf, 0, 0],\
                      saveat=SaveAt(t1=True),
                      stepsize_controller=stepsize_controller,adjoint=diffrax.ReversibleAdjoint())
        return res
    p1 = model(initial_conditions, cosmo, conf, particles, paint_absolute_pos).ys
    p2 = model(initial_conditions, cosmo, conf, particles, paint_absolute_pos).ys
    diffpos = np.std(p1[0,0].flatten()-p2[0,0].flatten())/np.std(p1[0,0].flatten())
    diffvel = np.std(p1[0,1].flatten()-p2[0,1].flatten())/np.std(p1[0,1].flatten())
    print(f"nmesh: {nmesh}, timestep: {da :3}", f'rel error in pos: {diffpos}, vel:{diffvel}')
    return p1, p2


In [8]:
nmesh=64
conf = Configuration(mesh_shape=[nmesh, nmesh, nmesh], 
                     snapshots= [1],
                     BoxSize=[256,256,256], 
                     initial_t0=1/64,
                     final_time = 1,
                     density_plane_npix = nmesh,
                     dt0=0.1,
                     density_plane_width= 100.0
                    )

In [8]:
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in conf.mesh_shape]),axis=-1).reshape([-1,3])

In [9]:
p1, p2 = check_reproducibility(nmesh=64, da=0.01, paint_absolute_pos=1)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


nmesh: 64, timestep: 0.01 rel error in pos: 1.1569404279043738e-07, vel:1.3421469020613586e-06


In [10]:
p3,p4 = check_reproducibility(nmesh=64, da=0.01, paint_absolute_pos=0)

nmesh: 64, timestep: 0.01 rel error in pos: 4.303087553125806e-05, vel:6.014339305693284e-05


In [11]:
mesh_shape = [64,64,64]

fields = {"ab pos" : cic_paint(jnp.zeros(mesh_shape),p1[0][0]),"no ab pos" : cic_paint_dx(p3[0][0])}
plot_fields_single_projection(fields)

<Figure size 2000x500 with 2 Axes>

In [12]:
#Test Reversibility

In [6]:
def check_reverse(
        nmesh,
        da,
        seed=0,
        paint_absolute_pos=0
    ):


    conf = Configuration(mesh_shape=[nmesh, nmesh, nmesh], 
                     snapshots= [1],
                     BoxSize=[256,256,256], 
                     initial_t0=1/64.,
                     final_time = 1.0,
                     density_plane_npix = nmesh,
                     dt0=da,
                     density_plane_width= 100.0
            )
    conf_inv = Configuration(mesh_shape=[nmesh, nmesh, nmesh], 
                     snapshots= [1],
                     BoxSize=[256,256,256], 
                     initial_t0=1.0,
                     final_time =1/64.,
                     density_plane_npix = nmesh,
                     dt0=-1*da,
                     density_plane_width= 100.0
             )
    cosmo = jc.Cosmology(Omega_c=cosmo_in[0]-cosmo_in[1], Omega_b=cosmo_in[1], h=cosmo_in[2], sigma8 = cosmo_in[4], n_s=cosmo_in[3],
                      Omega_k=0., w0=-1., wa=0.)
    ain = np.atleast_1d(1)
    _ =  growth_rate_second(cosmo,ain)
    _ = growth_rate(cosmo, ain)
    _ = growth_factor(cosmo, ain)
    jc.background.radial_comoving_distance(cosmo, ain)
    
    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in conf.mesh_shape]),axis=-1).reshape([-1,3])
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(cosmo, k)
    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
    
    # Create initial conditions and particle
    initial_conditions = linear_field(conf.mesh_shape, conf.BoxSize, pk_fn, seed=jax.random.PRNGKey(seed))
    #@jax.jit
    def model(initial_conditions, cosmo, conf, conf_inv, particles,paint_absolute_pos):
        mesh_shape = conf.mesh_shape
        initial_t0 = conf.initial_t0
        final_time = conf.final_time
        snapshots = conf.snapshots
        nmesh = conf.mesh_shape[0]
        dt0 = conf.dt0
        if paint_absolute_pos>0:
            dx, p, f = lpt(cosmo, initial_conditions, particles, a=conf.initial_t0)
        else:
            dx, p, f = lpt(cosmo, initial_conditions, a=conf.initial_t0)
        ode_fn = tree.map(
                FPMODE,
                symplectic_ode(mesh_shape, paint_absolute_pos=paint_absolute_pos>0)
        )
        solver = FPMLeapFrog(initial_t0=initial_t0, final_t1=final_time)
        solver_rev = FPMLeapFrog(initial_t0=final_time, final_t1=initial_t0)

        if paint_absolute_pos>0:
            initforce = pm_forces(
                particles+dx,
                mesh_shape=mesh_shape,
                paint_absolute_pos=paint_absolute_pos>0,
            )* 1.5* cosmo.Omega_m
        else:
            initforce = pm_forces(
                dx,
                mesh_shape=mesh_shape,
                paint_absolute_pos=paint_absolute_pos>0,
            )* 1.5* cosmo.Omega_m        
        if paint_absolute_pos>0:
            y0= jnp.stack([particles+dx, p,initforce], axis=0)
        else:
            y0= jnp.stack([dx, p,initforce], axis=0)


        step_num=jnp.ceil((conf.final_time - conf.initial_t0) / dt0)
        tseris = np.linspace(conf.initial_t0,conf.final_time, num=1+int(step_num))
        
        tseris_inv = tseris[::-1]
        stepsize_controller = StepTo(tseris)
        stepsize_controller_inv = StepTo(tseris_inv)

        res = diffeqsolve(ode_fn,solver,\
                      t0=initial_t0,\
                      t1=final_time,\
                      dt0=None,\
                      y0=y0,
                      args=[cosmo, cosmo._workspace, initial_t0, conf, 0, 0],\
                      saveat=SaveAt(ts=[final_time]),
                      stepsize_controller=stepsize_controller,adjoint=diffrax.ReversibleAdjoint())
        
        res2 = diffeqsolve(ode_fn,solver_rev ,\
                      t0=final_time,\
                      t1=initial_t0,\
                      dt0=None,\
                      y0=res.ys[0],
                      args=[cosmo, cosmo._workspace, final_time, conf_inv, 0, 0],\
                      saveat=SaveAt(ts=[initial_t0]),
                      stepsize_controller=stepsize_controller_inv,adjoint=diffrax.ReversibleAdjoint())
        
        return res2, res,y0
        
    p1old, p11, p2 = model(initial_conditions, cosmo, conf, conf_inv,particles, paint_absolute_pos)
    p1 = p1old.ys[0]
    diffpos = np.std(p1[0].flatten()-p2[0].flatten())/np.std(p1[0].flatten())
    diffvel = np.std(p1[1].flatten()-p2[1].flatten())/np.std(p1[1].flatten())
    print(f"nmesh: {nmesh}, timestep: {da :3}", f'rel error in pos: {diffpos}, vel:{diffvel}')
    return p1old, p11,p2


In [7]:
backward, middle, init = check_reverse(nmesh=64, da=0.01, paint_absolute_pos=0)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


nmesh: 64, timestep: 0.01 rel error in pos: 0.00819802563637495, vel:0.024429699406027794


In [15]:
fields = {"init" : cic_paint_dx(init[0]), "backward" : cic_paint_dx(backward.ys[0][0]), "forward" : cic_paint_dx(middle.ys[0][0])}
plot_fields_single_projection(fields)

<Figure size 2000x500 with 3 Axes>

# Check whether I can produce jaxpm universe

In [6]:
from PIL import Image, ImageFont, ImageDraw
nmesh=256

text = 'jaxpm'

ptcl_spacing = 10.
ptcl_grid_shape = (nmesh, nmesh, nmesh)
mesh_shape = (nmesh,nmesh,nmesh)
im_shape = (nmesh,nmesh)
xy = (int(nmesh/2), int(nmesh/2))

im = Image.new('L', im_shape[::-1], 255)
draw = ImageDraw.Draw(im)
fontsize=20
draw.text(xy, text, anchor='mm',font = ImageFont.truetype("/home/chto/code/pmwd/docs/nova/NovaRoundSlim-BookOblique.ttf",size=80)
)

# normalize the image to make the target
im_tgt = 1 - jnp.asarray(im) / 255
im_tgt *= jnp.prod(jnp.array(ptcl_grid_shape)) / im_tgt.sum()
im_tgt= im_tgt/im_tgt.sum()-1


In [7]:
plt.imshow(im_tgt)

<matplotlib.image.AxesImage at 0x7f0b1e714e50>

<Figure size 640x480 with 1 Axes>

In [8]:
from jaxpm.painting import cic_paint_2d
def density_plane_fn(t, y, args):
    cosmo, config = args
    positions = y
    nc = config.mesh_shape
    nx, ny, nz = nc
    density_plane_npix= config.density_plane_npix
    density_plane_width = config.density_plane_width
    # Converts time t to comoving distance in voxel coordinates
    w = density_plane_width /config.BoxSize[2] * nc[2]
    center = 16 #jc.background.radial_comoving_distance(cosmo, t) / config['BoxSize'][2] * nc[2]

    xy = positions[..., :2]
    d = positions[..., 2]

    # Apply 2d periodic conditions
    xy = jnp.mod(xy, nx)

    # Rescaling positions to target grid
    xy = xy / nx * density_plane_npix

    # Selecting only particles that fall inside the volume of interest
    weight = jnp.where((d > (center - w / 2)) & (d <= (center + w / 2)), 1., 0.)

    # Painting density plane
    density_plane = cic_paint_2d(jnp.zeros([density_plane_npix, density_plane_npix]), xy, weight)

    # Apply density normalization
    density_plane = density_plane / ((nx / density_plane_npix) *
                                     (ny / density_plane_npix) * w)

    return density_plane
import equinox as eqx


In [9]:
nmesh = 256
da = 0.1
seed=0 


conf = Configuration(mesh_shape=[nmesh, nmesh, nmesh], 
                 snapshots= [1],
                 BoxSize=[256,256,256], 
                 initial_t0=1/64,
                 final_time = 1,
                 density_plane_npix = nmesh,
                 dt0=da,
                 density_plane_width= 100.0
                )
cosmo = jc.Cosmology(Omega_c=cosmo_in[0]-cosmo_in[1], Omega_b=cosmo_in[1], h=cosmo_in[2], sigma8 = cosmo_in[4], n_s=cosmo_in[3],
                  Omega_k=0., w0=-1., wa=0.)
ain=np.atleast_1d(1)
_ =  growth_rate_second(cosmo,ain)
_ = growth_rate(cosmo, ain)
_ = growth_factor(cosmo, ain)
jc.background.radial_comoving_distance(cosmo, ain)

k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
# Create initial conditions and particle
initial_conditions = linear_field(conf.mesh_shape, conf.BoxSize, pk_fn, seed=jax.random.PRNGKey(seed))
@jax.jit
def model(initial_conditions, cosmo, conf,):
    mesh_shape = conf.mesh_shape
    initial_t0 = conf.initial_t0
    final_time = conf.final_time
    snapshots = conf.snapshots
    nmesh = conf.mesh_shape[0]
    dt0 = conf.dt0
    dx, p, f = lpt(cosmo, initial_conditions, a=conf.initial_t0)
    ode_fn = tree.map(
            FPMODE,
            symplectic_ode(mesh_shape, paint_absolute_pos=False)
    )
    solver = FPMLeapFrog(initial_t0=initial_t0, final_t1=final_time)
    stepsize_controller = ConstantStepSize()

    initforce = pm_forces(
        dx,
        mesh_shape=mesh_shape,
        paint_absolute_pos=False,
    )* 1.5* cosmo.Omega_m        
    
    y0= jnp.stack([dx, p,initforce], axis=0)
    res = diffeqsolve(ode_fn,solver,\
                  t0=initial_t0,\
                  t1=final_time,\
                  dt0=dt0,\
                  y0=y0,
                  args=[cosmo, cosmo._workspace, initial_t0, conf, 0, 0],\
                  saveat=SaveAt(t1=True),
                  stepsize_controller=stepsize_controller,adjoint=diffrax.ReversibleAdjoint())

    a, b, c = jnp.meshgrid(jnp.arange(conf.mesh_shape[0]),
                           jnp.arange(conf.mesh_shape[1]),
                           jnp.arange(conf.mesh_shape[2]),
                           indexing='ij')
    pmid = jnp.stack([a , b , c], axis=-1)
    
    res2 = density_plane_fn(res.ts[-1], (pmid+res.ys[0][0]).reshape(-1,3), [cosmo,conf])
    return res2/np.sum(res2)-1
p1 = model(initial_conditions, cosmo, conf)


  return lax_numpy.astype(self, dtype, copy=copy, device=device)


In [10]:
plt.imshow(p1)

<matplotlib.image.AxesImage at 0x7f0ac015f5e0>

<Figure size 640x480 with 1 Axes>

In [11]:
from tqdm import tqdm
from jax.example_libraries.optimizers import adam
from jax import value_and_grad
def obj(initial_conditions):
    dens = model(initial_conditions, cosmo, conf)
    
    return (dens - im_tgt).var() / im_tgt.var()
    
obj_valgrad = jax.jit(jax.value_and_grad(obj))

In [13]:
%prun obj_valgrad(initial_conditions)

 

         1869 function calls (1867 primitive calls) in 2.698 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    2.520    2.520    2.520    2.520 pxla.py:1277(__call__)
        1    0.176    0.176    0.176    0.176 pxla.py:2649(_get_layouts_from_executable)
        3    0.000    0.000    0.000    0.000 {built-in method fromkeys}
        1    0.000    0.000    0.000    0.000 pxla.py:2464(get_out_shardings_from_executable)
        1    0.000    0.000    0.001    0.001 pxla.py:2208(lower_sharding_computation)
     1032    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        1    0.000    0.000    0.176    0.176 pxla.py:2899(from_hlo)
        1    0.000    0.000    0.000    0.000 pjit.py:701(_infer_params)
        1    0.000    0.000    0.000    0.000 pxla.py:1725(_get_and_check_device_assignment)
        1    0.000    0.000    2.698    2.698 <string>:1(<module>)
      188    0.000    0.000   

In [14]:
#Check runtime and memory
#%prun obj_valgrad(initial_conditions)
#with jax.profiler.trace("tensorborard/"):
#    _,_ =obj_valgrad(initial_conditions)#.block_until_ready()


In [15]:
cpu_device = jax.devices('cpu')[0]

In [16]:
from tqdm import tqdm
from jax.example_libraries.optimizers import adam
from jax import value_and_grad
def optim(tgt, initial_conditions, cosmo, conf, iters=100, lr=0.1):
    init, update, get_params = adam(lr)
    state = init(initial_conditions)
    #@jax.jit
    def step(i, state, tgt, cosmo, conf):
        initial_conditions = get_params(state)
        value, grads = obj_valgrad(initial_conditions)
        state = update(i, grads, state)
        jax.debug.print("{y} loss {x}", x=value, y=i)
        return value, state
    tgt = jnp.asarray(tgt)
    
    initall = []
    for i in tqdm(range(iters)):
        value, state = step(i, state, tgt, cosmo, conf)
        initall.append(jax.jit(get_params, device=cpu_device)(state))
        #jax.clear_caches() 
    initial_conditions = get_params(state)
    return value, initial_conditions, initall

In [None]:
loss, initial_conditions_optim, initall = optim(im_tgt, initial_conditions, cosmo, conf, iters=500)
loss, initial_conditions.std(), initial_conditions_optim.std()

  0%|▎                                                                                                                                                                                | 1/500 [00:03<26:09,  3.14s/it]

0 loss 1.090367317199707


  0%|▋                                                                                                                                                                                | 2/500 [00:05<24:31,  2.96s/it]

1 loss 1.0633419752120972


  1%|█                                                                                                                                                                                | 3/500 [00:08<23:58,  2.89s/it]

2 loss 1.045209527015686


  1%|█▍                                                                                                                                                                               | 4/500 [00:11<23:42,  2.87s/it]

3 loss 1.0295987129211426


  1%|█▊                                                                                                                                                                               | 5/500 [00:14<23:32,  2.85s/it]

4 loss 1.0146608352661133


  1%|██                                                                                                                                                                               | 6/500 [00:17<23:24,  2.84s/it]

5 loss 1.0006753206253052


  1%|██▍                                                                                                                                                                              | 7/500 [00:20<23:19,  2.84s/it]

6 loss 0.9858233332633972


  2%|██▊                                                                                                                                                                              | 8/500 [00:22<23:13,  2.83s/it]

7 loss 0.9709124565124512


  2%|███▏                                                                                                                                                                             | 9/500 [00:25<23:09,  2.83s/it]

8 loss 0.9578410983085632


  2%|███▌                                                                                                                                                                            | 10/500 [00:28<23:06,  2.83s/it]

9 loss 0.946280300617218


  2%|███▊                                                                                                                                                                            | 11/500 [00:31<23:03,  2.83s/it]

10 loss 0.9352635145187378


  2%|████▏                                                                                                                                                                           | 12/500 [00:34<22:59,  2.83s/it]

11 loss 0.9243565201759338


  3%|████▌                                                                                                                                                                           | 13/500 [00:37<22:56,  2.83s/it]

12 loss 0.9125955700874329


  3%|████▉                                                                                                                                                                           | 14/500 [00:39<22:53,  2.83s/it]

13 loss 0.9005900025367737


  3%|█████▎                                                                                                                                                                          | 15/500 [00:42<22:51,  2.83s/it]

14 loss 0.8898316025733948


  3%|█████▋                                                                                                                                                                          | 16/500 [00:45<22:47,  2.83s/it]

15 loss 0.8802164793014526


  3%|█████▉                                                                                                                                                                          | 17/500 [00:48<22:43,  2.82s/it]

16 loss 0.870062530040741


  4%|██████▎                                                                                                                                                                         | 18/500 [00:51<22:41,  2.82s/it]

17 loss 0.8604196310043335


  4%|██████▋                                                                                                                                                                         | 19/500 [00:53<22:37,  2.82s/it]

18 loss 0.8506972193717957


  4%|███████                                                                                                                                                                         | 20/500 [00:56<22:34,  2.82s/it]

19 loss 0.8402960896492004


  4%|███████▍                                                                                                                                                                        | 21/500 [00:59<22:32,  2.82s/it]

20 loss 0.8302951455116272


  4%|███████▋                                                                                                                                                                        | 22/500 [01:02<22:30,  2.82s/it]

21 loss 0.8211338520050049


  5%|████████                                                                                                                                                                        | 23/500 [01:05<22:26,  2.82s/it]

22 loss 0.8126480579376221


  5%|████████▍                                                                                                                                                                       | 24/500 [01:08<22:24,  2.82s/it]

23 loss 0.8047407269477844


  5%|████████▊                                                                                                                                                                       | 25/500 [01:10<22:23,  2.83s/it]

24 loss 0.7968543171882629


  5%|█████████▏                                                                                                                                                                      | 26/500 [01:13<22:20,  2.83s/it]

25 loss 0.789260983467102


  5%|█████████▌                                                                                                                                                                      | 27/500 [01:16<22:16,  2.83s/it]

26 loss 0.7826585173606873


  6%|█████████▊                                                                                                                                                                      | 28/500 [01:19<22:13,  2.83s/it]

27 loss 0.7761842012405396


  6%|██████████▏                                                                                                                                                                     | 29/500 [01:22<22:10,  2.83s/it]

28 loss 0.7695777416229248


  6%|██████████▌                                                                                                                                                                     | 30/500 [01:25<22:08,  2.83s/it]

29 loss 0.763029932975769


  6%|██████████▉                                                                                                                                                                     | 31/500 [01:27<22:05,  2.83s/it]

30 loss 0.7565602660179138


  6%|███████████▎                                                                                                                                                                    | 32/500 [01:30<22:03,  2.83s/it]

31 loss 0.7504663467407227


  7%|███████████▌                                                                                                                                                                    | 33/500 [01:33<21:59,  2.83s/it]

32 loss 0.7446191906929016


  7%|███████████▉                                                                                                                                                                    | 34/500 [01:36<21:57,  2.83s/it]

33 loss 0.7388617396354675


  7%|████████████▎                                                                                                                                                                   | 35/500 [01:39<21:53,  2.82s/it]

34 loss 0.7330236434936523


  7%|████████████▋                                                                                                                                                                   | 36/500 [01:42<21:50,  2.82s/it]

35 loss 0.7270017266273499


  7%|█████████████                                                                                                                                                                   | 37/500 [01:44<21:48,  2.83s/it]

36 loss 0.721198558807373


  8%|█████████████▍                                                                                                                                                                  | 38/500 [01:47<21:46,  2.83s/it]

37 loss 0.7159629464149475


  8%|█████████████▋                                                                                                                                                                  | 39/500 [01:50<21:43,  2.83s/it]

38 loss 0.7104490995407104


  8%|██████████████                                                                                                                                                                  | 40/500 [01:53<21:40,  2.83s/it]

39 loss 0.7046354413032532


  8%|██████████████▍                                                                                                                                                                 | 41/500 [01:56<21:37,  2.83s/it]

40 loss 0.6992689967155457


  8%|██████████████▊                                                                                                                                                                 | 42/500 [01:58<21:35,  2.83s/it]

41 loss 0.6939848065376282


  9%|███████████████▏                                                                                                                                                                | 43/500 [02:01<21:32,  2.83s/it]

42 loss 0.6889059543609619


  9%|███████████████▍                                                                                                                                                                | 44/500 [02:04<21:29,  2.83s/it]

43 loss 0.6837449669837952


  9%|███████████████▊                                                                                                                                                                | 45/500 [02:07<21:28,  2.83s/it]

44 loss 0.6788484454154968


  9%|████████████████▏                                                                                                                                                               | 46/500 [02:10<21:26,  2.83s/it]

45 loss 0.6737843751907349


  9%|████████████████▌                                                                                                                                                               | 47/500 [02:13<21:22,  2.83s/it]

46 loss 0.6685346961021423


 10%|████████████████▉                                                                                                                                                               | 48/500 [02:15<21:19,  2.83s/it]

47 loss 0.663254976272583


 10%|█████████████████▏                                                                                                                                                              | 49/500 [02:18<21:17,  2.83s/it]

48 loss 0.658395528793335


 10%|█████████████████▌                                                                                                                                                              | 50/500 [02:21<21:15,  2.83s/it]

49 loss 0.6536766886711121


 10%|█████████████████▉                                                                                                                                                              | 51/500 [02:24<21:11,  2.83s/it]

50 loss 0.6487855315208435


 10%|██████████████████▎                                                                                                                                                             | 52/500 [02:27<21:08,  2.83s/it]

51 loss 0.6439735889434814


 11%|██████████████████▋                                                                                                                                                             | 53/500 [02:30<21:06,  2.83s/it]

52 loss 0.6388628482818604


 11%|███████████████████                                                                                                                                                             | 54/500 [02:32<21:04,  2.83s/it]

53 loss 0.6338284015655518


 11%|███████████████████▎                                                                                                                                                            | 55/500 [02:35<21:00,  2.83s/it]

54 loss 0.6288179159164429


 11%|███████████████████▋                                                                                                                                                            | 56/500 [02:38<20:58,  2.83s/it]

55 loss 0.6236178874969482


 11%|████████████████████                                                                                                                                                            | 57/500 [02:41<20:54,  2.83s/it]

56 loss 0.6186352968215942


 12%|████████████████████▍                                                                                                                                                           | 58/500 [02:44<20:51,  2.83s/it]

57 loss 0.6138918399810791


 12%|████████████████████▊                                                                                                                                                           | 59/500 [02:47<20:48,  2.83s/it]

58 loss 0.6093277335166931


 12%|█████████████████████                                                                                                                                                           | 60/500 [02:49<20:46,  2.83s/it]

59 loss 0.6048154234886169


 12%|█████████████████████▍                                                                                                                                                          | 61/500 [02:52<20:44,  2.84s/it]

60 loss 0.6001509428024292


 12%|█████████████████████▊                                                                                                                                                          | 62/500 [02:55<20:42,  2.84s/it]

61 loss 0.5956290364265442


 13%|██████████████████████▏                                                                                                                                                         | 63/500 [02:58<20:39,  2.84s/it]

62 loss 0.5915685892105103


 13%|██████████████████████▌                                                                                                                                                         | 64/500 [03:01<20:36,  2.84s/it]

63 loss 0.5878387093544006


 13%|██████████████████████▉                                                                                                                                                         | 65/500 [03:04<20:33,  2.84s/it]

64 loss 0.5839881300926208


 13%|███████████████████████▏                                                                                                                                                        | 66/500 [03:07<20:30,  2.84s/it]

65 loss 0.5802189707756042


 13%|███████████████████████▌                                                                                                                                                        | 67/500 [03:09<20:28,  2.84s/it]

66 loss 0.5764337778091431


 14%|███████████████████████▉                                                                                                                                                        | 68/500 [03:12<20:25,  2.84s/it]

67 loss 0.5727868676185608


 14%|████████████████████████▎                                                                                                                                                       | 69/500 [03:15<20:22,  2.84s/it]

68 loss 0.5695176720619202


 14%|████████████████████████▋                                                                                                                                                       | 70/500 [03:18<20:20,  2.84s/it]

69 loss 0.5664553642272949


 14%|████████████████████████▉                                                                                                                                                       | 71/500 [03:21<20:17,  2.84s/it]

70 loss 0.5635437965393066


 14%|█████████████████████████▎                                                                                                                                                      | 72/500 [03:24<20:15,  2.84s/it]

71 loss 0.560016930103302


 15%|█████████████████████████▋                                                                                                                                                      | 73/500 [03:26<20:12,  2.84s/it]

72 loss 0.5558573603630066


 15%|██████████████████████████                                                                                                                                                      | 74/500 [03:29<20:09,  2.84s/it]

73 loss 0.5527053475379944


 15%|██████████████████████████▍                                                                                                                                                     | 75/500 [03:32<20:05,  2.84s/it]

74 loss 0.5500802397727966


 15%|██████████████████████████▊                                                                                                                                                     | 76/500 [03:35<20:03,  2.84s/it]

75 loss 0.5474867820739746


 15%|███████████████████████████                                                                                                                                                     | 77/500 [03:38<20:01,  2.84s/it]

76 loss 0.5445327758789062


 16%|███████████████████████████▍                                                                                                                                                    | 78/500 [03:41<20:00,  2.84s/it]

77 loss 0.5407780408859253


 16%|███████████████████████████▊                                                                                                                                                    | 79/500 [03:43<19:57,  2.84s/it]

78 loss 0.537581205368042


 16%|████████████████████████████▏                                                                                                                                                   | 80/500 [03:46<19:54,  2.84s/it]

79 loss 0.5351647734642029


 16%|████████████████████████████▌                                                                                                                                                   | 81/500 [03:49<19:50,  2.84s/it]

80 loss 0.5325737595558167


 16%|████████████████████████████▊                                                                                                                                                   | 82/500 [03:52<19:49,  2.84s/it]

81 loss 0.5301016569137573


 17%|█████████████████████████████▏                                                                                                                                                  | 83/500 [03:55<19:46,  2.85s/it]

82 loss 0.5272460579872131


 17%|█████████████████████████████▌                                                                                                                                                  | 84/500 [03:58<19:44,  2.85s/it]

83 loss 0.5236768126487732


 17%|█████████████████████████████▉                                                                                                                                                  | 85/500 [04:01<19:41,  2.85s/it]

84 loss 0.5206641554832458


 17%|██████████████████████████████▎                                                                                                                                                 | 86/500 [04:03<19:39,  2.85s/it]

85 loss 0.5185787081718445


 17%|██████████████████████████████▌                                                                                                                                                 | 87/500 [04:06<19:35,  2.85s/it]

86 loss 0.5163140892982483


 18%|██████████████████████████████▉                                                                                                                                                 | 88/500 [04:09<19:33,  2.85s/it]

87 loss 0.5140685439109802


 18%|███████████████████████████████▎                                                                                                                                                | 89/500 [04:12<19:30,  2.85s/it]

88 loss 0.5120940804481506


 18%|███████████████████████████████▋                                                                                                                                                | 90/500 [04:15<19:27,  2.85s/it]

89 loss 0.5098937749862671


 18%|████████████████████████████████                                                                                                                                                | 91/500 [04:18<19:24,  2.85s/it]

90 loss 0.5073598027229309


 18%|████████████████████████████████▍                                                                                                                                               | 92/500 [04:20<19:21,  2.85s/it]

91 loss 0.5047523379325867


 19%|████████████████████████████████▋                                                                                                                                               | 93/500 [04:23<19:19,  2.85s/it]

92 loss 0.502526581287384


 19%|█████████████████████████████████                                                                                                                                               | 94/500 [04:26<19:17,  2.85s/it]

93 loss 0.5004652738571167


 19%|█████████████████████████████████▍                                                                                                                                              | 95/500 [04:29<19:13,  2.85s/it]

94 loss 0.4979912042617798


 19%|█████████████████████████████████▊                                                                                                                                              | 96/500 [04:32<19:11,  2.85s/it]

95 loss 0.49580109119415283


 19%|██████████████████████████████████▏                                                                                                                                             | 97/500 [04:35<19:08,  2.85s/it]

96 loss 0.4938182830810547


 20%|██████████████████████████████████▍                                                                                                                                             | 98/500 [04:38<19:06,  2.85s/it]

97 loss 0.4913618564605713


 20%|██████████████████████████████████▊                                                                                                                                             | 99/500 [04:40<19:02,  2.85s/it]

98 loss 0.4892450273036957


 20%|███████████████████████████████████                                                                                                                                            | 100/500 [04:43<19:00,  2.85s/it]

99 loss 0.4871169924736023


 20%|███████████████████████████████████▎                                                                                                                                           | 101/500 [04:46<18:57,  2.85s/it]

100 loss 0.4858599901199341


 20%|███████████████████████████████████▋                                                                                                                                           | 102/500 [04:49<18:54,  2.85s/it]

101 loss 0.4866996705532074


 21%|████████████████████████████████████                                                                                                                                           | 103/500 [04:52<18:50,  2.85s/it]

102 loss 0.48740509152412415


 21%|████████████████████████████████████▍                                                                                                                                          | 104/500 [04:55<18:48,  2.85s/it]

103 loss 0.4834144413471222


 21%|████████████████████████████████████▊                                                                                                                                          | 105/500 [04:58<18:46,  2.85s/it]

104 loss 0.47783467173576355


 21%|█████████████████████████████████████                                                                                                                                          | 106/500 [05:00<18:43,  2.85s/it]

105 loss 0.4770389795303345


 21%|█████████████████████████████████████▍                                                                                                                                         | 107/500 [05:03<18:39,  2.85s/it]

106 loss 0.4779355525970459


 22%|█████████████████████████████████████▊                                                                                                                                         | 108/500 [05:06<18:38,  2.85s/it]

107 loss 0.4740087687969208


 22%|██████████████████████████████████████▏                                                                                                                                        | 109/500 [05:09<18:34,  2.85s/it]

108 loss 0.47069433331489563


 22%|██████████████████████████████████████▌                                                                                                                                        | 110/500 [05:12<18:32,  2.85s/it]

109 loss 0.4708717167377472


 22%|██████████████████████████████████████▊                                                                                                                                        | 111/500 [05:15<18:29,  2.85s/it]

110 loss 0.4692503809928894


 22%|███████████████████████████████████████▏                                                                                                                                       | 112/500 [05:17<18:27,  2.85s/it]

111 loss 0.4657304584980011


 23%|███████████████████████████████████████▌                                                                                                                                       | 113/500 [05:20<18:25,  2.86s/it]

112 loss 0.464556485414505


 23%|███████████████████████████████████████▉                                                                                                                                       | 114/500 [05:23<18:23,  2.86s/it]

113 loss 0.4634731113910675


 23%|████████████████████████████████████████▎                                                                                                                                      | 115/500 [05:26<18:20,  2.86s/it]

114 loss 0.4605885446071625


 23%|████████████████████████████████████████▌                                                                                                                                      | 116/500 [05:29<18:21,  2.87s/it]

115 loss 0.4588755667209625


 23%|████████████████████████████████████████▉                                                                                                                                      | 117/500 [05:32<18:18,  2.87s/it]

116 loss 0.4578353464603424


 24%|█████████████████████████████████████████▎                                                                                                                                     | 118/500 [05:35<18:15,  2.87s/it]

117 loss 0.4554426074028015


 24%|█████████████████████████████████████████▋                                                                                                                                     | 119/500 [05:38<18:14,  2.87s/it]

118 loss 0.4541645050048828


 24%|██████████████████████████████████████████                                                                                                                                     | 120/500 [05:40<18:13,  2.88s/it]

119 loss 0.4537060856819153


 24%|██████████████████████████████████████████▎                                                                                                                                    | 121/500 [05:43<18:13,  2.88s/it]

120 loss 0.4516734480857849


 24%|██████████████████████████████████████████▋                                                                                                                                    | 122/500 [05:46<18:11,  2.89s/it]

121 loss 0.4501376152038574


 25%|███████████████████████████████████████████                                                                                                                                    | 123/500 [05:49<18:08,  2.89s/it]

122 loss 0.4499657452106476


 25%|███████████████████████████████████████████▍                                                                                                                                   | 124/500 [05:52<18:07,  2.89s/it]

123 loss 0.4483771324157715


 25%|███████████████████████████████████████████▊                                                                                                                                   | 125/500 [05:55<18:04,  2.89s/it]

124 loss 0.44600939750671387


 25%|████████████████████████████████████████████                                                                                                                                   | 126/500 [05:58<18:01,  2.89s/it]

125 loss 0.44464734196662903


 25%|████████████████████████████████████████████▍                                                                                                                                  | 127/500 [06:01<17:59,  2.90s/it]

126 loss 0.44380807876586914


 26%|████████████████████████████████████████████▊                                                                                                                                  | 128/500 [06:04<17:57,  2.90s/it]

127 loss 0.44237518310546875


 26%|█████████████████████████████████████████████▏                                                                                                                                 | 129/500 [06:07<17:55,  2.90s/it]

128 loss 0.44050362706184387


 26%|█████████████████████████████████████████████▌                                                                                                                                 | 130/500 [06:09<17:51,  2.90s/it]

129 loss 0.4393376111984253


 26%|█████████████████████████████████████████████▊                                                                                                                                 | 131/500 [06:12<17:48,  2.90s/it]

130 loss 0.43809276819229126


 26%|██████████████████████████████████████████████▏                                                                                                                                | 132/500 [06:15<17:45,  2.89s/it]

131 loss 0.43656110763549805


 27%|██████████████████████████████████████████████▌                                                                                                                                | 133/500 [06:18<17:43,  2.90s/it]

132 loss 0.43467190861701965


 27%|██████████████████████████████████████████████▉                                                                                                                                | 134/500 [06:21<17:41,  2.90s/it]

133 loss 0.4333142340183258


 27%|███████████████████████████████████████████████▎                                                                                                                               | 135/500 [06:24<17:37,  2.90s/it]

134 loss 0.43214571475982666


 27%|███████████████████████████████████████████████▌                                                                                                                               | 136/500 [06:27<17:34,  2.90s/it]

135 loss 0.43086549639701843


 27%|███████████████████████████████████████████████▉                                                                                                                               | 137/500 [06:30<17:33,  2.90s/it]

136 loss 0.4293747544288635


 28%|████████████████████████████████████████████████▎                                                                                                                              | 138/500 [06:33<17:33,  2.91s/it]

137 loss 0.42803019285202026


 28%|████████████████████████████████████████████████▋                                                                                                                              | 139/500 [06:36<17:28,  2.90s/it]

138 loss 0.4269711375236511


 28%|█████████████████████████████████████████████████                                                                                                                              | 140/500 [06:38<17:24,  2.90s/it]

139 loss 0.42624005675315857


 28%|█████████████████████████████████████████████████▎                                                                                                                             | 141/500 [06:41<17:20,  2.90s/it]

140 loss 0.4255235195159912


 28%|█████████████████████████████████████████████████▋                                                                                                                             | 142/500 [06:44<17:17,  2.90s/it]

141 loss 0.4247898459434509


 29%|██████████████████████████████████████████████████                                                                                                                             | 143/500 [06:47<17:16,  2.90s/it]

142 loss 0.42424216866493225


 29%|██████████████████████████████████████████████████▍                                                                                                                            | 144/500 [06:50<17:14,  2.91s/it]

143 loss 0.42428046464920044


 29%|██████████████████████████████████████████████████▊                                                                                                                            | 145/500 [06:53<17:10,  2.90s/it]

144 loss 0.4251537024974823


 29%|███████████████████████████████████████████████████                                                                                                                            | 146/500 [06:56<17:07,  2.90s/it]

145 loss 0.425407737493515


 29%|███████████████████████████████████████████████████▍                                                                                                                           | 147/500 [06:59<17:04,  2.90s/it]

146 loss 0.421876460313797


 30%|███████████████████████████████████████████████████▊                                                                                                                           | 148/500 [07:02<17:03,  2.91s/it]

147 loss 0.4174056649208069


 30%|████████████████████████████████████████████████████▏                                                                                                                          | 149/500 [07:05<17:01,  2.91s/it]

148 loss 0.41589561104774475


 30%|████████████████████████████████████████████████████▌                                                                                                                          | 150/500 [07:07<16:57,  2.91s/it]

149 loss 0.41620299220085144


 30%|████████████████████████████████████████████████████▊                                                                                                                          | 151/500 [07:10<16:54,  2.91s/it]

150 loss 0.41644129157066345


 30%|█████████████████████████████████████████████████████▏                                                                                                                         | 152/500 [07:13<16:51,  2.91s/it]

151 loss 0.41459494829177856


 31%|█████████████████████████████████████████████████████▌                                                                                                                         | 153/500 [07:16<16:47,  2.90s/it]

152 loss 0.41217920184135437


 31%|█████████████████████████████████████████████████████▉                                                                                                                         | 154/500 [07:19<16:45,  2.91s/it]

153 loss 0.41047823429107666


 31%|██████████████████████████████████████████████████████▎                                                                                                                        | 155/500 [07:22<16:42,  2.91s/it]

154 loss 0.40958401560783386


 31%|██████████████████████████████████████████████████████▌                                                                                                                        | 156/500 [07:25<16:39,  2.91s/it]

155 loss 0.4094378352165222


 31%|██████████████████████████████████████████████████████▉                                                                                                                        | 157/500 [07:28<16:36,  2.91s/it]

156 loss 0.4088055193424225


 32%|███████████████████████████████████████████████████████▎                                                                                                                       | 158/500 [07:31<16:33,  2.91s/it]

157 loss 0.4079681932926178


 32%|███████████████████████████████████████████████████████▋                                                                                                                       | 159/500 [07:34<16:31,  2.91s/it]

158 loss 0.4066224694252014


 32%|████████████████████████████████████████████████████████                                                                                                                       | 160/500 [07:37<16:28,  2.91s/it]

159 loss 0.4050247371196747


 32%|████████████████████████████████████████████████████████▎                                                                                                                      | 161/500 [07:39<16:25,  2.91s/it]

160 loss 0.40370121598243713


 32%|████████████████████████████████████████████████████████▋                                                                                                                      | 162/500 [07:42<16:22,  2.91s/it]

161 loss 0.4031286835670471


 33%|█████████████████████████████████████████████████████████                                                                                                                      | 163/500 [07:45<16:20,  2.91s/it]

162 loss 0.40266576409339905


 33%|█████████████████████████████████████████████████████████▍                                                                                                                     | 164/500 [07:48<16:17,  2.91s/it]

163 loss 0.40151727199554443


 33%|█████████████████████████████████████████████████████████▊                                                                                                                     | 165/500 [07:51<16:14,  2.91s/it]

164 loss 0.40004464983940125


 33%|██████████████████████████████████████████████████████████                                                                                                                     | 166/500 [07:54<16:12,  2.91s/it]

165 loss 0.3990285396575928


 33%|██████████████████████████████████████████████████████████▍                                                                                                                    | 167/500 [07:57<16:09,  2.91s/it]

166 loss 0.39880073070526123


 34%|██████████████████████████████████████████████████████████▊                                                                                                                    | 168/500 [08:00<16:07,  2.91s/it]

167 loss 0.39877060055732727


 34%|███████████████████████████████████████████████████████████▏                                                                                                                   | 169/500 [08:03<16:06,  2.92s/it]

168 loss 0.3983645439147949


 34%|███████████████████████████████████████████████████████████▌                                                                                                                   | 170/500 [08:06<16:03,  2.92s/it]

169 loss 0.3970707654953003


 34%|███████████████████████████████████████████████████████████▊                                                                                                                   | 171/500 [08:09<15:59,  2.92s/it]

170 loss 0.395785391330719


 34%|████████████████████████████████████████████████████████████▏                                                                                                                  | 172/500 [08:12<15:57,  2.92s/it]

171 loss 0.39507025480270386


 35%|████████████████████████████████████████████████████████████▌                                                                                                                  | 173/500 [08:14<15:54,  2.92s/it]

172 loss 0.3950096070766449


 35%|████████████████████████████████████████████████████████████▉                                                                                                                  | 174/500 [08:17<15:53,  2.93s/it]

173 loss 0.395041823387146


 35%|█████████████████████████████████████████████████████████████▏                                                                                                                 | 175/500 [08:20<15:51,  2.93s/it]

174 loss 0.39450037479400635


 35%|█████████████████████████████████████████████████████████████▌                                                                                                                 | 176/500 [08:23<15:47,  2.92s/it]

175 loss 0.392690509557724


 35%|█████████████████████████████████████████████████████████████▉                                                                                                                 | 177/500 [08:26<15:43,  2.92s/it]

176 loss 0.3908011019229889


 36%|██████████████████████████████████████████████████████████████▎                                                                                                                | 178/500 [08:29<15:40,  2.92s/it]

177 loss 0.3895389139652252


 36%|██████████████████████████████████████████████████████████████▋                                                                                                                | 179/500 [08:32<15:36,  2.92s/it]

178 loss 0.3890276849269867


 36%|███████████████████████████████████████████████████████████████                                                                                                                | 180/500 [08:35<15:36,  2.93s/it]

179 loss 0.3889959156513214


 36%|███████████████████████████████████████████████████████████████▎                                                                                                               | 181/500 [08:38<15:32,  2.92s/it]

180 loss 0.3888745605945587


 36%|███████████████████████████████████████████████████████████████▋                                                                                                               | 182/500 [08:41<15:28,  2.92s/it]

181 loss 0.3883274793624878


 37%|████████████████████████████████████████████████████████████████                                                                                                               | 183/500 [08:44<15:25,  2.92s/it]

182 loss 0.3871501684188843


 37%|████████████████████████████████████████████████████████████████▍                                                                                                              | 184/500 [08:47<15:22,  2.92s/it]

183 loss 0.38596880435943604


 37%|████████████████████████████████████████████████████████████████▊                                                                                                              | 185/500 [08:50<15:21,  2.93s/it]

184 loss 0.38498982787132263


 37%|█████████████████████████████████████████████████████████████████                                                                                                              | 186/500 [08:52<15:18,  2.92s/it]

185 loss 0.38433876633644104


 37%|█████████████████████████████████████████████████████████████████▍                                                                                                             | 187/500 [08:55<15:14,  2.92s/it]

186 loss 0.38390180468559265


 38%|█████████████████████████████████████████████████████████████████▊                                                                                                             | 188/500 [08:58<15:13,  2.93s/it]

187 loss 0.3836005628108978


 38%|██████████████████████████████████████████████████████████████████▏                                                                                                            | 189/500 [09:01<15:09,  2.93s/it]

188 loss 0.383026659488678


 38%|██████████████████████████████████████████████████████████████████▌                                                                                                            | 190/500 [09:04<15:07,  2.93s/it]

189 loss 0.3824222683906555


 38%|██████████████████████████████████████████████████████████████████▊                                                                                                            | 191/500 [09:07<15:03,  2.93s/it]

190 loss 0.3819741904735565


 38%|███████████████████████████████████████████████████████████████████▏                                                                                                           | 192/500 [09:10<15:02,  2.93s/it]

191 loss 0.3819151818752289


 39%|███████████████████████████████████████████████████████████████████▌                                                                                                           | 193/500 [09:13<15:01,  2.94s/it]

192 loss 0.3820951581001282


 39%|███████████████████████████████████████████████████████████████████▉                                                                                                           | 194/500 [09:16<14:59,  2.94s/it]

193 loss 0.3813333511352539


 39%|████████████████████████████████████████████████████████████████████▎                                                                                                          | 195/500 [09:19<14:55,  2.94s/it]

194 loss 0.37998855113983154


 39%|████████████████████████████████████████████████████████████████████▌                                                                                                          | 196/500 [09:22<14:51,  2.93s/it]

195 loss 0.3781597316265106


 39%|████████████████████████████████████████████████████████████████████▉                                                                                                          | 197/500 [09:25<14:48,  2.93s/it]

196 loss 0.3771599531173706


 40%|█████████████████████████████████████████████████████████████████████▎                                                                                                         | 198/500 [09:28<14:45,  2.93s/it]

197 loss 0.3768008351325989


 40%|█████████████████████████████████████████████████████████████████████▋                                                                                                         | 199/500 [09:31<14:42,  2.93s/it]

198 loss 0.37625381350517273


 40%|██████████████████████████████████████████████████████████████████████                                                                                                         | 200/500 [09:34<14:39,  2.93s/it]

199 loss 0.37455135583877563


 40%|██████████████████████████████████████████████████████████████████████▎                                                                                                        | 201/500 [09:36<14:36,  2.93s/it]

200 loss 0.37265416979789734


 40%|██████████████████████████████████████████████████████████████████████▋                                                                                                        | 202/500 [09:39<14:33,  2.93s/it]

201 loss 0.3722960948944092


 41%|███████████████████████████████████████████████████████████████████████                                                                                                        | 203/500 [09:42<14:31,  2.94s/it]

202 loss 0.37323445081710815


 41%|███████████████████████████████████████████████████████████████████████▍                                                                                                       | 204/500 [09:45<14:29,  2.94s/it]

203 loss 0.37347841262817383


 41%|███████████████████████████████████████████████████████████████████████▊                                                                                                       | 205/500 [09:48<14:26,  2.94s/it]

204 loss 0.37128522992134094


 41%|████████████████████████████████████████████████████████████████████████                                                                                                       | 206/500 [09:51<14:22,  2.94s/it]

205 loss 0.3692815899848938


 41%|████████████████████████████████████████████████████████████████████████▍                                                                                                      | 207/500 [09:54<14:19,  2.93s/it]

206 loss 0.36930185556411743


 42%|████████████████████████████████████████████████████████████████████████▊                                                                                                      | 208/500 [09:57<14:18,  2.94s/it]

207 loss 0.36957985162734985


 42%|█████████████████████████████████████████████████████████████████████████▏                                                                                                     | 209/500 [10:00<14:14,  2.94s/it]

208 loss 0.36798086762428284


 42%|█████████████████████████████████████████████████████████████████████████▌                                                                                                     | 210/500 [10:03<14:12,  2.94s/it]

209 loss 0.3659263551235199


 42%|█████████████████████████████████████████████████████████████████████████▊                                                                                                     | 211/500 [10:06<14:08,  2.94s/it]

210 loss 0.36550062894821167


 42%|██████████████████████████████████████████████████████████████████████████▏                                                                                                    | 212/500 [10:09<14:05,  2.94s/it]

211 loss 0.36589598655700684


 43%|██████████████████████████████████████████████████████████████████████████▌                                                                                                    | 213/500 [10:12<14:01,  2.93s/it]

212 loss 0.3654411733150482


 43%|██████████████████████████████████████████████████████████████████████████▉                                                                                                    | 214/500 [10:15<13:58,  2.93s/it]

213 loss 0.36432623863220215


 43%|███████████████████████████████████████████████████████████████████████████▎                                                                                                   | 215/500 [10:18<13:56,  2.93s/it]

214 loss 0.3633732199668884


 43%|███████████████████████████████████████████████████████████████████████████▌                                                                                                   | 216/500 [10:20<13:52,  2.93s/it]

215 loss 0.36371293663978577


 43%|███████████████████████████████████████████████████████████████████████████▉                                                                                                   | 217/500 [10:23<13:51,  2.94s/it]

216 loss 0.36440154910087585


 44%|████████████████████████████████████████████████████████████████████████████▎                                                                                                  | 218/500 [10:26<13:48,  2.94s/it]

217 loss 0.3639930486679077


 44%|████████████████████████████████████████████████████████████████████████████▋                                                                                                  | 219/500 [10:29<13:45,  2.94s/it]

218 loss 0.36302390694618225


 44%|█████████████████████████████████████████████████████████████████████████████                                                                                                  | 220/500 [10:32<13:43,  2.94s/it]

219 loss 0.3627462685108185


 44%|█████████████████████████████████████████████████████████████████████████████▎                                                                                                 | 221/500 [10:35<13:39,  2.94s/it]

220 loss 0.36298513412475586


 44%|█████████████████████████████████████████████████████████████████████████████▋                                                                                                 | 222/500 [10:38<13:38,  2.94s/it]

221 loss 0.36225825548171997


 45%|██████████████████████████████████████████████████████████████████████████████                                                                                                 | 223/500 [10:41<13:34,  2.94s/it]

222 loss 0.35998570919036865


 45%|██████████████████████████████████████████████████████████████████████████████▍                                                                                                | 224/500 [10:44<13:30,  2.94s/it]

223 loss 0.3581816554069519


 45%|██████████████████████████████████████████████████████████████████████████████▊                                                                                                | 225/500 [10:47<13:28,  2.94s/it]

224 loss 0.3577498197555542


 45%|███████████████████████████████████████████████████████████████████████████████                                                                                                | 226/500 [10:50<13:25,  2.94s/it]

225 loss 0.358236163854599


 45%|███████████████████████████████████████████████████████████████████████████████▍                                                                                               | 227/500 [10:53<13:21,  2.94s/it]

226 loss 0.3582076132297516


 46%|███████████████████████████████████████████████████████████████████████████████▊                                                                                               | 228/500 [10:56<13:18,  2.94s/it]

227 loss 0.3574725389480591


 46%|████████████████████████████████████████████████████████████████████████████████▏                                                                                              | 229/500 [10:59<13:16,  2.94s/it]

228 loss 0.35702285170555115


 46%|████████████████████████████████████████████████████████████████████████████████▌                                                                                              | 230/500 [11:02<13:12,  2.94s/it]

229 loss 0.35662642121315


 46%|████████████████████████████████████████████████████████████████████████████████▊                                                                                              | 231/500 [11:05<13:11,  2.94s/it]

230 loss 0.35581669211387634


 46%|█████████████████████████████████████████████████████████████████████████████████▏                                                                                             | 232/500 [11:07<13:07,  2.94s/it]

231 loss 0.35441890358924866


 47%|█████████████████████████████████████████████████████████████████████████████████▌                                                                                             | 233/500 [11:10<13:03,  2.94s/it]

232 loss 0.3535356819629669


 47%|█████████████████████████████████████████████████████████████████████████████████▉                                                                                             | 234/500 [11:13<13:01,  2.94s/it]

233 loss 0.353429913520813


 47%|██████████████████████████████████████████████████████████████████████████████████▎                                                                                            | 235/500 [11:16<12:59,  2.94s/it]

234 loss 0.3533918559551239


 47%|██████████████████████████████████████████████████████████████████████████████████▌                                                                                            | 236/500 [11:19<12:57,  2.94s/it]

235 loss 0.35308071970939636


 47%|██████████████████████████████████████████████████████████████████████████████████▉                                                                                            | 237/500 [11:22<12:55,  2.95s/it]

236 loss 0.35207921266555786


 48%|███████████████████████████████████████████████████████████████████████████████████▎                                                                                           | 238/500 [11:25<12:51,  2.94s/it]

237 loss 0.3511037826538086


 48%|███████████████████████████████████████████████████████████████████████████████████▋                                                                                           | 239/500 [11:28<12:49,  2.95s/it]

238 loss 0.35055604577064514


 48%|████████████████████████████████████████████████████████████████████████████████████                                                                                           | 240/500 [11:31<12:45,  2.94s/it]

239 loss 0.35024622082710266


 48%|████████████████████████████████████████████████████████████████████████████████████▎                                                                                          | 241/500 [11:34<12:41,  2.94s/it]

240 loss 0.34953105449676514


 48%|████████████████████████████████████████████████████████████████████████████████████▋                                                                                          | 242/500 [11:37<12:38,  2.94s/it]

241 loss 0.3487100601196289


 49%|█████████████████████████████████████████████████████████████████████████████████████                                                                                          | 243/500 [11:40<12:34,  2.94s/it]

242 loss 0.34845390915870667


 49%|█████████████████████████████████████████████████████████████████████████████████████▍                                                                                         | 244/500 [11:43<12:32,  2.94s/it]

243 loss 0.3485783636569977


 49%|█████████████████████████████████████████████████████████████████████████████████████▊                                                                                         | 245/500 [11:46<12:30,  2.94s/it]

244 loss 0.3490246534347534


 49%|██████████████████████████████████████████████████████████████████████████████████████                                                                                         | 246/500 [11:49<12:26,  2.94s/it]

245 loss 0.3491789698600769


 49%|██████████████████████████████████████████████████████████████████████████████████████▍                                                                                        | 247/500 [11:52<12:23,  2.94s/it]

246 loss 0.3487376570701599


 50%|██████████████████████████████████████████████████████████████████████████████████████▊                                                                                        | 248/500 [11:55<12:21,  2.94s/it]

247 loss 0.3485976457595825


 50%|███████████████████████████████████████████████████████████████████████████████████████▏                                                                                       | 249/500 [11:58<12:18,  2.94s/it]

248 loss 0.34874534606933594


 50%|███████████████████████████████████████████████████████████████████████████████████████▌                                                                                       | 250/500 [12:00<12:16,  2.95s/it]

249 loss 0.34829190373420715


In [None]:
gpu_device = jax.devices('gpu')[0]

In [None]:
cpu2gpu=jax.jit(lambda x: x, device=gpu_device)

In [None]:
res0 = model(cpu2gpu(initial_conditions), cosmo, conf)
res1 = model(cpu2gpu(initall[9]), cosmo, conf)
res2 = model(cpu2gpu(initall[99]), cosmo, conf)
res3 = model(cpu2gpu(initall[499]), cosmo, conf)


In [31]:
fig, axes  = plt.subplots(2,2, figsize=(10,10))
vmax=-0.9995
vmin=-1.0
axes[0,0].imshow(res0, vmax=vmax,vmin=vmin)
axes[0,1].imshow(res1, vmax=vmax,vmin=vmin)
axes[1,0].imshow(res2, vmax=vmax,vmin=vmin)
axes[1,1].imshow(res3, vmax=vmax,vmin=vmin)
ind=0
iterlist=[1, 10, 100, 500]
for i in range(2):
    for j in range(2):
        axes[i,j].text(40, 30, "iteration: {0}".format(iterlist[ind]), c="w")
        ind+=1

<Figure size 1000x1000 with 4 Axes>

In [30]:
plt.imshow(im_tgt)

<matplotlib.image.AxesImage at 0x7f62581a3e20>

<Figure size 640x480 with 1 Axes>

## Compare with Nbody simulation 

In [6]:
import numpy as np
import readgadget

# input files
snapshot = '/project/chto/chto/Qujote/Fiducial/snapdir/0/ICs_chto/////ics'
ptype    = [1] #[1](CDM), [2](neutrinos) or [1,2](CDM+neutrinos)

# read header
header   = readgadget.header(snapshot)
BoxSize  = header.boxsize/1e3  #Mpc/h
Nall     = header.nall         #Total number of particles
Masses   = header.massarr*1e10 #Masses of the particles in Msun/h
Omega_m  = header.omega_m      #value of Omega_m
Omega_l  = header.omega_l      #value of Omega_l
h        = header.hubble       #value of h
redshift = header.redshift     #redshift of the snapshot
Hubble   = 100.0*np.sqrt(Omega_m*(1.0+redshift)**3+Omega_l)#Value of H(z) in km/s/(Mpc/h)
init_chto = readgadget.read_block("/project/chto/chto/Qujote/Fiducial/snapdir/0/ics_256_init", "POS ", ptype)#positions in Mpc/h

# read positions, velocities and IDs of the particles
disp_chto = readgadget.read_block(snapshot, "POS ", ptype)/1000#positions in Mpc/h
vel_chto = readgadget.read_block(snapshot, "VEL ", ptype)     #peculiar velocities in km/s
ids_chto = readgadget.read_block(snapshot, "ID  ", ptype)   

In [7]:
cosmo_in = [0.3175, 0.049, 0.6711, 0.9624, 0.834]#Omega_m                 Omega_b                  h                        n_s                      sigma_8


In [8]:
#k = jnp.logspace(-4, 1, 128)
nmesh = round(len(disp_chto)**(1/3))
mesh_shape = [nmesh, nmesh, nmesh]
box_size = [BoxSize, BoxSize, BoxSize]

In [None]:
particles = init_chto.astype(np.float64)
# Initial displacement

cosmo = jc.Cosmology(Omega_c=cosmo_in[0]-cosmo_in[1], Omega_b=cosmo_in[1], h=cosmo_in[2], sigma8 = cosmo_in[4], n_s=cosmo_in[3],
                      Omega_k=0., w0=-1., wa=0.)
ain=np.atleast_1d(1)
_ =  growth_rate_second(cosmo,ain)
_ = growth_rate(cosmo, ain)
_ = growth_factor(cosmo, ain)
jc.background.radial_comoving_distance(cosmo, ain)
    

dx = disp_chto.astype(np.float64)/BoxSize*nmesh
dt0= 0.01
p= (vel_chto.astype(np.float64))*header.time/100/BoxSize*nmesh#/cosmo.h





conf = Configuration(mesh_shape=[nmesh, nmesh, nmesh], 
                     snapshots= [1],
                     BoxSize=box_size, 
                     initial_t0=header.time,
                     final_time = 1,
                     density_plane_npix = nmesh,
                     dt0=dt0)



ode_fn = tree.map(
        FPMODE,
        symplectic_ode(conf.mesh_shape, paint_absolute_pos=True)
)
solver = FPMLeapFrog(initial_t0=conf.initial_t0, final_t1=conf.final_time)
stepsize_controller = ConstantStepSize()

initforce = pm_forces(
        init_chto/conf.BoxSize*conf.mesh_shape[0]/1000+dx,
        mesh_shape=mesh_shape,
        paint_absolute_pos=True,
    )* 1.5* cosmo.Omega_m  

y0= jnp.stack([init_chto/BoxSize*nmesh/1000+dx, p,initforce], axis=0)

res = diffeqsolve(ode_fn,solver,\
                  t0=conf.initial_t0,\
                  t1=conf.final_time,\
                  dt0=conf.dt0,\
                  y0=y0,
                  args=[cosmo, cosmo._workspace, conf.initial_t0, conf, 0, 0],\
                  saveat=SaveAt(t1=True),
                  stepsize_controller=stepsize_controller,adjoint=diffrax.ReversibleAdjoint())


  return lax_numpy.astype(self, dtype, copy=copy, device=device)


In [None]:
res

In [None]:
from jaxpm.plotting import plot_fields_single_projection
fields={}
field = res.ys[0][0]
fields[f"a=1"] = jnp.log10(cic_paint(jnp.zeros(mesh_shape) , field)+1)
plot_fields_single_projection(fields)


# Check powerspectra 

In [19]:
import numpy as np
import readgadget
import MAS_library as MASL

In [20]:
#snapshot = '/project/chto/chto/Qujote/Fiducial/snapdir/0/ICs/ics' #location of the snapshot

snapshot = '/project/chto/chto/Qujote/Fiducial/snapdir/0/snapdir_004/snap_004' #location of the snapshot
grid     = 1024  #the density field will have grid^3 voxels
MAS      = 'CIC'  #Mass-assignment scheme:'NGP', 'CIC', 'TSC', 'PCS'
verbose  = True   #whether to print information about the progress
ptype    = [1]    #[1](CDM), [2](neutrinos) or [1,2](CDM+neutrinos)
# read header
header   = readgadget.header(snapshot)
BoxSize  = header.boxsize/1e3  #Mpc/h
redshift = header.redshift     #redshift of the snapshot
#Masses   = header.massarr*1e10 #Masses of the particles in Msun/h

# read positions, velocities and IDs of the particles
pos = readgadget.read_block(snapshot, "POS ", ptype)/1e3 #positions in Mpc/h
delta = np.zeros((grid,grid,grid), dtype=np.float32)
# construct 3D density field
MASL.MA(pos, delta, BoxSize, MAS, verbose=verbose)
delta /= np.mean(delta, dtype=np.float64);  delta -= 1.0

print('%.3f < delta < %.3f'%(np.min(delta), np.max(delta)))
print('<delta> = %.3f'%np.mean(delta))
print('shape of the matrix:', delta.shape)
print('density field data type:', delta.dtype)



Using CIC mass assignment scheme
Time taken = 37.533 seconds

-1.000 < delta < 4780.165
<delta> = 0.000
shape of the matrix: (1024, 1024, 1024)
density field data type: float32


In [21]:
# compute power spectrum
import Pk_library as PKL
axis=0
threads=1
Pk = PKL.Pk(delta, BoxSize, axis, MAS, threads, verbose)
# Pk is a python class containing the 1D, 2D and 3D power spectra, that can be retrieved as

# 1D P(k)
k1D      = Pk.k1D
Pk1D     = Pk.Pk1D
Nmodes1D = Pk.Nmodes1D


Computing power spectrum of the field...
Time to complete loop = 34.68
Time taken = 57.62 seconds


In [22]:
def get_delta(pos_in):
    delta = np.zeros((grid,grid,grid), dtype=np.float32)
    # construct 3D density field
    MASL.MA(pos_in, delta, BoxSize, MAS, verbose=verbose)
    #delta *= Masses[1]
    
    # now check that the mass in the density field is equal to the total mass in the simulation
    #print('%.3e should be equal to\n%.3e'%(np.sum(delta, dtype=np.float64), pos.shape[0]*Masses[1]))
    # at this point, delta contains the effective number of particles in each voxel
    # now compute overdensity and density constrast
    delta /= np.mean(delta, dtype=np.float64);  delta -= 1.0
    return delta

In [23]:
def getpos(sol, boxsize, nmesh):
    return (sol/nmesh*boxsize)%boxsize

In [25]:
delta_jaxpm =get_delta(np.asarray(getpos(res.ys[0][0], BoxSize, nmesh)).astype('float32'))
Pk_jaxpm = PKL.Pk(delta_jaxpm, BoxSize, axis, MAS, threads, verbose)



Using CIC mass assignment scheme
Time taken = 3.972 seconds


Computing power spectrum of the field...
Time to complete loop = 34.78
Time taken = 57.59 seconds


In [26]:
clist = ['k','r','g','b', 'm']
fig, axes=plt.subplots(2,1,figsize=(5,10), sharex=True)
axes[0].plot(Pk.k3D, Pk.Pk[:,0], label="Qujote (1024^3", c=clist[0])
axes[0].plot(Pk_jaxpm.k3D, Pk_jaxpm.Pk[:,0], label="jaxpm Leapfrog Runtime 5.61s ", c=clist[4])

axes[0].set_xscale('log')
axes[0].set_yscale('log')
axes[1].plot(Pk_jaxpm.k3D, Pk_jaxpm.Pk[:,0]/Pk.Pk[:,0], label="jaxpm Leapfrog", c=clist[4])

axes[0].legend()
axes[0].set_ylabel("pk")
axes[1].set_ylabel("pk/pk[high res]")
axes[1].set_xlabel("k (h/Mpc)")
axes[1].set_ylim(0.1,1.5)
plt.ylim(0.1,1.5)

(0.1, 1.5)

<Figure size 500x1000 with 2 Axes>