In [1]:
import os
from warnings import filterwarnings

import torch

from clipppy.commands.nre.validate import MultiNREValidator
# noinspection PyUnresolvedReferences
from clipppy.patches import torch_numpy


torch.set_default_tensor_type(torch.cuda.FloatTensor)


filterwarnings('ignore', message='indexing past lexsort depth may impact performance.')
filterwarnings('ignore', module='torch.nn.modules.lazy')
filterwarnings('ignore', message='Named tensors')
filterwarnings('ignore', module='pytorch_lightning.trainer.data_loading',
               message='The dataloader, train_dataloader, does not have many workers')

In [None]:
from libsimplesn import SimpleSN

simplesn = SimpleSN(survey='pantheon-g10', datatype='photoz', N=100_000, suffix=0, version=0)
config = simplesn.config('simplesn.yaml', gen=True)
defs = config.kwargs['defs']
nre = config.lightning_nre

ONLYCOSMO = True

COSMOGROUP = 'Om0', 'Ode0'
LATENT_PARAMS = 'M0',

groups = [COSMOGROUP]
if not ONLYCOSMO:
    groups.extend(key for key in nre.param_names if key not in COSMOGROUP)

#### Enact constraints

In [None]:
ZOOMS_MANUAL = {100_000 : [{}, {
    'Om0': (0, 0.75), 'Ode0': (0, 1.5),
    'sigma_z': (0.03, 0.05),
    'mean_M0': (-20, -19),
    'sigma_res': (0, 0.6),
    'mean_x1': (-0.5, 0.5),
    'log10_R_x1': (-5, 0.2),
    'mean_c': (-0.05, 0.05),
    'log10_R_c': (-5, -0.8)
}, {
    'Om0': (0.1, 0.5), 'Ode0': (0, 1.1),
    'sigma_z': (0.037, 0.043),
    'mean_M0': (-19.7, -19.2),
    'sigma_res': (0, 0.35),
    'mean_x1': (-0.2, 0.2),
    'log10_R_x1': (-0.2, 0.2),
    'mean_c': (-0.01, 0.02),
    'log10_R_c': (-1.2, -0.8)
}, {
    'Om0': (0.2, 0.5), 'Ode0': (0.4, 1),
    'alpha': (0.08, 0.2), 'beta': (2.7, 3.5),
    'sigma_z': (0.038, 0.042),
    'mean_M0': (-19.55, -19.4),
    'sigma_res': (0, 0.2),
    'mean_x1': (-0.05, 0.05),
    'log10_R_x1': (-0.05, 0.05),
    'mean_c': (-0.005, 0.005),
    'log10_R_c': (-1.05, -0.95)
}, {
#     'Om0': (0.22, 0.35), 'Ode0': (0.5, 0.8),
    'Om0': (0.22, 0.38), 'Ode0': (0.5, 0.9),
    'alpha': (0.12, 0.155), 'beta': (2.9, 3.2),
    'sigma_z': (0.0385, 0.041),
    'mean_M0': (-19.525, -19.46),
    'sigma_res': (0.05, 0.15),
    'mean_x1': (-0.02, 0.03),
    'log10_R_x1': (-0.01, 0.01),
    'mean_c': (-0.002, 0.003),
    'log10_R_c': (-1.01, -0.99)
}]}

BOUNDS = simplesn.zoom_bounds
# ZOOM_LEVEL = len(BOUNDS) - 1
ZOOM_LEVEL = 5

nre.dataset_config.kwargs['ranges'].update(BOUNDS[ZOOM_LEVEL])

#### Prepare plotting

In [None]:
from clipppy.utils.messengers import CollectSitesMessenger
from typing import Iterable
from clipppy.commands.nre import ClipppyDataset
from collections import defaultdict
from clipppy.utils.plotting.nre import MultiNREPlotter


def get_priors(param_names: Iterable[str], dataset: ClipppyDataset):
    with CollectSitesMessenger(*param_names) as trace:
        dataset.get_trace()
    return {name: site['fn'] for name, site in trace.items()}


