In [72]:
import warnings
import logging
from dataclasses import dataclass
from numbers import Number
from typing import Any
from typing import Dict
from typing import final
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import torch
from torch.distributions import Distribution
from torch.distributions import Laplace
from torch.distributions import Normal

from disent.frameworks.ae._unsupervised__ae import Ae
from disent.frameworks.vae import Vae
from disent.frameworks.vae import BetaVae

from disent.frameworks.helper.latent_distributions import LatentDistsHandler
from disent.frameworks.helper.latent_distributions import make_latent_distribution
from disent.frameworks.helper.util import detach_all

from disent.util import map_all

from dataclasses import fields
from typing import Sequence
from typing import Tuple, final

import numpy as np


from disent.frameworks.helper.reductions import loss_reduction
from disent.frameworks.helper.util import compute_ave_loss
from disent.frameworks.helper.latent_distributions import LatentDistsHandler


from disent.model.ae.base import AutoEncoder
REQUIRED_Z_MULTIPLIER = 2
REQUIRED_OBS = 1


#_model: AutoEncoder = make_model_fn() 
# --------------------------------------------------------------------- #
# VAE Training Step                                                     #
# --------------------------------------------------------------------- #

def _get_xs_and_targs(batch: Dict[str, Tuple[torch.Tensor, ...]]) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
    xs_targ = batch['x_targ']
    if 'x' not in batch:
        warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ')
        xs = xs_targ
    else:
        xs = batch['x']
    # check that we have the correct number of inputs
    if (len(xs) != REQUIRED_OBS) or (len(xs_targ) != REQUIRED_OBS):
        log.warning(f'batch len(xs)={len(xs)} and len(xs_targ)={len(xs_targ)} observation count mismatch, requires: {REQUIRED_OBS}')
    # done
    return xs, xs_targ

def do_training_step(batch, batch_idx):
    xs, xs_targ = _get_xs_and_targs(batch, batch_idx)

    # FORWARD
    # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
    # latent distribution parameterizations
    ds_posterior, ds_prior = map_all(encode_dists, xs, collect_returned=True)

    """
    # [HOOK] disable learnt scale values
    ds_posterior, ds_prior = _hook_intercept_ds_disable_scale(ds_posterior, ds_prior)
    # [HOOK] intercept latent parameterizations
    ds_posterior, ds_prior, logs_intercept_ds = hook_intercept_ds(ds_posterior, ds_prior)
    """
    # sample from dists
    zs_sampled = tuple(d.rsample() for d in ds_posterior)
    # reconstruct without the final activation
    xs_partial_recon = map_all(decode_partial, detach_all(zs_sampled, if_=cfg.disable_decoder))
    # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #

    # LOSS
    # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
    # compute all the recon losses
    recon_loss, logs_recon = compute_ave_recon_loss(xs_partial_recon, xs_targ)
    # compute all the regularization losses
    reg_loss, logs_reg = compute_ave_reg_loss(ds_posterior, ds_prior, zs_sampled)
    # [HOOK] augment loss
    aug_loss, logs_aug = hook_compute_ave_aug_loss(ds_posterior=ds_posterior, ds_prior=ds_prior, zs_sampled=zs_sampled, xs_partial_recon=xs_partial_recon, xs_targ=xs_targ)
    # compute combined loss
    loss = 0
    if not cfg.disable_rec_loss: loss += recon_loss
    if not cfg.disable_aug_loss: loss += aug_loss
    if not cfg.disable_reg_loss: loss += reg_loss
    # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #

    # return values
    return loss, {
        **logs_intercept_ds,
        **logs_recon,
        **logs_reg,
        **logs_aug,
        'recon_loss': recon_loss,
        'reg_loss': reg_loss,
        'aug_loss': aug_loss,
    }

# --------------------------------------------------------------------- #
# Delete AE Hooks                                                       #
# --------------------------------------------------------------------- #


def hook_ae_intercept_zs(zs: Sequence[torch.Tensor]) -> Tuple[Sequence[torch.Tensor], Dict[str, Any]]:
    raise NotImplementedError('This function should never be used or overridden by VAE methods!')  # pragma: no cover


def hook_ae_compute_ave_aug_loss(zs: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
    raise NotImplementedError('This function should never be used or overridden by VAE methods!')  # pragma: no cover

# --------------------------------------------------------------------- #
# Private Hooks                                                         #
# --------------------------------------------------------------------- #

def _hook_intercept_ds_disable_scale(ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]):
    # disable posterior scales
    if cfg.disable_posterior_scale is not None:
        for d_posterior in ds_posterior:
            assert isinstance(d_posterior, (Normal, Laplace))
            d_posterior.scale = torch.full_like(d_posterior.scale, fill_value=cfg.disable_posterior_scale)
    # return modified values
    return ds_posterior, ds_prior

# --------------------------------------------------------------------- #
# Overrideable Hooks                                                    #
# --------------------------------------------------------------------- #

def hook_intercept_ds(ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]) -> Tuple[Sequence[Distribution], Sequence[Distribution], Dict[str, Any]]:
    return ds_posterior, ds_prior, {}

def hook_compute_ave_aug_loss(ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution], zs_sampled: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
    return 0, {}

