Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attend and excite tests disable determinism on the class level #3478

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@


torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get ran once when the module is loaded and then determinism will get turned back on later



@skip_mps
Expand All @@ -47,6 +46,19 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS

# Attend and excite requires being able to run a backward pass at
# inference time. There's no deterministic backward operator for pad

@classmethod
def setUpClass(cls):
super().setUpClass()
torch.use_deterministic_algorithms(False)

@classmethod
def tearDownClass(cls):
super().tearDownClass()
torch.use_deterministic_algorithms(True)

def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
Expand Down Expand Up @@ -171,6 +183,19 @@ def test_save_load_optional_components(self):
@require_torch_gpu
@slow
class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
# Attend and excite requires being able to run a backward pass at
# inference time. There's no deterministic backward operator for pad

@classmethod
def setUpClass(cls):
super().setUpClass()
torch.use_deterministic_algorithms(False)

@classmethod
def tearDownClass(cls):
super().tearDownClass()
torch.use_deterministic_algorithms(True)

def tearDown(self):
super().tearDown()
gc.collect()
Expand Down