Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't resume training with quantized base model #9108

Open
bghira opened this issue Aug 7, 2024 · 44 comments
Open

can't resume training with quantized base model #9108

bghira opened this issue Aug 7, 2024 · 44 comments
Labels
bug Something isn't working

Comments

@bghira
Copy link
Contributor

bghira commented Aug 7, 2024

Describe the bug

We have a KeyError when the state dict goes to load into the transformer.

Reproduction

import torch
from diffusers.models import FluxTransformer2DModel
from peft import LoraConfig, set_peft_model_state_dict
from diffusers import FluxPipeline
from diffusers.utils import convert_unet_state_dict_to_peft
from accelerate import Accelerator
from optimum.quanto import quantize, freeze, qint8

# Step 1: FluxTransformer2DModel Model
print("loading model")
model = FluxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-dev', subfolder='transformer')
accelerator = Accelerator()
# Step 2: Add LoRA Adapter
config = LoraConfig(r=8, lora_alpha=16, target_modules=["to_q", "to_k"], lora_dropout=0.1)
print("adding adapter (random)")
model.add_adapter(config)

# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)

# Step 4: prepare model
print("prepare model")
model = accelerator.prepare(model)

# Step 5: Load LoRA State Dictionary
lora_path = "checkpoint/"

print("retrieve lora state dict")
lora_state_dict = FluxPipeline.lora_state_dict(lora_path)
transformer_state_dict = {
    f'{k.replace("transformer.", "")}': v
    for k, v in lora_state_dict.items()
    if k.startswith("unet.")
}
print("convert state dict")
transformer_state_dict = convert_unet_state_dict_to_peft(
    transformer_state_dict
)
incompatible_keys = set_peft_model_state_dict(
    model, transformer_state_dict, adapter_name="default"
)
print(f"unexpected keys: {incompatible_keys.unexpected_keys}")

Logs

adding adapter (random)
quantizing model
freezing model
prepare model
retrieve lora state dict
convert state dict
peft_model_state_dict: dict_keys([])
Loading with state_dict odict_keys([])
state_dict: dict_keys([])
prefix: time_text_embed.timestep_embedder.linear_1.weight.
Traceback (most recent call last):
  File "/Users/bghira/src/SimpleTuner/test_flux.py", line 41, in <module>
    )
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 354, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
    load(self, state_dict)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 92, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

System Info

latest git main

Who can help?

@sayakpaul

@bghira bghira added the bug Something isn't working label Aug 7, 2024
@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

  • if we remove quantize and freeze, it works, and has no unexpected keys
  • if we quantize and freeze after load, it complains about unexpected keys

@sayakpaul
Copy link
Member

I don't see any calls to load_state_dict() in the code but the trace has:

load_result = model.load_state_dict(peft_model_state_dict, strict=False)

Cc: @dacorvo @BenjaminBossan

Benjamin, I belive it would be nice to be able to support this because with models like Flux coming, it will be crucial to support quantized models + PEFT in this way. But I think this is something that needs to be fixed in quanto. So, I have added David to the thread as well.

@sayakpaul
Copy link
Member

Would be also nice to have a dummy checkpoint for lora_path.

@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

it's that set_peft_model_state_dict helper method calling this.

@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

checkpoint-10.zip

@sayakpaul
Copy link
Member

if we remove quantize and freeze, it works, and has no unexpected keys

That means we lose the benefits of quantization.

@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

well that happens if you quantise after sending it to the accelerator, too. it's quite fragile currently

@sayakpaul
Copy link
Member

Oh then it seems like the net benefits are zero in terms of memory, no? Not saying that as a negative comment because I think this exploration is critically important.

@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

i had to reorder things so it goes to the accelerator long after everything is done to the weights

@sayakpaul
Copy link
Member

Could you also then provide the order that provides us with the benefits of quantization (when training is not being resumed)?

@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

that is in the reproducer

@BenjaminBossan
Copy link
Member

I'm not familiar with quanto, what type of quantization is applied here? Also, with PEFT, we have different LoRA classes depending on quantization, so e.g. we have a different linear layer for bnb than for GPTQ. Therefore, the model already needs to be quantized before adding the adapter. Not sure if this is any help here, but I wanted to mention that just in case.

@sayakpaul
Copy link
Member

