Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam optimizer doesn't work for some cases #91

Closed
thomasckng opened this issue Jul 12, 2024 · 2 comments
Closed

Adam optimizer doesn't work for some cases #91

thomasckng opened this issue Jul 12, 2024 · 2 comments

Comments

@thomasckng
Copy link
Collaborator

Script:

from jimgw.single_event.runManager import SingleEventPERunManager, SingleEventRun
import jax.numpy as jnp
import jax

import os
outdir = os.path.dirname(__file__)
label = os.path.splitext(os.path.basename(__file__))[0]

jax.config.update("jax_enable_x64", True)

mass_matrix = jnp.eye(15)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
    [
        [10.0, 40.0],
        [0.125, 1.0],
        [0, jnp.pi],
        [0, 2*jnp.pi],
        [0.0, 1.0],
        [0, jnp.pi],
        [0, 2*jnp.pi],
        [0.0, 1.0],
        [0.0, 2000.0],
        [-0.05, 0.05],
        [0.0, 2 * jnp.pi],
        [-1.0, 1.0],
        [0.0, jnp.pi],
        [0.0, 2 * jnp.pi],
        [-1.0, 1.0],
    ]
)


run = SingleEventRun(
    seed=0,
    path='',
    detectors=["H1", "L1"],
    priors={
        "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
        "q": {"name": "MassRatio"},
        "s1": {"name": "Sphere"},
        "s2": {"name": "Sphere"},
        "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
        "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
        "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
        "cos_iota": {"name": "CosIota"},
        "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
        "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
        "sin_dec": {"name": "SinDec"},
    },
    waveform_parameters={"name": "RippleIMRPhenomPv2", "f_ref": 20.0},
    jim_parameters={
        "n_loop_training": 10,
        "n_loop_production": 10,
        "n_local_steps": 150,
        "n_global_steps": 150,
        "n_chains": 500,
        "n_epochs": 50,
        "learning_rate": 0.001,
        "max_samples": 45000,
        "momentum": 0.9,
        "batch_size": 50000,
        "use_global": True,
        "keep_quantile": 0.0,
        "train_thinning": 1,
        "output_thinning": 10,
        "local_sampler_arg": local_sampler_arg,
    },
    likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds},
    injection=True,
    injection_parameters={
        "M_c": 28.6,
        "eta": 0.24,
        "s1_x": 0.05,
        "s1_y": -0.05,
        "s1_z": 0.05,
        "s2_x": -0.05,
        "s2_y": 0.05,
        "s2_z": 0.05,
        "d_L": 440.0,
        "t_c": 0.0,
        "phase_c": 0.0,
        "iota": 0.5,
        "psi": 0.7,
        "ra": 1.2,
        "dec": 0.3,
    },
    data_parameters={
        "trigger_time": 1126259462.4,
        "duration": 4,
        "post_trigger_duration": 2,
        "f_min": 20.0,
        "f_max": 1024.0,
        "tukey_alpha": 0.2,
        "f_sampling": 4096.0,
    },
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))
samples = run_manager.jim.get_samples()
run_manager.save(outdir+'/'+label)
jnp.save(outdir+'/'+label+"_samples.npy", samples)
run_manager.jim.print_summary()

Output:

