Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Textual Inversion Training on M1 (works!) #517

Closed
tmm1 opened this issue Sep 12, 2022 · 205 comments
Closed

Textual Inversion Training on M1 (works!) #517

tmm1 opened this issue Sep 12, 2022 · 205 comments
Labels
enhancement New feature or request

Comments

@tmm1
Copy link

tmm1 commented Sep 12, 2022

WIP HERE: development...tmm1:dev-train-m1


I started experimenting with running main.py on M1 and wanted to document some immediate issues.

Looks like we need a newer pytorch-lightning for MPS. Currently using 1.6.5 but latest is 1.7.5

However bumping it causes this error:

AttributeError: module 'pytorch_lightning.loggers' has no a
ttribute 'TestTubeLogger'. Did you mean: 'NeptuneLogger'?

which is because TestTubeLogger was deprecated: Lightning-AI/pytorch-lightning#13958 (comment)

@lstein
Copy link
Collaborator

lstein commented Sep 12, 2022

I started working with the training functionality last night as well and ran into problems on CUDA. The textual inversion modifications to ddpm.py seem to have adversely affected vanilla training and we'll have to do a careful comparison with the original CompViz implementation in order to isolate the conflicts.

@tmm1, have you tried main.py on M1 using any of the other (multitudinous) forks? If so, any success?

@tmm1
Copy link
Author

tmm1 commented Sep 12, 2022

If there's a fork that advertises M1 training support I would be happy to try it. I have not seen one, but I have not looked much either. My understanding was that most of the M1 work was happening here.

@lstein
Copy link
Collaborator

lstein commented Sep 12, 2022

@Any-Winter-4079, when you've finished the latest round of code tweaking, could you have a look at training? It seems to be messed up on M1.

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

I made some progress today, and was able to get through all the setup and start training to send commands to the mps backend: development...tmm1:dev-train-m1

Currently stuck here:

...
Global seed set to 23
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

Summoning checkpoint.
Traceback (most recent call last):
  File "/Users/tmm1/code/stable-diffusion/./main.py", line 946, in <module>
    trainer.fit(model, data)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1147, in _run
    self.strategy.setup(self)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 161, in setup
    self._share_information_to_prevent_deadlock()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 396, in _share_information_to_prevent_deadlock
    self._share_pids()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 414, in _share_pids
    pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/parallel.py", line 113, in all_gather
    return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/utilities/distributed.py", line 219, in all_gather_ddp_if_available
    return AllGatherGrad.apply(tensor, group)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/utilities/distributed.py", line 187, in forward
    torch.distributed.all_gather(gathered_tensor, tensor, group=group)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2070, in all_gather
    work = group.allgather([tensor_list], [tensor])
RuntimeError: ProcessGroupGloo::allgather: unsupported device type mps

I'm looking to see if there's a way to turn off the distributed_backend=gloo

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Made some more progress by changing strategy from ddp to dp (https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html)

However, it seems ImageLogger which was using testtube is doing something important and switching to dp + csvlogger is not achieving the same result.

TypeError: LatentDiffusion.on_train_batch_start() missing 1 required positional argument: 'dataloader_idx'

EDIT: Found solution in Lightning-AI/pytorch-lightning#10315

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

It is training!

Epoch 0: 11%|█▎ | 44/404 [03:59<32:37, 5.44s/it, loss=0.0784, v_num=0, train/loss_simple_step=0.00508, train/loss_vlb_step=2.81e-5, train/loss_step=0.00508, global_step=43.00]

I recall tho others saying training was broken on CUDA too in this fork, so I'm not sure if this is actually working or just appearing to. But a lot of the blockers are solved and we can get into the guts of the impl now.

EDIT: I am not seeing any of the warnings mentioned on the CUDA thread (related to batch_size)

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Died at the end.

Everything turned to nan at some point, I don't know if that's a bad sign.

I will try to make outputs optional and exit training early to see if it works.

Epoch 0: 100%|████████████████████| 404/404 [27:34<00:00,  4.10s/it, loss=nan, v_num=0, train/loss_simple_step=nan.0, train/loss_vlb_step=nan.0, train/loss_step=nan.0, global_step=399.0]