@BenjaminBossan this is weight-only int8 quantization.

Also, with PEFT, we have different LoRA classes depending on quantization, so e.g. we have a different linear layer for bnb than for GPTQ.

I see. So, IIUC, this needs to be implemented at the PEFT level then? I imagine quanto becoming quite mainstream in the diffusion community especially after large releases like Flux. So, could be nice to start thinking about how we could support the quantized linear layers from quanto.

@bghira
Copy link
Contributor Author

bghira commented Aug 7, 2024

quantizing model
freezing model
adding adapter (random)
prepare model
retrieve lora state dict
convert state dict
peft_model_state_dict: dict_keys([])
state_dict: dict_keys([])
prefix: time_text_embed.timestep_embedder.linear_1.weight.
Traceback (most recent call last):

same error when we load the adapter after freeze

@BenjaminBossan
Copy link
Member

So, IIUC, this needs to be implemented at the PEFT level then? I imagine quanto becoming quite mainstream in the diffusion community especially after large releases like Flux. So, could be nice to start thinking about how we could support the quantized linear layers from quanto.

Okay, so I checked, as I wasn't sure if quanto is using its own quantization scheme or if they wrap others. But from what I could tell, its their own. Therefore, for this to work with PEFT, we would need to add support for quanto layers. Hopefully, this shouldn't be too difficult. I added an item to the backlog.

In the meantime, if anyone has a prototype, it is now quite easy to add support for new LoRA layer types to PEFT by using dynamic dispatch. So this can be tested quickly without the need to create a PR on PEFT.

@sayakpaul
Copy link
Member

In the meantime, if anyone has a prototype, it is now quite easy to add support for new LoRA layer types to PEFT by using dynamic dispatch. So this can be tested quickly without the need to create a PR on PEFT.

Very nice, but the problem is we don't have quanto LoRA layers yet.

@bghira
Copy link
Contributor Author

bghira commented Aug 8, 2024

ideally the lora remains unquantised and only the base weights are

@BenjaminBossan
Copy link
Member

Very nice, but the problem is we don't have quanto LoRA layers yet.

Yes, they have yet to be implemented. What I meant is that if someone has a POC, they can quickly test it with dynamic dispatch, no need to go through a lengthy process of adjusting PEFT code.

ideally the lora remains unquantised and only the base weights are

Yes, that's always the case with quantization in PEFT: The base weights are quantized but the LoRA weights are not.

@bghira
Copy link
Contributor Author

bghira commented Aug 8, 2024

wasn't sure exactly what was meant by quanto lora layers :D

@sayakpaul
Copy link
Member

Yes, that's always the case with quantization in PEFT: The base weights are quantized but the LoRA weights are not.

Then I am still a little unclear what would it mean to support injecting LoRA layers into a model that is quantized with quanto (which can be detected at runtime). Could you maybe elaborate a little more? @BenjaminBossan

@bghira
Copy link
Contributor Author

bghira commented Aug 8, 2024

we can run add_adapter on the quantised model it's just when we try loading the state dict that it really gets upset. the training of the lora also works

@BenjaminBossan
Copy link
Member

Then I am still a little unclear what would it mean to support injecting LoRA layers into a model that is quantized with quanto (which can be detected at runtime). Could you maybe elaborate a little more?

What I mean is that if you want to add a new quantization method and have written the layer class as POC, normally the next step would be to edit PEFT to make use of that class and create a PR on PEFT. E.g. like in this PR for EETQ. With dynamic dispatch, you can instead just do this to use the new class:

config = LoraConfig(...)
custom_module_mapping = {nn.Linear: MyQuantizedLinearClass}
config._register_custom_module(custom_module_mapping)

No PEFT source code needs to be altered.

we can run add_adapter on the quantised model it's just when we try loading the state dict that it really gets upset. the training of the lora also works

Oh, this is surprising, I'll have to look into this.

@sayakpaul
Copy link
Member

sayakpaul commented Aug 8, 2024

@BenjaminBossan this is likely because in add_adapter() we never operate at the state dict level like we do when calling set_peft_model_state_dict().

def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:

Regarding using the dispatching system, I think there might be a way. Consider the following example:

from optimum.quanto import quantize, qfloat8, freeze
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, n_dim):
        super().__init__()
        self.linear1 = nn.Linear(n_dim, 1204)
    def forward(self, x):
        return self.linear1(x)