priors = get_priors(nre.param_names, nre.dataset.dataset)
ranges = {key: (prior.support.lower_bound, prior.support.upper_bound) for key, prior in priors.items()}


def nrepper(ngrid=256, ngrid_cosmo=32):
    return MultiNREPlotter(
        groups=groups,
        grid_sizes=defaultdict(lambda: ngrid, Om0=ngrid_cosmo, Ode0=ngrid_cosmo),
        priors=priors, ranges=ranges, labels=defs['labels']
    )

ranges

####  Define network

In [None]:
from clipppy.commands.lightning.config.schedulers import StepLR
from clipppy.commands.lightning.hyper import Scheduler
from torch import nn
from clipppy.commands.lightning import hyper as h
from clipppy.utils.nn import Movedim, USequential
from clipppy.commands.nre import IUWhiteningTail, MultiNRETail, UWhiteningTail, WhiteningHead


MAX_BATCH = 32 if simplesn.N > 50_000 else 64

STRUCTURE_NAME = ('onlycosmo' if ONLYCOSMO else 'all') + '-' + str(ZOOM_LEVEL)

# LR_NAME = 'lowlr'
LR_NAME = 'step'

STEP_SIZE = 5000
MAX_STEPS = 20000


hp = h.Hyperparams(
    h.Structure(
        head=h.BaseHParams(
            pre=h.MLP(3, 128, 32),
            summary=h.BaseHParams(
                dropout=0.1,
                net=h.MLP(3, 256)
            )
        ),
        tail=h.BaseHParams(
            cosmo=h.Tail(
                thead=h.MLP(3, 256),
                net=h.OMLP(3, 256)
            ),
            other=h.Tail(
                thead=h.MLP(3, 256),
                net=h.OMLP(3, 256)
            ),
            latent=h.BaseHParams(
                ihead=h.MLP(3, 128, 16),
                net=h.OMLP(3, 128),
                subsample=None,
            )
        )
    ),
    h.Training(
        1e-4, 64,
        Scheduler(StepLR, step_size=STEP_SIZE, gamma=0.5)
        if LR_NAME == 'step' else None
    )
)

ADDITIONAL = torch.cat((
    simplesn.zcmb.unsqueeze(-1),
    simplesn.vars_scale_tril.flatten(-2)
), -1)

nre.head = WhiteningHead(head=nn.Sequential(
    nn.Unflatten(-1, (-1, 3)),
    hp.structure.head.pre.make(),
    Movedim(source=-1, destination=-2),
    USequential(
        nn.Flatten(-2), nn.Dropout(hp.structure.head.summary.dropout), hp.structure.head.summary.net.make()
    )
), event_dims={'data': 2})
nre.tail = MultiNRETail(tails={
    COSMOGROUP: UWhiteningTail(
        thead=hp.structure.tail.cosmo.thead.make(),
        xhead=hp.structure.tail.cosmo.xhead.make(),
#         xhead=nn.Sequential(nn.Dropout(hp.structure.tail.cosmo.xhead.dropout),
#                             hp.structure.tail.cosmo.xhead.net.make()),
        net=hp.structure.tail.cosmo.net.make()),
    **{group: UWhiteningTail(
        thead=hp.structure.tail.other.thead.make(),
        xhead=hp.structure.tail.other.xhead.make(),
#         xhead=nn.Sequential(nn.Dropout(hp.structure.tail.other.xhead.dropout),
#                             hp.structure.tail.other.xhead.net.make()),
        net=hp.structure.tail.other.net.make()
    ) for group in groups if group is not COSMOGROUP},
    **{group: IUWhiteningTail(
        additional=ADDITIONAL,
        ihead=hp.structure.tail.latent.ihead.make(),
        net=hp.structure.tail.latent.net.make(),
        subsample=hp.structure.tail.latent.subsample,
        summarize=False
    ) for group in LATENT_PARAMS}
})

nre.just_save_hyperparameters(dict(hp.collapse()))
dict(hp.collapse())

#### Define learning

In [None]:
# Learning rate
nre.lr = nre.hparams['training/lr']

