Skip to content
This repository was archived by the owner on Aug 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions neuralcompression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

from importlib.metadata import PackageNotFoundError, version

from ._outputs import (
DiscriminatorOutput,
HyperpriorCompressedOutput,
HyperpriorOutput,
VqVaeAutoencoderOutput,
)

try:
__version__ = version("neuralcompression")
except PackageNotFoundError:
Expand Down
8 changes: 8 additions & 0 deletions neuralcompression/_outputs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ._discriminator_output import DiscriminatorOutput
from ._hyperprior_output import HyperpriorCompressedOutput, HyperpriorOutput
from ._vqvae_autoencoder_output import VqVaeAutoencoderOutput
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ._hific_discriminator import HiFiCDiscriminator
from ._hific_encoder import HiFiCEncoder
from ._hific_generator import HiFiCGenerator
from typing import NamedTuple, Optional

from torch import Tensor


class DiscriminatorOutput(NamedTuple):
logits: Tensor
target: Optional[Tensor] = None
27 changes: 27 additions & 0 deletions neuralcompression/_outputs/_hyperprior_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, NamedTuple, Optional, Tuple, Union

from torch import Tensor


class HyperpriorOutput(NamedTuple):
image: Tensor
latent: Tensor
latent_likelihoods: Tensor
quantized_latent_likelihoods: Tensor
hyper_latent: Tensor
hyper_latent_likelihoods: Tensor
quantized_hyper_latent_likelihoods: Tensor
quantized_latent: Optional[Tensor] = None
quantized_hyper_latent: Optional[Tensor] = None


class HyperpriorCompressedOutput(NamedTuple):
latent_strings: Union[List[str], List[List[str]]]
hyper_latent_strings: List[str]
image_size: Tuple[int, int]
padded_size: Tuple[int, int]
23 changes: 23 additions & 0 deletions neuralcompression/_outputs/_vqvae_autoencoder_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence, Union

from torch import Tensor


@dataclass
class VqVaeAutoencoderOutput:
image: Optional[Tensor] = None
latent: Optional[Tensor] = None
prequantized_latent: Optional[Union[Tensor, Sequence[Tensor]]] = None
commitment_loss: Optional[Tensor] = None
embedding_loss: Optional[Tensor] = None
codebook_indices: Optional[Union[Tensor, Sequence[Tensor]]] = None
quantize_residuals: Optional[Tensor] = None
num_bytes: Optional[int] = None
quantize_distances: Optional[Tensor] = None
indices: Optional[Tensor] = None
1 change: 1 addition & 0 deletions neuralcompression/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ._multiscale_structural_similarity import multiscale_structural_similarity
from ._ndtr import ndtr
from ._optical_flow_to_color import optical_flow_to_color
from ._pad_image import pad_image_to_factor
from ._quantization_offset import quantization_offset
from ._soft_round import soft_round
from ._soft_round_conditional_mean import soft_round_conditional_mean
Expand Down
37 changes: 37 additions & 0 deletions neuralcompression/functional/_pad_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch.nn.functional as F
from torch import Tensor


def pad_image_to_factor(
image: Tensor, factor: int, mode: str = "reflect"
) -> Tuple[Tensor, Tuple[int, int]]:
"""
Pads an image if it is not divisible by factor.

For many neural autoencoders, performance suffers if the input image is not
divisible by the downsampling factor. This utility function can be used to
pad the input image with reflection padding to avoid such cases.

Args:
image: A 4-D PyTorch tensor with dimensions (B, C, H, W)
factor: A factor by which the output image should be divisible.

Returns:
The image padded so that its dimensions are disible by factor, as well
as the height and width.
"""
# pad image if it's not divisible by downsamples
_, _, height, width = image.shape
pad_height = (factor - (height % factor)) % factor
pad_width = (factor - (width % factor)) % factor
if pad_height != 0 or pad_width != 0:
image = F.pad(image, (0, pad_width, 0, pad_height), mode=mode)

return image, (height, width)
10 changes: 2 additions & 8 deletions neuralcompression/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ._analysis_transformation_2d import AnalysisTransformation2D
from ._channel_norm import ChannelNorm2D
from ._continuous_entropy import ContinuousEntropy
from ._generalized_divisive_normalization import GeneralizedDivisiveNormalization
from ._hyper_analysis_transformation_2d import HyperAnalysisTransformation2D
from ._hyper_synthesis_transformation_2d import HyperSynthesisTransformation2D
from ._non_negative_parameterization import NonNegativeParameterization
from ._rate_mse_distortion_loss import RateMSEDistortionLoss
from ._synthesis_transformation_2d import SynthesisTransformation2D
from .gdn import SimplifiedGDN, SimplifiedInverseGDN
from ._simplified_gdn import SimplifiedGDN, SimplifiedInverseGDN
73 changes: 0 additions & 73 deletions neuralcompression/layers/_analysis_transformation_2d.py

This file was deleted.

54 changes: 54 additions & 0 deletions neuralcompression/layers/_channel_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch import Tensor


class ChannelNorm2D(nn.Module):
"""
Channel normalization layer.

This implements the channel normalization layer as described in the
following paper:

High-Fidelity Generative Image Compression
F. Mentzer, G. Toderici, M. Tschannen, E. Agustsson

Using this layer provides more stability to model outputs when there is a
shift in image resolutions between the training set and the test set.

Args:
input_channels: Number of channels to normalize.
epsilon: Divide-by-0 protection parameter.
affine: Whether to include affine parameters for the noramlized output.
"""

def __init__(self, input_channels: int, epsilon: float = 1e-3, affine: bool = True):
super().__init__()

if input_channels <= 1:
raise ValueError(
"ChannelNorm only valid for channel counts greater than 1."
)

self.epsilon = epsilon
self.affine = affine

if affine is True:
self.gamma = nn.Parameter(torch.ones(1, input_channels, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, input_channels, 1, 1))

def forward(self, x: Tensor) -> Tensor:
mean = torch.mean(x, dim=1, keepdim=True)
variance = torch.var(x, dim=1, keepdim=True)

x_normed = (x - mean) * torch.rsqrt(variance + self.epsilon)

if self.affine is True:
x_normed = self.gamma * x_normed + self.beta

return x_normed
Loading