model = Model(64)
quantize(model, weights=qfloat8)
freeze(model)

print(model)

It prints:

Model(
  (linear1): QLinear(in_features=64, out_features=1204, bias=True)
)

Then if I do this (mimicking what we are likely trying to do):

from optimum.quanto import quantize, qfloat8, freeze, requantize
from optimum.quanto.quantize import quantization_map
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, n_dim):
        super().__init__()
        self.linear1 = nn.Linear(n_dim, 1204)
    def forward(self, x):
        return self.linear1(x)


model = Model(64)

quantize(model, weights=qfloat8)
freeze(model)
print(model)

qmap = quantization_map(model)
qstate_dict = model.state_dict()

new_model = Model(64)
requantize(model, state_dict=qstate_dict, quantization_map=qmap)
print(model)

It prints:

Model(
  (linear1): QLinear(in_features=64, out_features=1204, bias=True)
)
Model(
  (linear1): QLinear(in_features=64, out_features=1204, bias=True)
)

So, it appears to me that we can probably use the dispatching system by mapping the nn.Linear to QLinear from quanto. @bghira possible to give this a try?

@BenjaminBossan
Copy link
Member

this is likely because in add_adapter() we never operate at the state dict level.

Hmm, but this still calls:

inject_adapter_in_model(adapter_config, self, adapter_name)

which starts the whole PEFT machinery of checking matching layers and replacing them with LoraLayers. I'd expect that stage to fail because PEFT should not be able to match the quanto layers. In the first example, PEFT was applied before quantization, so it makes sense that it would work. But later @bghira wrote:

we can run add_adapter on the quantised model

so that can't be the explanation. Possibly, even after quantization there are some non-quantized layers left that are matched by PEFT, thus preventing PEFT from erroring because it could not match any layers? I'll have to test but am currently not on the right machine to do so :)

@sayakpaul
Copy link
Member

Sure, thank you! There might be a way for us to meanwhile use the dispatcher and I updated my comment accordingly.

@BenjaminBossan
Copy link
Member

Okay, back after some testing. The reason why there is no error when adding a PEFT adapter on top of a quanto model is because the quanto QLinear layer is a subclass of nn.Linear. Therefore, PEFT simply applies the normal lora.Linear layer on top of it.

I did some quick tests and it appears it's actually working out of the box -- almost. What doesn't work for now is merging the LoRA weights into the QLinear layer. I tried in vain to make it work but gave up after a few minutes of trying. It's probably not difficult by I've never used quanto.

Anyway, it looks like support for quanto should not be that difficult to add. In fact, for simple use cases, it may already work correctly. However, I'd not recommend using it until we have run it through our test suite to ensure that we know what does and does not work. Until then, I'd recommend using one of the officially supported quantization methods in PEFT.

@bghira
Copy link
Contributor Author

bghira commented Aug 9, 2024

i would love to but most all of those are nvidia only, and i am working from MacOS

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 9, 2024

i would love to but most all of those are nvidia only, and i am working from MacOS

I see, all the more reasons to add quanto support soon. I'll open an issue on PEFT to call for contributions. If there is no taker, I'll work on it as soon as a have a bit of time on my hands.

Update: The issue: huggingface/peft#1997

@BenjaminBossan
Copy link
Member

I coded something up quickly: huggingface/peft#2000. It's still unfinished but in my limited testing looks like it works.

@Leommm-byte
Copy link
Contributor

Describe the bug

We have a KeyError when the state dict goes to load into the transformer.

Reproduction

import torch
from diffusers.models import FluxTransformer2DModel
from peft import LoraConfig, set_peft_model_state_dict
from diffusers import FluxPipeline
from diffusers.utils import convert_unet_state_dict_to_peft
from accelerate import Accelerator
from optimum.quanto import quantize, freeze, qint8

# Step 1: FluxTransformer2DModel Model
print("loading model")
model = FluxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-dev', subfolder='transformer')
accelerator = Accelerator()
# Step 2: Add LoRA Adapter
config = LoraConfig(r=8, lora_alpha=16, target_modules=["to_q", "to_k"], lora_dropout=0.1)
print("adding adapter (random)")
model.add_adapter(config)

# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)

# Step 4: prepare model
print("prepare model")
model = accelerator.prepare(model)