Run instance provided. Loading from instance.
Initializing detectors.
Injection mode. Need to wait until waveform model is loaded.
Injection mode. Need to wait until waveform model is loaded.
Initializing waveform.
Grabbing GWTC-2 PSD for H1
For detector H1:
The injected optimal SNR is 29.713682403670045
The injected match filter SNR is (30.11265963611454-0.563276239330949j)
Grabbing GWTC-2 PSD for L1
For detector L1:
The injected optimal SNR is 51.04914603748342
The injected match filter SNR is (50.302015972725656-0.13345233070825305j)
Initializing heterodyned likelihood..
No reference parameters are provided, finding it...
Starting the optimizer
Using Adam optimization
Warning: Optimization accessed infinite or NaN log-probabilities.
The reference parameters are {'M_c': nan, 'eta': nan, 's1_x': nan, 's1_y': nan, 's1_z': nan, 's2_x': nan, 's2_y': nan, 's2_z': nan, 'd_L': nan, 't_c': nan, 'phase_c': nan, 'iota': nan, 'psi': nan, 'ra': nan, 'dec': nan}
Constructing reference waveforms..
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/ckng/project/jim_testing/sky_location_frame/ra_dec.py", line 103, in <module>
    run_manager = SingleEventPERunManager(run=run)
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jimgw/single_event/runManager.py", line 127, in __init__
    local_likelihood = self.initialize_likelihood(local_prior)
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jimgw/single_event/runManager.py", line 185, in initialize_likelihood
    return likelihood_presets[name](
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jimgw/single_event/likelihood.py", line 298, in __init__
    f_max = jnp.max(f_valid)
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 268, in max
    return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 260, in _reduce_max
    return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 115, in _reduction
    raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
ValueError: zero-size array to reduction operation max which has no identity
@thomasckng
Copy link
Collaborator Author

Same problem with aligned spin.

from jimgw.single_event.runManager import SingleEventPERunManager, SingleEventRun
import jax.numpy as jnp
import jax

import os
outdir = os.path.dirname(__file__)
label = os.path.splitext(os.path.basename(__file__))[0]

jax.config.update("jax_enable_x64", True)

mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
    [
        [10.0, 40.0],
        [0.125, 1.0],
        [-1.0, 1.0],
        [-1.0, 1.0],
        [0.0, 2000.0],
        [-0.05, 0.05],
        [0.0, 2 * jnp.pi],
        [-1.0, 1.0],
        [0.0, jnp.pi],
        [0.0, 2 * jnp.pi],
        [-1.0, 1.0],
    ]
)


run = SingleEventRun(
    seed=0,
    path='',
    detectors=["H1", "L1"],
    priors={
        "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
        "q": {"name": "MassRatio"},
        "s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
        "s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
        "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
        "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
        "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
        "cos_iota": {"name": "CosIota"},
        "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
        "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
        "sin_dec": {"name": "SinDec"},
    },
    waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
    jim_parameters={
        "n_loop_training": 10,
        "n_loop_production": 10,
        "n_local_steps": 150,
        "n_global_steps": 150,
        "n_chains": 500,
        "n_epochs": 50,
        "learning_rate": 0.001,
        "max_samples": 45000,
        "momentum": 0.9,
        "batch_size": 50000,
        "use_global": True,
        "keep_quantile": 0.0,
        "train_thinning": 1,
        "output_thinning": 10,
        "local_sampler_arg": local_sampler_arg,
    },
    likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds},
    injection=True,
    injection_parameters={
        "M_c": 28.6,
        "eta": 0.24,
        "s1_z": 0.05,
        "s2_z": 0.05,
        "d_L": 440.0,
        "t_c": 0.0,
        "phase_c": 0.0,
        "iota": 0.5,
        "psi": 0.7,
        "ra": 1.2,
        "dec": 0.3,
    },
    data_parameters={
        "trigger_time": 1126259462.4,
        "duration": 4,
        "post_trigger_duration": 2,
        "f_min": 20.0,
        "f_max": 1024.0,
        "tukey_alpha": 0.2,
        "f_sampling": 4096.0,
    },
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))
samples = run_manager.jim.get_samples()
run_manager.save(outdir+'/'+label)
jnp.save(outdir+'/'+label+"_samples.npy", samples)
run_manager.jim.print_summary()

@thomasckng
Copy link
Collaborator Author

Solved by #93

@thomasckng thomasckng moved this to Done in Jim-v1.0.0 Jul 22, 2024
@thomasckng thomasckng removed this from Jim-v1.0.0 Jul 23, 2024
@thomasckng thomasckng moved this to In Progress in Jim-v1.0.0 Jul 23, 2024
@thomasckng thomasckng closed this as not planned Won't fix, can't repro, duplicate, stale Jul 23, 2024
@github-project-automation github-project-automation bot moved this from In Progress to Done in Jim-v1.0.0 Jul 23, 2024
@thomasckng thomasckng removed this from Jim-v1.0.0 Jul 23, 2024
@thomasckng thomasckng moved this to Done in Jim-v1.0.0 Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

No branches or pull requests

1 participant