Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hackable_diffusion/kdiff/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from kauldron import kd

################################################################################
# MARK: Type aliases
# MARK: Type Aliases
################################################################################

Array = hd_typing.Array
Expand Down
2 changes: 1 addition & 1 deletion hackable_diffusion/kdiff/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


################################################################################
# MARK: Type aliases
# MARK: Type Aliases
################################################################################

Array = hd_typing.Array
Expand Down
2 changes: 1 addition & 1 deletion hackable_diffusion/lib/architecture/arch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SkipConnectionMethod = arch_typing.SkipConnectionMethod

################################################################################
# MARK: Reusable NN components
# MARK: Reusable NN Components
################################################################################

kernel_init = nn.initializers.lecun_normal()
Expand Down
6 changes: 3 additions & 3 deletions hackable_diffusion/lib/architecture/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


################################################################################
# MARK: Type aliases
# MARK: Type Aliases
################################################################################

Float = hd_typing.Float
Expand All @@ -45,7 +45,7 @@
MASK_LOGITS_VALUE = -1e9

################################################################################
# MARK: Attention utilities
# MARK: Attention Utilities
################################################################################


Expand Down Expand Up @@ -166,7 +166,7 @@ def _dot_product_attention(


################################################################################
# MARK: Multi-head attention
# MARK: Multi-Head Attention
################################################################################


Expand Down
8 changes: 4 additions & 4 deletions hackable_diffusion/lib/architecture/conditioning_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod

################################################################################
# MARK: Base classes
# MARK: Base Classes
################################################################################


Expand Down Expand Up @@ -81,7 +81,7 @@ def __call__(


################################################################################
# MARK: Time embedders
# MARK: Time Embedders
################################################################################


Expand Down Expand Up @@ -162,7 +162,7 @@ def __call__(self, time: hd_typing.TimeArray) -> hd_typing.TimeArray:


################################################################################
# MARK: Conditioning embedders
# MARK: Conditioning Embedders
################################################################################


Expand Down Expand Up @@ -338,7 +338,7 @@ def __call__(


################################################################################
# MARK: Process and combine time and conditioning signals
# MARK: Process and Combine Time and Conditioning Signals
################################################################################


Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/architecture/sequence_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


################################################################################
# MARK: Type aliases
# MARK: Type Aliases
################################################################################

Float = hd_typing.Float
Expand All @@ -36,7 +36,7 @@
RoPEPositionType = arch_typing.RoPEPositionType

################################################################################
# MARK: Sequence embedding modules
# MARK: Sequence Embedding Modules
################################################################################


Expand Down
2 changes: 1 addition & 1 deletion hackable_diffusion/lib/architecture/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import kauldron.ktyping as kt

################################################################################
# MARK: Common types
# MARK: Type Aliases
################################################################################

DType = hd_typing.DType
Expand Down
8 changes: 4 additions & 4 deletions hackable_diffusion/lib/architecture/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import kauldron.ktyping as kt

################################################################################
# MARK: Common types and aliases
# MARK: Type Aliases
################################################################################

DType = hd_typing.DType
Expand All @@ -49,7 +49,7 @@
DownsampleOutput = Float["batch height/2 width/2 output_channels"]

################################################################################
# MARK: Input and Output blocks
# MARK: Input and Output Blocks
################################################################################


Expand Down Expand Up @@ -121,7 +121,7 @@ def __call__(self, x: BaseInput) -> BaseOutput:


################################################################################
# MARK: Residual block with optional resampling
# MARK: Residual Block With Optional Resampling
################################################################################


Expand Down Expand Up @@ -218,7 +218,7 @@ def __call__(


################################################################################
# MARK: Attention residual block
# MARK: Attention Residual Block
################################################################################


Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/corruption/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class SamplingPrecisionMode(enum.StrEnum):


################################################################################
# MARK: Projection functions
# MARK: Projection Functions
################################################################################


Expand Down Expand Up @@ -333,7 +333,7 @@ def get_schedule_info(self, time: TimeArray) -> dict[str, TimeArray]:
return self.schedule.evaluate(time)

##############################################################################
# MARK: Factory methods
# MARK: Factory Methods
##############################################################################

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions hackable_diffusion/lib/corruption/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _get_alpha_sigma_and_der(
################################################################################

################################################################################
# MARK: convert from x0
# MARK: Convert From x0
################################################################################


Expand All @@ -296,7 +296,7 @@ def x0_to_v(x0, xt, alpha, sigma, alpha_der, sigma_der):


################################################################################
# MARK: convert from epsilon
# MARK: Convert From Epsilon
################################################################################


Expand All @@ -322,7 +322,7 @@ def epsilon_to_v(epsilon, xt, alpha, sigma, alpha_der, sigma_der):


################################################################################
# MARK: convert from score
# MARK: Convert From Score
################################################################################


Expand Down Expand Up @@ -354,7 +354,7 @@ def score_to_v(score, xt, alpha, sigma, alpha_der, sigma_der):


################################################################################
# MARK: convert from velocity
# MARK: Convert From Velocity
################################################################################


Expand Down Expand Up @@ -396,7 +396,7 @@ def velocity_to_v(velocity, xt, alpha, sigma, alpha_der, sigma_der):


################################################################################
# MARK: convert from v
# MARK: Convert From v
################################################################################


Expand Down Expand Up @@ -443,7 +443,7 @@ def v_to_velocity(v, xt, alpha, sigma, alpha_der, sigma_der):


################################################################################
# MARK: helpers
# MARK: Helpers
################################################################################


Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/corruption/simplicial.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SamplingPrecisionMode(enum.StrEnum):


################################################################################
# MARK: Post-corruption functions
# MARK: Post-Corruption Functions
################################################################################


Expand Down Expand Up @@ -356,7 +356,7 @@ def get_schedule_info(self, time: TimeArray) -> dict[str, TimeArray]:
return self.schedule.evaluate(time)

##############################################################################
# MARK: Factory methods
# MARK: Factory Methods
##############################################################################

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/inference/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __call__(


################################################################################
# MARK: Helper functions
# MARK: Helper Functions
################################################################################


Expand Down Expand Up @@ -129,7 +129,7 @@ def _from_x0(


################################################################################
# MARK: Projection functions
# MARK: Projection Functions
################################################################################


Expand Down
6 changes: 3 additions & 3 deletions hackable_diffusion/lib/manifolds.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
EPSILON = 1e-9

################################################################################
# MARK: Utility functions
# MARK: Utility Functions
################################################################################


Expand Down Expand Up @@ -94,7 +94,7 @@ def transpose(x: DataArray) -> DataArray:


################################################################################
# MARK: Base class
# MARK: Base Class
################################################################################


Expand Down Expand Up @@ -166,7 +166,7 @@ def velocity(


################################################################################
# MARK: Common manifold methods
# MARK: Common Manifold Methods
################################################################################


Expand Down
2 changes: 1 addition & 1 deletion hackable_diffusion/lib/sampling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __call__(


################################################################################
# MARK: Helper functions
# MARK: Helper Functions
################################################################################


Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/sampling/sampling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SamplerStep = base.SamplerStep

################################################################################
# MARK: Helper functions
# MARK: Helper Functions
################################################################################

dummy_inference_fn = lambda xt, conditioning, time: {'x0': xt}
Expand Down Expand Up @@ -76,7 +76,7 @@ def finalize(self, prediction, current_step, next_step_info):

class DiffusionSamplingTest(parameterized.TestCase):

# MARK: Test for helper functions
# MARK: Test for Helper Functions

def setUp(self):
super().setUp()
Expand Down
6 changes: 5 additions & 1 deletion hackable_diffusion/lib/sampling/time_scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

class UniformTimeScheduleTest(absltest.TestCase):

def test_all_step_infos(self):
# MARK: UniformTimeSchedule Tests

def test_uniform_all_step_infos(self):
time_schedule = time_scheduling.UniformTimeSchedule(
span=utils.SafeSpan(safety_epsilon=0.1)
)
Expand Down Expand Up @@ -91,6 +93,8 @@ def test_all_step_infos_without_safety_epsilon(self):
expected,
)

# MARK: EDMTimeSchedule Tests


class EDMTimeScheduleTest(parameterized.TestCase):

Expand Down
Loading