# Step 5: Load LoRA State Dictionary
lora_path = "checkpoint/"

print("retrieve lora state dict")
lora_state_dict = FluxPipeline.lora_state_dict(lora_path)
transformer_state_dict = {
    f'{k.replace("transformer.", "")}': v
    for k, v in lora_state_dict.items()
    if k.startswith("unet.")
}
print("convert state dict")
transformer_state_dict = convert_unet_state_dict_to_peft(
    transformer_state_dict
)
incompatible_keys = set_peft_model_state_dict(
    model, transformer_state_dict, adapter_name="default"
)
print(f"unexpected keys: {incompatible_keys.unexpected_keys}")

Logs

adding adapter (random)
quantizing model
freezing model
prepare model
retrieve lora state dict
convert state dict
peft_model_state_dict: dict_keys([])
Loading with state_dict odict_keys([])
state_dict: dict_keys([])
prefix: time_text_embed.timestep_embedder.linear_1.weight.
Traceback (most recent call last):
  File "/Users/bghira/src/SimpleTuner/test_flux.py", line 41, in <module>
    )
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 354, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
    load(self, state_dict)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 92, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

System Info

latest git main

Who can help?

@sayakpaul

I also have this error when I try to load a lora weight into the quantized transformer class. I am using this method because it is much faster to load the quantized weight rather than quantizing at runtime. The issue is now that, when I try loading the lora weights, I get similar error

File "C:\Users\teggy\Desktop\Comfy_Creator\Gen-Server\.venv\lib\site-packages\optimum\quanto\nn\qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "C:\Users\teggy\Desktop\Comfy_Creator\Gen-Server\.venv\lib\site-packages\optimum\quanto\tensor\qbytes.py", line 90, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

Here are some relevant part of my code

from optimum.quanto.models import QuantizedTransformersModel, QuantizedDiffusersModel
from transformers import T5EncoderModel
from diffusers import FluxTransformer2DModel

class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel

class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel
if lora_info is None:
            # If no LoRA info is provided, disable all LoRAs
            pipeline.unload_lora_weights()
        else:
            print("Loading LoRA weights...")
            adapter_name = lora_info["adapter_name"]
            print(f"Adapter Name: {adapter_name}")
            try:
                pipeline.load_lora_weights(
                    lora_info["repo_id"],
                    weight_name=lora_info["weight_name"],
                    adapter_name=adapter_name
                )
                print(f"LoRA adapter '{adapter_name}' loaded successfully.")
            except ValueError as e:
                if "already in use" in str(e):
                    print(f"LoRA adapter '{adapter_name}' is already loaded. Using existing adapter.")
                    
                else:
                    raise e

Are there any workaround or fix for this?
Thank you!

@BenjaminBossan
Copy link
Member

@Leommm-byte The issue is that non-strict loading of state_dicts is not yet supported in quanto, see the discussion here. I proposed a fix but we have yet to hear from the maintainer if this is the way to go.

@sayakpaul
Copy link
Member

Cc: @bghira for huggingface/optimum-quanto#278.

@Leommm-byte
Copy link
Contributor

@Leommm-byte The issue is that non-strict loading of state_dicts is not yet supported in quanto, see the discussion here. I proposed a fix but we have yet to hear from the maintainer if this is the way to go.

Thank you for your fix. It worked really well.

My own fix was to comment out these lines

incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)

            if incompatible_keys is not None:
                # check only for unexpected keys
                unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
                if unexpected_keys:
                    logger.warning(
                        f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                        f" {unexpected_keys}. "
                    )

            # Offload back.
            if is_model_cpu_offload:
                _pipeline.enable_model_cpu_offload()
            elif is_sequential_cpu_offload:
                _pipeline.enable_sequential_cpu_offload()

which wasn't very smart in retrospect 😅. I thought it was just for checking for incompatible keys. So the side effect was that the LoRA weight wasn't having any effect on the model's output.

Thank you so much once again. I'll continue using your fix till the maintainers decide to use this fix or create another. Cheers.

@bghira
Copy link
Contributor Author

bghira commented Aug 22, 2024

you should have seen my face as i read through your comment lmao

@sayakpaul
Copy link
Member