def compute_ave_reg_loss(ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution], zs_sampled: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
    # compute regularization loss (kl divergence)
    kl_loss = latents_handler.compute_ave_kl_loss(ds_posterior, ds_prior, zs_sampled)
    # return logs
    return kl_loss, {
        'kl_loss': kl_loss,
    }

# --------------------------------------------------------------------- #
# VAE - Encoding - Overrides AE                                         #
# --------------------------------------------------------------------- #

def encode(x: torch.Tensor) -> torch.Tensor:
    """Get the deterministic latent representation (useful for visualisation)"""
    z_raw = module.encode(x)
    z = latents_handler.encoding_to_representation(z_raw)
    return z

cfg=BetaVae.cfg(beta=0.003, loss_reduction='mean')
latents_handler = make_latent_distribution(cfg.latent_distribution, kl_mode=cfg.kl_loss_mode, reduction=cfg.loss_reduction)


def encode_dists(x: torch.Tensor) -> Tuple[Distribution, Distribution]:
    """Get parametrisations of the latent distributions, which are sampled from during training."""
    z_raw = module.encode(x)
    z_posterior, z_prior = latents_handler.encoding_to_dists(z_raw)
    return z_posterior, z_prior



In [89]:
from sklearn import linear_model
from disent.dataset.groundtruth import GroundTruthDatasetTriples
from disent.dataset.groundtruth import GroundTruthDistDataset
from disent.metrics._flatness import get_device
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.data.groundtruth import XYObjectData, XYSquaresData
from disent.frameworks.vae import BetaVae
from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder
from disent.transform import ToStandardisedTensor
from disent.util import colors
from disent.util import Timer

def get_str(r):
    return ', '.join(f'{k}={v:6.4f}' for k, v in r.items())

def print_r(name, steps, result, clr=colors.lYLW, t: Timer = None):
    print(f'{clr}{name:<13} ({steps:>04}){f" {colors.GRY}[{t.pretty}]{clr}" if t else ""}: {get_str(result)}{colors.RST}')

def calculate(name, steps, dataset, get_repr):
    print(get_repr)
    #global aggregate_measure_distances_along_factor
    #with Timer() as t:
    #    r = {
    #    #**metric_flatness_components(dataset, get_repr, factor_repeats=64, batch_size=64),
    #    #    **metric_flatness(dataset, get_repr, factor_repeats=64, batch_size=64),
    #    }
    #results.append((name, steps, r))
    #print_r(name, steps, r, colors.lRED, t=t)
    #print(colors.GRY, '='*100, colors.RST, sep='')
    #return r
    

class XYOverlapData(XYSquaresData):
    def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3, rgb=True):
        if grid_spacing is None:
            grid_spacing = (square_size+1) // 2
        super().__init__(square_size=square_size, grid_size=grid_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb)

####################
#### train. ########
####################

results=[]
data= XYSquaresData()
dataset = GroundTruthDistDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True)
cfg=BetaVae.cfg(beta=0.003, loss_reduction='mean')
module = BetaVae(
    make_optimizer_fn=lambda params: Adam(params, lr=5e-4),
    make_model_fn=lambda: AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=cfg
)

latents_handler = make_latent_distribution(cfg.latent_distribution, kl_mode=cfg.kl_loss_mode, reduction=cfg.loss_reduction)

datum = next(iter(dataloader))
[x] = datum['x_targ']
z_posterior, z_prior = module.encode_dists(x)
z_posterior, z_prior

(Normal(loc: torch.Size([32, 6]), scale: torch.Size([32, 6])),
 Normal(loc: torch.Size([32, 6]), scale: torch.Size([32, 6])))

In [101]:
z_prior.sample()

tensor([[-0.6505, -0.9901,  0.4339,  0.5095,  0.8773, -1.7774],
        [ 0.4489,  1.8013,  0.7695, -0.3043,  0.4200,  1.4606],
        [ 1.9153,  1.0503, -2.0272,  0.6760,  1.3001, -1.1350],
        [ 0.7543,  0.3628,  1.0625, -1.5774,  0.1372, -0.9330],
        [ 0.0245, -0.0343, -0.5464, -1.5049, -1.2338,  0.0360],
        [-0.0178,  0.0234, -0.2680,  0.0583,  0.2498,  1.1666],
        [-0.6733, -2.5286,  1.2362, -0.7957, -0.1021,  0.3970],
        [-0.0639,  0.6298,  1.4576,  0.1052, -0.3022, -0.3363],
        [ 1.4955, -0.0432, -2.7151, -0.5484,  0.1370,  0.2457],
        [-0.8257, -1.3647, -0.9488, -0.4407, -0.2710,  0.6412],
        [-1.2137,  0.3191,  0.1291, -0.8512,  0.7040,  1.0466],
        [-0.6759,  0.5157,  1.7835,  0.0582,  0.2333, -0.6896],
        [-1.0725,  1.8992,  0.5191, -0.9337,  1.7056,  0.6230],
        [ 1.5748,  2.1236,  0.2903,  0.5397, -0.8111,  0.3532],
        [ 0.5568,  0.4835, -1.0068, -1.3208,  0.6463, -1.5340],
        [ 0.2054,  0.3039, -0.0127, -0.3