In [1]:
from __future__ import annotations
import sys

sys.path.append("../../")

import math
import logging

import torch
import torch.nn.functional as F
import pytorch_lightning as pl

# from jsonargparse.typing import PositiveInt, PositiveFloat, NonNegativeFloat

import flows.phi_four as phi_four
import flows.transforms as transforms
import flows.utils as utils
from flows.flow_hmc import *
from flows.models import MultilevelFlow
from flows.layers import GlobalRescalingLayer
from flows.distributions import Prior, FreeScalarDistribution

Tensor: TypeAlias = torch.Tensor
BoolTensor: TypeAlias = torch.BoolTensor
Module: TypeAlias = torch.nn.Module
IterableDataset: TypeAlias = torch.utils.data.IterableDataset

logging.getLogger().setLevel("WARNING")


In [2]:
# Model spec
ADDITIVE_BLOCK = {
    "transform": transforms.PointwiseAdditiveTransform,
    "transform_spec": {},
    "net_spec": {
        "hidden_shape": [4, 4],
        "activation": torch.nn.Tanh(),
        "final_activation": torch.nn.Identity(),
        "use_bias": False,
    },
}
AFFINE_BLOCK = {
    "transform": transforms.PointwiseAffineTransform,
    "transform_spec": {},
    "net_spec": {
        "hidden_shape": [4, 4, 4, 4],
        "activation": torch.nn.Tanh(),
        "final_activation": torch.nn.Tanh(),
        "use_bias": False,
    },
}
SPLINE_BLOCK = {
    "transform": transforms.PointwiseRationalQuadraticSplineTransform,
    "transform_spec": {"n_segments": 8, "interval": (-4, 4)},
    "net_spec": {
        "hidden_shape": [4],
        "activation": torch.nn.Tanh(),
        "final_activation": torch.nn.Identity(),
        "use_bias": True,
    },
}


In [4]:
# Target theory
LATTICE_LENGTH = 8
BETA = 0.7
LAM = 0.5

MODEL_SPEC = [
    AFFINE_BLOCK,
    AFFINE_BLOCK,
    "rescaling",
]

N_TRAIN = 10
N_BATCH = 1000
N_BATCH_VAL = 1000

model = MultilevelFlow(
    beta=BETA,
    lam=LAM,
    model_spec=MODEL_SPEC,
)


In [5]:
dist = torch.distributions.Normal(
    loc=torch.zeros((LATTICE_LENGTH, LATTICE_LENGTH)),
    scale=torch.ones((LATTICE_LENGTH, LATTICE_LENGTH)),
)
# dist = FreeScalarDistribution(LATTICE_LENGTH, M_SQ)
train_dataloader = Prior(dist, sample_shape=[N_BATCH, 1])
val_dataloader = Prior(dist, sample_shape=[N_BATCH_VAL, 1])

pbar = utils.JlabProgBar()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    gpus=0,
    max_steps=N_TRAIN,  # total number of training steps
    val_check_interval=100,  # how often to run sampling
    limit_val_batches=1,  # one batch for each val step
    callbacks=[pbar, lr_monitor],
    enable_checkpointing=False,  # manually saving checkpoints
)

trainer.validate(model, val_dataloader)

trainer.fit(model, train_dataloader, val_dataloader)


GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.

  | Name             | Type            | Params
-----------------------------------------------------
0 | flow             | Flow            | 2.2 K 
1 | upsampling_layer | UpsamplingLayer | 0     
-----------------------------------------------------
2.2 K     Trainable params
0         Non-trainable params
2.2 K     Total params
0.009     Total estimated model params size (MB)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       acceptance           0.01001000963151455
          loss              -11.101160049438477
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [6]:
val_dataloader = Prior(dist, sample_shape=[1, 1])
batch = val_dataloader.sample()


In [29]:
batch_cp = apply_reverse_flow_to_fields(torch.clone(batch).detach(), model)[0]
batch_cp = apply_flow_to_fields(batch_cp, model)[0]

# batch_cp, trash = model.flow.inverse(batch)
# batch_cp, trash = model.flow(batch_cp)

((batch_cp - batch)**2).sum()

tensor(1.5671e-12, grad_fn=<SumBackward0>)

In [12]:
batch_cp, trash = model.flow.inverse(batch)

In [13]:
batch

tensor([[[[-1.4253,  0.0672, -1.5074, -1.4993, -0.5472, -2.5435, -1.5111,
           -0.1945],
          [ 0.3814,  0.5978, -0.2750, -0.4200,  0.4945, -0.5013,  0.0804,
            0.3819],
          [-0.0602, -1.3316,  0.9656,  1.5211, -1.8475, -1.4497, -1.5704,
           -0.5208],
          [-1.9197,  0.9528,  0.1118, -0.6883,  1.1194, -0.3986,  1.3732,
           -1.0705],
          [ 2.0687, -0.0036, -0.1874, -2.4672, -1.0532, -0.8711, -0.0351,
            0.0109],
          [-1.0594, -1.6358,  3.1215, -1.0083,  0.1539,  1.2647, -0.4479,
            0.0131],
          [ 1.2747, -0.8006,  1.4830, -0.6272, -1.0876,  1.8610, -1.8756,
           -0.2722],
          [ 0.6857,  1.0311,  1.3519,  0.4791,  0.3655,  0.3717,  0.6464,
            0.6266]]]])

In [54]:
2.2460*1.0324

2.3187704