Can someone have a gist up that shows the application of the fix (ideally following the DreamBooth LoRA Flux script we have in the repo)? This is not to deter any finetuners like SimpleTuner or AI Toolkit but to have a more self-contained and minimal example.

I can do it tomorrow once I get back to my computer.

Question to @bghira and @Leommm-byte:

  • Do you notice any improvement in the wall-clock timing when doing LoRA over the quantized model? (I wouldn’t expect it but still checking)
  • How much memory savings can we expect?

@bghira
Copy link
Contributor Author

bghira commented Aug 22, 2024

baseline memory goes from 33GB VRAM at bf16 to ~17GB VRAM at int8 for LoRA

@sayakpaul
Copy link
Member

So, if we decrease resolution, seems like it could be possible to do it a free Colab?

@bghira
Copy link
Contributor Author

bghira commented Aug 22, 2024

no, reducing resolution really only increases the speed. paradoxically i'm training on 2048px with the same vram use but it goes slower. this is on 3x 4090.

@sayakpaul
Copy link
Member

How about coupling that with 8bit Adam or does that impact convergence?

@bghira
Copy link
Contributor Author

bghira commented Aug 22, 2024

i think DeepSpeed with a single GPU will be fine for 16GB VRAM. the weights at bf16 are just under 24G on their own. cutting that in half naively will give us just around 12G of memory for the full model weights. and then the optimiser states are like +100% more memory used, with AdamW. Lion fares better since it keeps fewer copies.

so something has to be offloaded even with stochastic bf16 + Lion, no autocast, even for LoRA at 16G VRAM. the way kohya is pulling this off is with a modified Adafactor with fused optimiser step + backward pass, and other offload tactics that don't exist in Diffusers

@bghira
Copy link
Contributor Author

bghira commented Aug 22, 2024

8bit Adam doesn't help memory pressure much or at all, it needs autocast and fp32 gradients as i understand it. when we don't do that, it just kinda gets stuck and slowly degrades.

@Leommm-byte
Copy link
Contributor

Can someone have a gist up that shows the application of the fix (ideally following the DreamBooth LoRA Flux script we have in the repo)? This is not to deter any finetuners like SimpleTuner or AI Toolkit but to have a more self-contained and minimal example.

I can do it tomorrow once I get back to my computer.

Question to @bghira and @Leommm-byte:

  • Do you notice any improvement in the wall-clock timing when doing LoRA over the quantized model? (I wouldn’t expect it but still checking)
  • How much memory savings can we expect?

Although my case was running inference rather than training, there's a significant improvement.

My rig has an rtx 4090 mobile (it's a laptop so 16gb VRAM).

Its inference is way faster from like impossible (with unquantized bf16 weights) to about 2 to 2 and a half minutes when loading from disk (using quantized transformer and T5 at fp8). But once it is in memory, it's like 20 to 30 sec.

VRAM usage is about 15.5gb.

@Leommm-byte
Copy link
Contributor

you should have seen my face as i read through your comment lmao

😭

@bghira
Copy link
Contributor Author

bghira commented Aug 22, 2024

for training speed, due to the fact that it will OOM, there's really no way to know how slowly bf16 will run on a 3090 vs fp8/int8. but fp8 is slower than int8 on 3090. not sure about 4090.

and A100 goes about 1.25 seconds per step at 512x512 whether it's a rank-16+int8 LoRA or a bf16 LoRA.

fully training the non-quantised BF16 weights on 1x A100-80G SXM4 with DeepSpeed ZeRO 2 is about 10 seconds per step at 1024x1024. this is the same speed with 3x A100-80G SXM4. but then you can do nicer batch sizes. and the VRAM used per GPU dropped with more GPUs.

increasing the batch size hurts throughput a lot. but not on the H100 - that is required to keep throughput scaling with batch size. that dang dispatch layer in the hw is just too good.

MI300X i paid to test yesterday. it's still in need of a lot of work for scaling with batch sizes but you just need one and no deepspeed to fully tune the model for $3.99/hr. H100 is the same cost, but needs DeepSpeed using 58,000MiB of VRAM and about 176GB of system memory.

and i implemented the offload techniques from kohya but their utility is limited to a single GPU. we can't fuse things as deeply on multigpu setups. so, for this you'll always need multigpu quantised training or tolerate the slowdown of heavy offloading through DeepSpeed, which is optimised for multi-GPU training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants