In [None]:
from gatsbi.task_utils.benchmarks import make_generator, make_discriminator
from gatsbi.optimize import Base as Opt
import sbibm

In [None]:
# Make task callables
task_name = "two_moons"
task = sbibm.get_task(task_name)
prior = task.get_prior()
simulator = task.get_simulator()


In [None]:
# Make generator and discriminator networks
gen = make_generator(seq_impwts="impwts",
                     add_noise_kwargs={"lat_dim": 2, "output_dim": 128, "heteroscedastic": True},
                     add_noise_layer=5,
                     gen_nonlin="leaky_relu",
                     gen_nonlin_kwargs={"negative_slope": 0.1},
                     gen_units=[task.dim_data, 128, 128, 128, 128, task.dim_parameters])
dis = make_discriminator(dis_units=[task.dim_data + task.dim_parameters, 2048, 2048, 2048, 2048, 2048, 1],
                         dis_nonlin="leaky_relu",
                         dis_nonlin_kwargs={"negative_slope": 0.1})


In [None]:
# Set training hyperparameters
training_opts = {
    "gen_iter": 1,
    "dis_iter": 1,
    "max_norm_gen": .1,
    "max_norm_dis": .1,
    "num_simulations": 100,
    "sample_seed": None,
    "hold_out": 10,
    "batch_size": 10,
    "log_dataloader": False,
    "stop_thresh": 0.001,
        }
gen_optim_args = [0.0001, [0.9, 0.99]]
dis_optim_args = [0.0001, [0.9, 0.99]]
loss = "cross_entropy"

In [None]:
# Make optimizer
opt = Opt(
        generator=gen,
        discriminator=dis,
        prior=prior,
        simulator=simulator,
        optim_args=[gen_optim_args, dis_optim_args],
        loss=loss,
        training_opts=training_opts
            )

In [None]:
# Train
opt.train(10)