Summoning checkpoint.
Traceback (most recent call last):
  File "/Users/tmm1/code/stable-diffusion/./main.py", line 946, in <module>
    trainer.fit(model, data)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 201, in run
    self.on_advance_end()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 299, in on_advance_end
    self.trainer._call_callback_hooks("on_train_epoch_end")
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1597, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
TypeError: CUDACallback.on_train_epoch_end() missing 1 required positional argument: 'outputs'

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Hmm got past last error but a new one now:

Epoch 0: 100%|███████████| 404/404 [16:24<00:00,  2.44s/it, loss=0.0835, v_num=0, train/loss_simple_step=0.00323, train/loss_vlb_step=1.88e-5, train/loss_step=0.00323, global_step=399.0/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:2075: LightningDeprecationWarning: `Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. Please use `Trainer.strategy.root_device.index` instead.
  rank_zero_deprecation(
Summoning checkpoint.
Traceback (most recent call last):
  File "/Users/tmm1/code/stable-diffusion/./main.py", line 946, in <module>
    trainer.fit(model, data)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 201, in run
    self.on_advance_end()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 299, in on_advance_end
    self.trainer._call_callback_hooks("on_train_epoch_end")
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1597, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/main.py", line 558, in on_train_epoch_end
    torch.cuda.synchronize(trainer.root_gpu)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/cuda/__init__.py", line 494, in synchronize
    _lazy_init()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/cuda/__init__.py", line 211, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Okay now its able to move onto epoch 1!

Average Epoch time: 472.13 seconds
Average Peak memory 0.00MiB
Epoch 1:  16%|█▌        | 63/404 [01:23<07:33,  1.33s/it, loss=0.0875, v_num=0, train/loss_simple_step=0.204, train/loss_vlb_step=0.000971, train/loss_step=0.204, global_step=462.0, train/loss_simple_epoch=0.108, train/loss_vlb_epoch=0.00118, train/loss_epoch=0.108]

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

I have a checkpoints/embeddings_gs-1600.pt now but when I try using it, the output images are black :(

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

I started fresh and by epoch 2 everything turns to nan. I think that is causing the black images?

Epoch 2: 26%|▎| 104/404 [02:11<06:20, 1.27s/it, loss=nan, v_num=0, train/loss_simple_step=nan.0, train/loss_vlb_step=nan.0, train/loss_step=nan.

cc @magnusviri @Birch-san

@Birch-san
Copy link

when I encountered black images with k-diffusion sampler, it was due to this problem (with ±Inf):
pytorch/pytorch#84364

fix was just to detach and clone the tensor:
crowsonkb/k-diffusion@3e976ef

if you're having NaN (rather than ±Inf), maybe that's unrelated.

I recommend to narrow down which line first introduces NaN. you can use this check to do so:

mycooltensor.isnan().any()
# returns a boolean

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Thanks @Birch-san! I see this warning at the start of training which may be related.

/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/module.py:555: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
value = torch.tensor(value, device=self.device)

@Any-Winter-4079
Copy link
Contributor

This looks interesting. I will have a look.

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Thanks @Any-Winter-4079! You could use the ugly-sonic training samples along with instructions in TEXTUAL_INVERSION.md

I am going to try detect_anomaly=True as recommended on Lightning-AI/pytorch-lightning#12137 (reply in thread)

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Caught something:

Sanity Checking: 0it [00:00, ?it/s]/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:225: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Sanity Checking DataLoader 0:   0%|                                                                                         | 0/2 [00:00<?, ?it/s]/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/overrides/data_parallel.py:100: UserWarning: Could not determine on which device the inputs are. When using DataParallel (strategy='dp'), be aware that in case you are using self.device in your code, it will reference only the root device.
  rank_zero_warn(
/Users/tmm1/code/stable-diffusion/ldm/modules/embedding_manager.py:153: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at  /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1659484612588/work/aten/src/ATen/mps/MPSFallback.mm:11.)
  placeholder_idx = torch.where(
/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/module.py:555: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  value = torch.tensor(value, device=self.device)
/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:225: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   0%|                                                                                                            | 0/404 [00:00<?, ?it/s]/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:231: UserWarning: You called `self.log('global_step', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
  warning_cache.warn(
Epoch 0:   0%| | 1/404 [00:10<1:12:21, 10.77s/it, loss=0.00231, v_num=0, train/loss_simple_step=0.00231, train/loss_vlb_step=1.41e-5, train/loss_s...


/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in PermuteBackward0. Traceback of forward call that caused the error:
  File "/Users/tmm1/code/stable-diffusion/./main.py", line 947, in <module>
    trainer.fit(model, data)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1672, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
    return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
    closure_result = closure()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 132, in closure
    step_output = self._step_fn()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 407, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/dp.py", line 134, in training_step
    return self.model(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/overrides/data_parallel.py", line 65, in forward
    output = super().forward(*inputs, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 79, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/models/diffusion/ddpm.py", line 498, in training_step
    loss, loss_dict = self.shared_step(batch)
  File "/Users/tmm1/code/stable-diffusion/ldm/models/diffusion/ddpm.py", line 1253, in shared_step
    loss = self(x, c)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/models/diffusion/ddpm.py", line 1270, in forward
    return self.p_losses(x, c, t, *args, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/models/diffusion/ddpm.py", line 1475, in p_losses
    model_output = self.apply_model(x_noisy, t, cond)
  File "/Users/tmm1/code/stable-diffusion/ldm/models/diffusion/ddpm.py", line 1440, in apply_model
    x_recon = self.model(x_noisy, t, **cond)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/models/diffusion/ddpm.py", line 2148, in forward
    out = self.diffusion_model(x, t, context=cc)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/diffusionmodules/openaimodel.py", line 811, in forward
    h = module(h, emb, context)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/diffusionmodules/openaimodel.py", line 88, in forward
    x = layer(x, context)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 346, in forward
    x = block(x, context=context)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 296, in forward
    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/diffusionmodules/util.py", line 157, in checkpoint
    return func(*inputs)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 301, in _forward
    x = self.attn2(self.norm2(x), context=context) + x
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 274, in forward
    r1 = self.einsum_op(q, k, v, r1)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 189, in einsum_op_v1
    r1 = einsum('b i j, b j d -> b i d', s2, v)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/functional.py", line 360, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
 (Triggered internally at  /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1659484612588/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Summoning checkpoint.
Traceback (most recent call last):
  File "/Users/tmm1/code/stable-diffusion/./main.py", line 947, in <module>
    trainer.fit(model, data)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1672, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
    return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
    closure_result = closure()
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 141, in closure
    self._backward_fn(step_output.closure_loss)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 304, in backward_fn
    self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 191, in backward
    self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 80, in backward
    model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1417, in backward
    loss.backward(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'PermuteBackward0' returned nan values in its 0th output.

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

This is all very new to me, but if I'm interpreting the output correctly it seems to suggest the gradients from einsum are nan.

Important bits:

UserWarning: Error detected in PermuteBackward0. Traceback of forward call that caused the error:
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 301, in _forward
    x = self.attn2(self.norm2(x), context=context) + x
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 274, in forward
    r1 = self.einsum_op(q, k, v, r1)
  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 189, in einsum_op_v1
    r1 = einsum('b i j, b j d -> b i d', s2, v)
...
  File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1417, in backward
    loss.backward(*args, **kwargs)
RuntimeError: Function 'PermuteBackward0' returned nan values in its 0th output.

This is including the changes merged into development this morning.

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

I started working with the training functionality last night as well and ran into problems on CUDA. The textual inversion modifications to ddpm.py seem to have adversely affected vanilla training and we'll have to do a careful comparison with the original CompViz implementation in order to isolate the conflicts.

@lstein What sort of problems did you run into on CUDA? I wonder if you can try this and see if any anomalies are detected?

diff --git a/main.py b/main.py
index c45194d..57c8832 100644
--- a/main.py
+++ b/main.py
@@ -864,6 +864,7 @@ if __name__ == '__main__':
         ]
         trainer_kwargs['max_steps'] = trainer_opt.max_steps

+        trainer_opt.detect_anomaly = True
         trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
         trainer.logdir = logdir  ###

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

I switched CrossAttention#forward to original implementation, and the same anomaly is detected. So atleast it is not seemingly related to the performance tweaks there.

I don't have a CUDA setup to test with, so maybe nan at this step is expected. This happens right away for me, whereas loss was fine for a few hundred steps before so there must be a different anomaly later on.

@lstein
Copy link
Collaborator

lstein commented Sep 13, 2022 via email

@Birch-san
Copy link

Birch-san commented Sep 13, 2022

Maybe this is another opportunity to try replacing einsum with matmul?

Birch-san/stable-diffusion@d2d533d

It's like 30% slower, but it might do something different regarding NaN?

Context:.
huggingface/diffusers#452 (comment)

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Good idea!

But it just failed in the same way, so either nan is normal or something else bigger is the problem.

  File "/Users/tmm1/code/stable-diffusion/ldm/modules/attention.py", line 255, in forward
    sim = torch.matmul(q, k.transpose(1, 2)) * self.scale
RuntimeError: Function 'TransposeBackward0' returned nan values in its 0th output.

@Any-Winter-4079
Copy link
Contributor

Any-Winter-4079 commented Sep 13, 2022

If it's about replacing einsum, the other day I tried https://github.com/dgasmith/opt_einsum (just to see if it was faster). Not sure if this can be another alternative. This goes way over my head, but I'll try to mess around to see if I can get it to work, even if by trial and error

@Any-Winter-4079
Copy link
Contributor

@tmm1 Have you encountered this error? RuntimeError: Placeholder storage has not been allocated on MPS device!

@tmm1
Copy link
Author

tmm1 commented Sep 13, 2022

Hm no I didn't see that one.

@Any-Winter-4079
Copy link
Contributor

@tmm1 Have you encountered this error? RuntimeError: Placeholder storage has not been allocated on MPS device!

Well, for anyone that encounters the issue, it's fixed with pip install pytorch-lightning==1.7.5 (which you mentioned in the first comment, but I naively tried to get by without updating my environment, but nope. It's needed.

@Birch-san
Copy link

Birch-san commented Sep 14, 2022

If it's about replacing einsum, the other day I tried https://github.com/dgasmith/opt_einsum (just to see if it was faster).

wow, that's cool. yeah, it's just a drop-in replacement:
CompVis/stable-diffusion@b7357a7
I just did pip install opt_einsum and made that source change.

unfortunately I'm finding opt_einsum to be about 30x slower on MPS.

8 steps inference:
opt_einsum took 158.9 secs at 40~20s/it,
regular einsum took 10.4 secs at ~1.25s/it.

Birch-san added a commit to Birch-san/stable-diffusion that referenced this issue Sep 14, 2022
@tmm1
Copy link
Author

tmm1 commented Sep 27, 2022

Yes I used lldb to see what the seg fault was coming from

@Any-Winter-4079
Copy link
Contributor

Any-Winter-4079 commented Sep 27, 2022

Yesterday, I left it training with my original 512x512 3-sample hamburger training set. Just to compare results. And it learnt much better. So it seems the difference maker is num_vectors_per_token: 6

"a painting of * :3.6 in the style of van gogh" -s 50 -S 166153623 -W 512 -H 512 -C 7.5 -A k_lms
Screenshot 2022-09-27 at 11 40 54

min val/loss_simple_ema was 0.0179...

@Any-Winter-4079
Copy link
Contributor

@Birch-san by the way if you got it to work, there is this other version (Dreambooth) you may be interested in. XavierXiao/Dreambooth-Stable-Diffusion#4 They claim even better results than with regular TI

Also 256x256 sounds interesting/promising (4x promise), but there is an issue where 2 characters sometimes appear on the sample images.

@lkewis
Copy link

lkewis commented Sep 27, 2022

@Any-Winter-4079 Excellent results and information about your experiments! This is all really helpful for me, and I'll continue to test Textual Inversion (also in combination with prior training in Dreambooth, which seems to require less overall training).

As a direct comparison these were my results with the same training data using Dreambooth.
Dreambooth_Examples1_smaller

The results are far easy to stylise and don't require prompt weighting, though one thing I noticed about this particular training, it has exact likeness in face but rarely produces the dreadlock hairstyle:
sks_person_normanrockwell_small

Unless I specify 'with short dreadlocks' in the prompt, then it produces the dreadlocks and bleached tips but it never fully gets the exact hair style (depending on the artistic modifiers also used in the prompt, these examples were both norman rockwell for comparison)
sks_person_dreadlocks_normanrockwell_small

@Any-Winter-4079
Copy link
Contributor

Any-Winter-4079 commented Sep 28, 2022

Okay, so yesterday I left it training overnight to do a mini-experiment on whether the lowest val/loss_simple_ema provides the best embedding to use (vs. other epochs with also low val/loss_simple_ema but better results in val folder).

However, I've found two issues. First, after training with 50 (DDIM) steps, inference with K_LMS at 50 steps sometimes produces noise or unfinished images.
image

To fix this, one can increase the number of steps for K_LMS to 100. Alternatively, using DDIM at 50 steps also doesn't seem to produce noise. Knowing this issue, it's probably best to train on K_LMS / K_HEUN, if that is the sampler to be used later, given convergence and speed metrics.

The second issue that I've found, after recording my val/loss_simple_ema values, is that not all epochs have their results saved. In particular, embeddings_gs-2000.pt, embeddings_gs-4000.pt, embeddings_gs-6000.pt are missing, so I can't complete the experiment.
Screenshot 2022-09-28 at 23 31 56

I'll leave my val/loss_simple_ema anyway in case they are helpful for comparison. I'll update after fixing both issues.

Step val/loss_simple_ema
399 0.03151685371994972
799 0.022184547036886215
1199 0.023538151755928993
1599 0.0887226015329361
1999 0.048154544085264206
2399 0.09849347174167633
2799 0.0656338557600975
3199 0.03855472058057785
3599 0.01978435553610325
3999 0.04475581645965576
4399 0.025169644504785538
4799 0.03502315282821655
5199 0.019644401967525482
5599 0.022544164210557938
5999 0.015959467738866806
6399 0.02311003766953945
6799 0.018962502479553223

@Any-Winter-4079
Copy link
Contributor

@Birch-san I've tried using your embeddings, but they don't seem to generate plush dolls.
"*" -s 20 -S 1652632745 -W 512 -H 512 -C 7.5 -A k_heun
Screenshot 2022-09-29 at 15 23 15

"*" -s 20 -S 3312359222 -W 512 -H 512 -C 7.5 -A k_heun
Screenshot 2022-09-29 at 15 23 02

I assume it's due to changes in your K_HEUN implementation, so heads up for everyone, embeddings may not be guaranteed to work cross-repo, if there are changes in the samplers. I even had trouble training with one sampler and generating with another sampler (even within the same repo).

@lstein
Copy link
Collaborator

lstein commented Sep 29, 2022 via email

@Birch-san
Copy link

Birch-san commented Sep 29, 2022

@Birch-san I've tried using your embeddings, but they don't seem to generate plush dolls.
I assume it's due to changes in your K_HEUN implementation, so heads up for everyone, embeddings may not be guaranteed to work cross-repo, if there are changes in the samplers. I even had trouble training with one sampler and generating with another sampler (even within the same repo).

my k-diffusion integration is just for the logger used during training. log_images() is annotated with @torch.no_grad(), so I'm pretty sure it wouldn't have any impact on what's learned.

I think the explanation is simpler: my embedding is weak.
perhaps due to the strategy I used for text conditioning:
https://github.com/Birch-san/stable-diffusion/blob/3548866e020ef0ddcb6b594984c2eb36d17341bd/ldm/data/personalized.py#L187-L197
the * token is not near the start of the prompt in most cases, and we know SD pays way less attention to tokens in the middle of the prompt.

I also used * plush doll in every case, instead of just *. so perhaps its job was to say "what distinguishes me among plush dolls (fat and chibi)", and perhaps if you use it on its own without specifying "plush doll" afterward.. you just get the "fat and chibi" aspect?

it could explain why it turns everyone into pillows when I use it on waifu-diffusion:

00253 s3714475585_photo of Reimu - plush doll_

00254 s1724541131_photo of Reimu - plush doll_

00252 s3714475585_photo of Reimu - plush doll_

(this was "photo of Reimu * plush doll")

@Any-Winter-4079
Copy link
Contributor

Any-Winter-4079 commented Sep 30, 2022

By the way, is training deterministic? I aborted training, started over, and I'm getting the same images in my train/val folders. Maybe it's because of our rand fix?

Same value for val/loss_simple_ema on the epoch as well (0.12891504168510437)

@toprakfirat
Copy link

toprakfirat commented Oct 7, 2022

File "/opt/homebrew/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/cuda/init.py", line 211, in _lazy_init
raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

Hello, @tmm1, you said that you faced with this error and solved it. How did you fix it care to elaborate? I tried using different flags to use cpu, but it didn't fix it for me. Is there another method, maybe I should change the package versions or something? I have Apple M1 Max.

@Any-Winter-4079 Any-Winter-4079 added the enhancement New feature or request label Oct 9, 2022
@Any-Winter-4079
Copy link
Contributor

I wonder if the model learns better/worse with e.g. waifu diffusion.

@Any-Winter-4079
Copy link
Contributor

Update to say I've had my best success so far training with images of myself.
num_vectors_per_token: 6 and 50 training images (not 3-5 as it was suggested). Also, val/loss_simple_ema was not as good of a metric as choosing the embedding with best (closest to myself) images on val folder.
Training was 40k steps (8 epochs of 5k steps each),. Best results came in 5th epoch.

@lstein lstein closed this as completed Oct 24, 2022
@remixer-dec
Copy link

I tried training with different repos, this one, mac-optimized by Birch-san, web-ui by AUTOMATIC1111 (currently broken for training), DreamBoothMac by SujeethJinesh. Tried applying optimizations mentioned in this thread (most of them as I understand, are already implemented in the code), but every time after the 1st iteration, loss becomes nan. Updated to latest pytorch-nightly. No Idea where to dig next. My last bet is to upgrade MacOS to Ventura in hope for changes in Metal backend, since I'm currently on an old 12.3.1.

@Birch-san
Copy link

@remixer-dec try the previous PyTorch stable, 1.12.1. there's a bug in PyTorch 1.13 stable which means autograd returns NaN gradients. I'm nearly done making a minimal repro of it. I can reproduce it using just the autoencoder's decoder. IIRC detect_anomaly said the NaN comes from NativeGroupNorm. This also breaks CLIP guidance on Mac.

@Any-Winter-4079
Copy link
Contributor

Any-Winter-4079 commented Nov 2, 2022

My best Dreambooth result has been with https://colab.research.google.com/drive/1-HIbslQd7Ei_mAt25ipqSUMvbe3POm98?usp=sharing#scrollTo=CnBAZ4eje2Sl (in case it's helpful).
I could download the .ckpt and load it into this repo, which I wasn't able to do with other versions (when I tried them).

I will try out https://github.com/SujeethJinesh/DreamBoothMac. Thanks!

If you get nan, you may need:

def fix_func(orig):
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        def new_func(*args, **kw):
            device = kw.get("device", "mps")
            kw["device"]="cpu"
            return orig(*args, **kw).to(device)
        return new_func
    return orig

torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)

And pytorch-nightly not the best idea, yes (aside form other issues yet to polish... it's slower for some ops like einsum).

@remixer-dec
Copy link

@Birch-san thank you! It actually worked! You are a legend!
Looks like older version of pytorch is also using more memory (or maybe that's because there are actual numbers and not nan_s), and DreamBooth becomes unusable on a 32gb machine, consuming 40-50 gigs of ram with swap.

@lkewis
Copy link

lkewis commented Nov 2, 2022

@Any-Winter-4079 be careful with the LastBen ones, he turned of class + regularisation to make training faster in some of them, and in our testing it caused dogs to have human eyes lol. You should be doing regularisation unless you don't care about using the model like a regular SD one.

@Birch-san
Copy link

if prior preservation loss / coarse classes / regularisation are removed… is it still Dreambooth? is there any way in which it differs from regular finetuning at that point?

@Birch-san
Copy link

@remixer-dec I got to the bottom of the problem returning NaN gradients from autograd on 1.13.0. minimal repro here:
pytorch/pytorch#88331

@lkewis
Copy link

lkewis commented Nov 2, 2022

@Birch-san A lot more people are going down the image + caption route now for multi-trained models, and it is basically just full fine tuning at that point, albeit in smaller quantities

@lstein
Copy link
Collaborator

lstein commented Nov 2, 2022 via email

@Any-Winter-4079
Copy link
Contributor

Any-Winter-4079 commented Nov 2, 2022

The notebook I shared has prior preservation and regulation I think -and thanks to Joe Penna and these other contributors, also has images available to be used (instead of creating them yourself, or allowing them to be created on the spot...):
Screenshot 2022-11-02 at 20 06 23

All of this as far as I know, of course. I'm not an expert :)

In any case, I tested 1-2 days ago. I'll re-test and fork TheLastBen's repo -if it still works-, if there is fear of it changing.
Update: Yep, still works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests