In [1]:
import os
from warnings import filterwarnings

import torch

from clipppy.commands.nre.validate import MultiNREValidator
# noinspection PyUnresolvedReferences
from clipppy.patches import torch_numpy


torch.set_default_tensor_type(torch.cuda.FloatTensor)


filterwarnings('ignore', message='indexing past lexsort depth may impact performance.')
filterwarnings('ignore', module='torch.nn.modules.lazy')
filterwarnings('ignore', message='Named tensors')
filterwarnings('ignore', module='pytorch_lightning.trainer.data_loading',
               message='The dataloader, train_dataloader, does not have many workers')

In [2]:
from libsimplesn import SimpleSN

simplesn = SimpleSN(survey='pantheon-g10', datatype='mphotoz', N=2_000, suffix=0, version=0)
config = simplesn.config('simplesn-marginal.yaml', gen=True)
defs = config.kwargs['defs']
nre = config.lightning_nre

#### Enact constraints

In [3]:
from operator import itemgetter


nre.dataset_config.kwargs['ranges'].update({
    key: itemgetter('lower', 'upper')(val)
    for key, val in simplesn.hdi_bounds[(simplesn.datatype, simplesn.N)].to_dict().items()
})
nre.dataset_config.kwargs['ranges']

{'mean_x1': (-0.14367517968974275, 0.1620717020494652),
 'mean_c': (-0.015384074488660648, 0.014917194403638435),
 'log10_R_x1': (-0.04255821021574628, 0.05537624820292962),
 'log10_R_c': (-1.0525528591230264, -0.9466537627692363),
 'Om0': (-0.19035878579010057, 0.7069034571292946),
 'Ode0': (-0.5128887348585609, 1.736940199103507),
 'alpha': (0.07285661810403343, 0.2236113293941773),
 'beta': (2.286168990065564, 3.913619512517845),
 'mean_M0': (-19.748456623654715, -19.273431582009305),
 'sigma_res': (-0.06325484654371223, 0.29706277981544366),
 'sigma_z': (0.03315636378310852, 0.045306562727659344)}

#### Prepare plotting

In [4]:
from clipppy.utils.messengers import CollectSitesMessenger
from typing import Iterable
from clipppy.commands.nre import ClipppyDataset
from collections import defaultdict
from clipppy.utils.plotting.nre import MultiNREPlotter


def get_priors(param_names: Iterable[str], dataset: ClipppyDataset):
    with CollectSitesMessenger(*param_names) as trace:
        dataset.get_trace()
    return {name: site['fn'] for name, site in trace.items()}


priors = get_priors(nre.param_names, nre.dataset.dataset)
ranges = {key: (prior.support.lower_bound, prior.support.upper_bound) for key, prior in priors.items()}


def nrepper(ngrid=256, ngrid_cosmo=32):
    return MultiNREPlotter(
        groups=[('Om0', 'Ode0')],
        grid_sizes=defaultdict(lambda: ngrid, Om0=ngrid_cosmo, Ode0=ngrid_cosmo),
        priors=priors, ranges=ranges, labels=defs['labels']
    )

ranges

####  Define network

In [5]:
from clipppy.commands.lightning.config.schedulers import StepLR
from clipppy.commands.lightning.hyper import Scheduler
from torch import nn
from clipppy.commands.lightning import hyper as h
from clipppy.utils.nn import Movedim, USequential
from clipppy.commands.nre import MultiNRETail, UWhiteningTail, WhiteningHead


MAX_BATCH = 32 if simplesn.N > 50_000 else 64
STRUCTURE_NAME = 'onlycosmo-fc'
HPARAMS_NAME = 'step'


hp = h.Hyperparams(
    h.Structure(
        head=h.MLP(3, 128, 32),
        tail=h.BaseHParams(
            cosmo=h.Tail(
                thead=h.MLP(2, 256),
                xhead=h.BaseHParams(
                    dropout=0.995,
                    net=h.Linear(256)
                ),
                net=h.OMLP(3, 256)
            )
        )
    ),
    h.Training(1e-4, 64, Scheduler(StepLR, step_size=2000, gamma=0.5))
)

nre.head = WhiteningHead(head=USequential(
    nn.Unflatten(-1, (-1, 3)),
    hp.structure.head.make(),
    Movedim(source=-1, destination=-2)
), event_dims={'data': 2})
nre.tail = MultiNRETail(tails={
    ('Om0', 'Ode0'): UWhiteningTail(
        thead=hp.structure.tail.cosmo.thead.make(),
        xhead=nn.Sequential(nn.Flatten(-2), nn.Dropout(hp.structure.tail.cosmo.xhead.dropout), hp.structure.tail.cosmo.xhead.net.make()),
#         xhead=nn.Sequential(
#             convlayer(32, 100), nn.Flatten(-2), mlp(256)
#         ),
        net=hp.structure.tail.cosmo.net.make()),
})

nre.just_save_hyperparameters(dict(hp.collapse()))

#### Define learning

In [6]:
# Learning rate
nre.lr = nre.hparams['training/lr']

# Batch size
assert (nre.hparams['training/batch_size'] < MAX_BATCH
        or not nre.hparams['training/batch_size'] % MAX_BATCH)
memory_batch_size = min(nre.hparams['training/batch_size'], MAX_BATCH)
accumulate_grad_batches = nre.hparams['training/batch_size'] // memory_batch_size
nre.dataset_config.kwargs['batch_size'] = memory_batch_size

# Scheduler
if hp.training.scheduler:
    nre.scheduler_config = hp.training.scheduler.make()

In [7]:
from clipppy.commands.lightning.callbacks import MultiPosteriorCallback, MultiValidationCallback
from clipppy.commands.lightning.patches import LearningRateMonitor, ModelCheckpoint, TensorBoardLogger, Trainer


BASE_LOGDIR = 'lightning_logs'
CHECKPOINT_EVERY = 2000
VALIDATE_EVERY = 500
VALIDATE_SAMPLES = 1024


name = os.path.join(simplesn.basedata_prefix, STRUCTURE_NAME, HPARAMS_NAME)
validator = MultiNREValidator(VALIDATE_SAMPLES // MAX_BATCH, MAX_BATCH, nre.dataset, nrepper())

trainer = Trainer(
    gpus=1, max_epochs=-1,
    max_steps=10000,
    logger=TensorBoardLogger(BASE_LOGDIR, name),
    callbacks=[
        MultiValidationCallback(VALIDATE_EVERY, nre, validator),
        MultiPosteriorCallback(VALIDATE_EVERY, nre, validator.nrep, simplesn.data),
        LearningRateMonitor(),
        ModelCheckpoint(every_n_train_steps=CHECKPOINT_EVERY)
    ],
    accumulate_grad_batches=accumulate_grad_batches
)
trainer.fit(nre, nre.training_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training: 0it [00:00, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]