diff --git a/hackable_diffusion/kdiff/core.py b/hackable_diffusion/kdiff/core.py index cd35802..3d0c17c 100644 --- a/hackable_diffusion/kdiff/core.py +++ b/hackable_diffusion/kdiff/core.py @@ -56,7 +56,7 @@ from kauldron import kd ################################################################################ -# MARK: Type aliases +# MARK: Type Aliases ################################################################################ Array = hd_typing.Array diff --git a/hackable_diffusion/kdiff/evals.py b/hackable_diffusion/kdiff/evals.py index 6944fa4..cd15572 100644 --- a/hackable_diffusion/kdiff/evals.py +++ b/hackable_diffusion/kdiff/evals.py @@ -26,7 +26,7 @@ ################################################################################ -# MARK: Type aliases +# MARK: Type Aliases ################################################################################ Array = hd_typing.Array diff --git a/hackable_diffusion/lib/architecture/arch_utils.py b/hackable_diffusion/lib/architecture/arch_utils.py index 2421cc9..be335a3 100644 --- a/hackable_diffusion/lib/architecture/arch_utils.py +++ b/hackable_diffusion/lib/architecture/arch_utils.py @@ -37,7 +37,7 @@ SkipConnectionMethod = arch_typing.SkipConnectionMethod ################################################################################ -# MARK: Reusable NN components +# MARK: Reusable NN Components ################################################################################ kernel_init = nn.initializers.lecun_normal() diff --git a/hackable_diffusion/lib/architecture/attention.py b/hackable_diffusion/lib/architecture/attention.py index ff9ef02..9171f0f 100644 --- a/hackable_diffusion/lib/architecture/attention.py +++ b/hackable_diffusion/lib/architecture/attention.py @@ -27,7 +27,7 @@ ################################################################################ -# MARK: Type aliases +# MARK: Type Aliases ################################################################################ Float = hd_typing.Float @@ -45,7 +45,7 @@ MASK_LOGITS_VALUE = -1e9 ################################################################################ -# MARK: Attention utilities +# MARK: Attention Utilities ################################################################################ @@ -166,7 +166,7 @@ def _dot_product_attention( ################################################################################ -# MARK: Multi-head attention +# MARK: Multi-Head Attention ################################################################################ diff --git a/hackable_diffusion/lib/architecture/conditioning_encoder.py b/hackable_diffusion/lib/architecture/conditioning_encoder.py index dc2d973..d368743 100644 --- a/hackable_diffusion/lib/architecture/conditioning_encoder.py +++ b/hackable_diffusion/lib/architecture/conditioning_encoder.py @@ -44,7 +44,7 @@ EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod ################################################################################ -# MARK: Base classes +# MARK: Base Classes ################################################################################ @@ -81,7 +81,7 @@ def __call__( ################################################################################ -# MARK: Time embedders +# MARK: Time Embedders ################################################################################ @@ -162,7 +162,7 @@ def __call__(self, time: hd_typing.TimeArray) -> hd_typing.TimeArray: ################################################################################ -# MARK: Conditioning embedders +# MARK: Conditioning Embedders ################################################################################ @@ -338,7 +338,7 @@ def __call__( ################################################################################ -# MARK: Process and combine time and conditioning signals +# MARK: Process and Combine Time and Conditioning Signals ################################################################################ diff --git a/hackable_diffusion/lib/architecture/sequence_embedders.py b/hackable_diffusion/lib/architecture/sequence_embedders.py index a6f8b60..f4e9ab6 100644 --- a/hackable_diffusion/lib/architecture/sequence_embedders.py +++ b/hackable_diffusion/lib/architecture/sequence_embedders.py @@ -26,7 +26,7 @@ ################################################################################ -# MARK: Type aliases +# MARK: Type Aliases ################################################################################ Float = hd_typing.Float @@ -36,7 +36,7 @@ RoPEPositionType = arch_typing.RoPEPositionType ################################################################################ -# MARK: Sequence embedding modules +# MARK: Sequence Embedding Modules ################################################################################ diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index d8159c5..60b4aa0 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -27,7 +27,7 @@ import kauldron.ktyping as kt ################################################################################ -# MARK: Common types +# MARK: Type Aliases ################################################################################ DType = hd_typing.DType diff --git a/hackable_diffusion/lib/architecture/unet_blocks.py b/hackable_diffusion/lib/architecture/unet_blocks.py index f07cac4..8232176 100644 --- a/hackable_diffusion/lib/architecture/unet_blocks.py +++ b/hackable_diffusion/lib/architecture/unet_blocks.py @@ -25,7 +25,7 @@ import kauldron.ktyping as kt ################################################################################ -# MARK: Common types and aliases +# MARK: Type Aliases ################################################################################ DType = hd_typing.DType @@ -49,7 +49,7 @@ DownsampleOutput = Float["batch height/2 width/2 output_channels"] ################################################################################ -# MARK: Input and Output blocks +# MARK: Input and Output Blocks ################################################################################ @@ -121,7 +121,7 @@ def __call__(self, x: BaseInput) -> BaseOutput: ################################################################################ -# MARK: Residual block with optional resampling +# MARK: Residual Block With Optional Resampling ################################################################################ @@ -218,7 +218,7 @@ def __call__( ################################################################################ -# MARK: Attention residual block +# MARK: Attention Residual Block ################################################################################ diff --git a/hackable_diffusion/lib/corruption/discrete.py b/hackable_diffusion/lib/corruption/discrete.py index 8d3c1ac..a1cb4d2 100644 --- a/hackable_diffusion/lib/corruption/discrete.py +++ b/hackable_diffusion/lib/corruption/discrete.py @@ -66,7 +66,7 @@ class SamplingPrecisionMode(enum.StrEnum): ################################################################################ -# MARK: Projection functions +# MARK: Projection Functions ################################################################################ @@ -333,7 +333,7 @@ def get_schedule_info(self, time: TimeArray) -> dict[str, TimeArray]: return self.schedule.evaluate(time) ############################################################################## - # MARK: Factory methods + # MARK: Factory Methods ############################################################################## @classmethod diff --git a/hackable_diffusion/lib/corruption/gaussian.py b/hackable_diffusion/lib/corruption/gaussian.py index 89c199a..3d1c8b1 100644 --- a/hackable_diffusion/lib/corruption/gaussian.py +++ b/hackable_diffusion/lib/corruption/gaussian.py @@ -270,7 +270,7 @@ def _get_alpha_sigma_and_der( ################################################################################ ################################################################################ -# MARK: convert from x0 +# MARK: Convert From x0 ################################################################################ @@ -296,7 +296,7 @@ def x0_to_v(x0, xt, alpha, sigma, alpha_der, sigma_der): ################################################################################ -# MARK: convert from epsilon +# MARK: Convert From Epsilon ################################################################################ @@ -322,7 +322,7 @@ def epsilon_to_v(epsilon, xt, alpha, sigma, alpha_der, sigma_der): ################################################################################ -# MARK: convert from score +# MARK: Convert From Score ################################################################################ @@ -354,7 +354,7 @@ def score_to_v(score, xt, alpha, sigma, alpha_der, sigma_der): ################################################################################ -# MARK: convert from velocity +# MARK: Convert From Velocity ################################################################################ @@ -396,7 +396,7 @@ def velocity_to_v(velocity, xt, alpha, sigma, alpha_der, sigma_der): ################################################################################ -# MARK: convert from v +# MARK: Convert From v ################################################################################ @@ -443,7 +443,7 @@ def v_to_velocity(v, xt, alpha, sigma, alpha_der, sigma_der): ################################################################################ -# MARK: helpers +# MARK: Helpers ################################################################################ diff --git a/hackable_diffusion/lib/corruption/simplicial.py b/hackable_diffusion/lib/corruption/simplicial.py index d532218..42ea4c2 100644 --- a/hackable_diffusion/lib/corruption/simplicial.py +++ b/hackable_diffusion/lib/corruption/simplicial.py @@ -68,7 +68,7 @@ class SamplingPrecisionMode(enum.StrEnum): ################################################################################ -# MARK: Post-corruption functions +# MARK: Post-Corruption Functions ################################################################################ @@ -356,7 +356,7 @@ def get_schedule_info(self, time: TimeArray) -> dict[str, TimeArray]: return self.schedule.evaluate(time) ############################################################################## - # MARK: Factory methods + # MARK: Factory Methods ############################################################################## @classmethod diff --git a/hackable_diffusion/lib/inference/projection.py b/hackable_diffusion/lib/inference/projection.py index 424ce79..feb5ab8 100644 --- a/hackable_diffusion/lib/inference/projection.py +++ b/hackable_diffusion/lib/inference/projection.py @@ -55,7 +55,7 @@ def __call__( ################################################################################ -# MARK: Helper functions +# MARK: Helper Functions ################################################################################ @@ -129,7 +129,7 @@ def _from_x0( ################################################################################ -# MARK: Projection functions +# MARK: Projection Functions ################################################################################ diff --git a/hackable_diffusion/lib/manifolds.py b/hackable_diffusion/lib/manifolds.py index 5a1b7ce..2c2f30d 100644 --- a/hackable_diffusion/lib/manifolds.py +++ b/hackable_diffusion/lib/manifolds.py @@ -36,7 +36,7 @@ EPSILON = 1e-9 ################################################################################ -# MARK: Utility functions +# MARK: Utility Functions ################################################################################ @@ -94,7 +94,7 @@ def transpose(x: DataArray) -> DataArray: ################################################################################ -# MARK: Base class +# MARK: Base Class ################################################################################ @@ -166,7 +166,7 @@ def velocity( ################################################################################ -# MARK: Common manifold methods +# MARK: Common Manifold Methods ################################################################################ diff --git a/hackable_diffusion/lib/sampling/sampling.py b/hackable_diffusion/lib/sampling/sampling.py index 21a2c73..a1ff942 100644 --- a/hackable_diffusion/lib/sampling/sampling.py +++ b/hackable_diffusion/lib/sampling/sampling.py @@ -68,7 +68,7 @@ def __call__( ################################################################################ -# MARK: Helper functions +# MARK: Helper Functions ################################################################################ diff --git a/hackable_diffusion/lib/sampling/sampling_test.py b/hackable_diffusion/lib/sampling/sampling_test.py index b568ace..dfbb445 100644 --- a/hackable_diffusion/lib/sampling/sampling_test.py +++ b/hackable_diffusion/lib/sampling/sampling_test.py @@ -34,7 +34,7 @@ SamplerStep = base.SamplerStep ################################################################################ -# MARK: Helper functions +# MARK: Helper Functions ################################################################################ dummy_inference_fn = lambda xt, conditioning, time: {'x0': xt} @@ -76,7 +76,7 @@ def finalize(self, prediction, current_step, next_step_info): class DiffusionSamplingTest(parameterized.TestCase): - # MARK: Test for helper functions + # MARK: Test for Helper Functions def setUp(self): super().setUp() diff --git a/hackable_diffusion/lib/sampling/time_scheduling_test.py b/hackable_diffusion/lib/sampling/time_scheduling_test.py index c19ffa8..bf3e080 100644 --- a/hackable_diffusion/lib/sampling/time_scheduling_test.py +++ b/hackable_diffusion/lib/sampling/time_scheduling_test.py @@ -30,7 +30,9 @@ class UniformTimeScheduleTest(absltest.TestCase): - def test_all_step_infos(self): + # MARK: UniformTimeSchedule Tests + + def test_uniform_all_step_infos(self): time_schedule = time_scheduling.UniformTimeSchedule( span=utils.SafeSpan(safety_epsilon=0.1) ) @@ -91,6 +93,8 @@ def test_all_step_infos_without_safety_epsilon(self): expected, ) + # MARK: EDMTimeSchedule Tests + class EDMTimeScheduleTest(parameterized.TestCase):