# Batch size
assert (nre.hparams['training/batch_size'] < MAX_BATCH
        or not nre.hparams['training/batch_size'] % MAX_BATCH)
memory_batch_size = min(nre.hparams['training/batch_size'], MAX_BATCH)
accumulate_grad_batches = nre.hparams['training/batch_size'] // memory_batch_size
nre.dataset_config.kwargs['batch_size'] = memory_batch_size

# Scheduler
if hp.training.scheduler:
    nre.scheduler_config = hp.training.scheduler.make()

In [None]:
from clipppy.commands.lightning.callbacks import MultiPosteriorCallback, MultiValidationCallback
from clipppy.commands.lightning.patches import LearningRateMonitor, ModelCheckpoint, TensorBoardLogger, Trainer


BASE_LOGDIR = 'lightning_logs'
CHECKPOINT_EVERY = 2000
VALIDATE_EVERY = 500
VALIDATE_SAMPLES = 1024


name = os.path.join(simplesn.basedata_prefix, STRUCTURE_NAME, LR_NAME)
validator = MultiNREValidator(VALIDATE_SAMPLES // MAX_BATCH, MAX_BATCH, nre.dataset, nrepper())

trainer = Trainer(
    gpus=1, max_epochs=-1,
    max_steps=MAX_STEPS,
    logger=TensorBoardLogger(BASE_LOGDIR, name),
    callbacks=[
        MultiValidationCallback(VALIDATE_EVERY, nre, validator),
        MultiPosteriorCallback(VALIDATE_EVERY, nre, validator.nrep, simplesn.data),
        LearningRateMonitor(),
        ModelCheckpoint(every_n_train_steps=CHECKPOINT_EVERY, save_top_k=1)
    ],
    accumulate_grad_batches=accumulate_grad_batches
)
trainer.fit(nre, nre.training_loader)

#### Get next bounds

In [None]:
THRESH = 1e-4

nre.cuda(), nre.eval()

posts = validator.nrep.post({key: simplesn.data[key] for key in nre.obs_names}, nre.head, nre.tail)
(bounds := {
    key: tuple(map(float, val))
    for key, val in validator.nrep.get_bounds_from_post(posts, thresh=THRESH).items()
})

In [None]:
for key, val in bounds.items():
    print(key, (val[1] - val[0]) / (BOUNDS[ZOOM_LEVEL][key][1] - BOUNDS[ZOOM_LEVEL][key][0]))

In [None]:
BOUNDS.append(bounds)
BOUNDS

In [None]:
simplesn.zoom_bounds = BOUNDS

In [10]:
BOUNDS.append(bounds)
BOUNDS

[{},
 {'Om0': (0.03125, 0.90625),
  'Ode0': (0.03125, 1.96875),
  'sigma_z': (0.033259764313697815, 0.044919922947883606),
  'mean_M0': (-20.408203125, -18.416015625),
  'sigma_res': (0.026365235447883606, 0.9980488419532776),
  'alpha': (0.001953125, 0.998046875),
  'beta': (0.0078125, 3.9921875),
  'mean_x1': (-0.19921875, 0.19921875),
  'log10_R_x1': (-4.986328125, 0.208984375),
  'mean_c': (-0.02460937574505806, 0.03867187723517418),
  'log10_R_c': (-4.986328125, -0.611328125)},
 {'Om0': (0.044921875, 0.673828125),
  'Ode0': (0.0615234375, 1.2724609375),
  'sigma_z': (0.03656195476651192, 0.04325743764638901),
  'mean_M0': (-19.696151733398438, -19.252578735351562),
  'sigma_res': (0.02826305478811264, 0.38884878158569336),
  'alpha': (0.00389862060546875, 0.8054428100585938),
  'beta': (0.031158447265625, 3.984405517578125),
  'mean_x1': (-0.0521392822265625, 0.1190643310546875),
  'log10_R_x1': (-0.1664581298828125, 0.0973663330078125),
  'mean_c': (-0.00792388990521431, 0.009626

In [11]:
simplesn.zoom_bounds = BOUNDS