Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral loss instability #26498

Closed
1 of 4 tasks
teknium1 opened this issue Sep 29, 2023 · 71 comments
Closed
1 of 4 tasks

Mistral loss instability #26498

teknium1 opened this issue Sep 29, 2023 · 71 comments

Comments

@teknium1
Copy link

teknium1 commented Sep 29, 2023

System Info

Hello, I've been working with dhokas who finetuned Mistral's official instruct model. I have been trying to finetune mistral with several datasets over dozens of ablations. There is very insane loss instability training this model with transformers that never seems to appear with his training runs which do not use hf trainer.

I am opening this so we can get to the bottom of this. Here are some of my runs using axolotl with some datasets.

With hermes 2.0 dataset (unpublished):
https://wandb.ai/teknium1/hermes2.0-mistral-7b?workspace=user-teknium1

With Teknium/GPT4-LLM-CLEANED dataset
https://wandb.ai/teknium1/gpt4llm-mistral-7b

With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring):
https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

With OpenHermes dataset teknium1/openhermes:
https://wandb.ai/teknium1/hermes-mistral-7b

as can be seen, these loss charts with all these ablations are unreliable, and generally produce bad results no matter what hyperparams are changed.

Mistral dev who worked with me, he trained mistral with gpt4llm cleaned and got this result:
image

@younesbelkada @muellerz

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Train Mistral on any of the above datasets with Mistral's own finetune hyperparams as reported in mistral's discord and see the loss fail to work out

Expected behavior

A smooth or downward trajectory for the loss.

@teknium1
Copy link
Author

teknium1 commented Sep 29, 2023

I have tried:
2e-5, 1e-5, 8e-6, 6e-6, 4e-6, with and without flash attention/xformers/none, with and without packing, with 0.1 and 0.01 weight decay, with long, medium, and short warmups (between 0.01% and 80% warmup steps to total steps), I've tried with Hermes 2.0, Hermes 1.0 (which has been trained on llama fine in several occasions), and GPT4LLM datasets, I've tried with FSDP, With Deepspeed zero2 & zero3, with and without groupbylength, with updated adam beta and epsilons #adam_beta2: 0.95
#adam_epsilon: 0.00001

with and without max_grad_norm: 1.0. I've basically run out of hyperparams to try tuning - several on fresh venv's

@Ki6an
Copy link
Contributor

Ki6an commented Sep 30, 2023

I have also come across an issue involving an irregular loss curve for finetuning mistral 7b.
unusual_loss

@teknium1
Copy link
Author

For reference some of my loss charts:
image
image
image

@akjindal53244
Copy link

I am facing the same issue and loss is going up while finetuning on Dolly-15k dataset.

@adarshxs
Copy link

Same for me with the garage-bAInd/Open-Platypus Dataset. Though mine was extremely weird
image

@adamlin120
Copy link

Continue pre-training on Chinese/mandarin corpus
IMG_7827

Optimizer adamw
lr: 2.5e-5
Warmup: 4%
Bs 2
Seq Len 1024
Used flash attention in the pr

@adarshxs
Copy link

Continue pre-training on Chinese/mandarin corpus IMG_7827

Optimizer adamw lr: 2.5e-5 Warmup: 4% Bs 2 Seq Len 1024 Used flash attention in the pr

Any specific library you using for continued pre training?

@adamlin120
Copy link

adamlin120 commented Sep 30, 2023

Continue pre-training on Chinese/mandarin corpus IMG_7827

Optimizer adamw lr: 2.5e-5 Warmup: 4% Bs 2 Seq Len 1024 Used flash attention in the pr

Any specific library you using for continued pre training?

I am using SFTtrainer from trl. Noted that both runs failed. Orange one cannot converge. Green one dropped to loss=0.0 but in fact the model produced garbages

@adarshxs
Copy link

adarshxs commented Sep 30, 2023

I am using SFTtrainer from trl. Noted that both runs failed. Orange one cannot converge. Green one dropped to loss=0.0 but in fact the model produced garbages

image
Same with fine tuning. The output is pure garbage even with all the standard hyperparams I used for fine tuning llama.

@sparverius
Copy link

With Teknium/GPT4-LLM-CLEANED dataset https://wandb.ai/teknium1/gpt4llm-mistral-7b

With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring): https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

@teknium1 these both 404 😞

@teknium1
Copy link
Author

With Teknium/GPT4-LLM-CLEANED dataset https://wandb.ai/teknium1/gpt4llm-mistral-7b
With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring): https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

@teknium1 these both 404 😞

Sorry, my projects default to private, public'ed them

@bdytx5
Copy link

bdytx5 commented Sep 30, 2023

How did you load your model?

@teknium1
Copy link
Author

How did you load your model?

with transformers? or do you mean precision?

@bdytx5
Copy link

bdytx5 commented Sep 30, 2023

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

@teknium1
Copy link
Author

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

@bdytx5
Copy link

bdytx5 commented Sep 30, 2023

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

I see. I guess one idea to sanity check could be to load the model using the reference implementation and ensure it behaves similarly to the HuggingFace version.

@teknium1
Copy link
Author

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

I see. I guess one idea to sanity check could be to load the model using the reference implementation and ensure it behaves similarly to the HuggingFace version.

Do you mean outside of huggingface/hf trainer? The mistral dev did do this, we have totally different training results when he trains the same dataset, same hyperparams, without hf trainer.

@bdytx5
Copy link

bdytx5 commented Sep 30, 2023

How did you load your model?

with transformers? or do you mean precision?

I was just wondering if you used one of the HuggingFace AutoModel classes or if you loaded it using the Mistral reference implementation.

MistralForCausalLM

I see. I guess one idea to sanity check could be to load the model using the reference implementation and ensure it behaves similarly to the HuggingFace version.

Do you mean outside of huggingface/hf trainer? The mistral dev did do this, we have totally different training results when he trains the same dataset, same hyperparams, without hf trainer.

Yeah I mean just making sure both models are behaving similarly for a single forward/backwards pass on the same data without the trainer. If they are the same, then my guess is it probably narrows it down to the Trainer

@teknium1
Copy link
Author

Indeed, they are not the same. They are actually completely inverse lol

@bdytx5
Copy link

bdytx5 commented Sep 30, 2023

Indeed, they are not the same. They are actually completely inverse lol

interesting.

@Undi95
Copy link

Undi95 commented Sep 30, 2023

image

Trying the Pippa-ShareGPT dataset from huggingface, the loss is big.
https://wandb.ai/undis95/pippa-sharegpt-13b-qlora?workspace=user-undis95
I trained others datasets, but don't have screenshot of the loss nor the wandb.ai data since I just learned all this.
Data and dataset can be seen at source, OG dataset are always linked:

https://huggingface.co/Undi95/Mistral-pippa-sharegpt-7b-qlora
https://huggingface.co/Undi95/Mistral-7B-smoll_pippa-lora
https://huggingface.co/Undi95/Mistral-7B-roleplay_alpaca-lora

Result are not the one I expected, and I can't find a way to train properly.

@bdytx5
Copy link

bdytx5 commented Oct 2, 2023

I made a script that compares the last hidden state embeddings of both

Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 ]
[ 0.1438 0.2181 0.0925 ]
[ 0.2527 0.8457 0.8496 ]
[ 0.1675 0.07324 1.037 ]
[ 0.881 -0.614 0.1123 ]]
Sampled values from Hugging Face embedding: [[-1.7 0.5347 -1.733 ]
[ 1.075 1.69 0.7036]
[ 1.983 6.86 6.73 ]
[ 1.353 0.615 8.5 ]
[ 9.23 -6.65 1.188 ]]
Embedding difference (L2 norm): inf

see comparison script at https://github.com/bdytx5/mistral7B_finetune/blob/main/train/dev/cmp_models.py

also, you will have to add

def get_last_hidden_state(
    self,
    input_ids: torch.Tensor,
    cache: RotatingBufferCache,
    seqlens: List[int],
) -> torch.Tensor:
    assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
    assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

    input_metadata = cache.get_input_metadata(seqlens)
    h = self.tok_embeddings(input_ids)
    freqs_cis = self.freqs_cis[input_metadata.positions]

    for layer_id, layer in enumerate(self.layers):
        h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

    cache.update_seqlens(seqlens)

    return h  # Return the embeddings before the output layer.        

into the 'transformer' class of the reference implementation

@teknium1
Copy link
Author

teknium1 commented Oct 2, 2023

I made a script that compares the last hidden state embeddings of both

Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 ] [ 0.1438 0.2181 0.0925 ] [ 0.2527 0.8457 0.8496 ] [ 0.1675 0.07324 1.037 ] [ 0.881 -0.614 0.1123 ]] Sampled values from Hugging Face embedding: [[-1.7 0.5347 -1.733 ] [ 1.075 1.69 0.7036] [ 1.983 6.86 6.73 ] [ 1.353 0.615 8.5 ] [ 9.23 -6.65 1.188 ]] Embedding difference (L2 norm): inf

see comparison script at https://github.com/bdytx5/mistral7B_finetune/blob/main/train/dev/cmp_models.py

also, you will have to add

def get_last_hidden_state(
    self,
    input_ids: torch.Tensor,
    cache: RotatingBufferCache,
    seqlens: List[int],
) -> torch.Tensor:
    assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
    assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

    input_metadata = cache.get_input_metadata(seqlens)
    h = self.tok_embeddings(input_ids)
    freqs_cis = self.freqs_cis[input_metadata.positions]

    for layer_id, layer in enumerate(self.layers):
        h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

    cache.update_seqlens(seqlens)

    return h  # Return the embeddings before the output layer.        

into the 'transformer' class of the reference implementation

So is this the cause of the loss issues or just a cleaner more proper implementation?

@bdytx5
Copy link

bdytx5 commented Oct 3, 2023

I made a script that compares the last hidden state embeddings of both
Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 ] [ 0.1438 0.2181 0.0925 ] [ 0.2527 0.8457 0.8496 ] [ 0.1675 0.07324 1.037 ] [ 0.881 -0.614 0.1123 ]] Sampled values from Hugging Face embedding: [[-1.7 0.5347 -1.733 ] [ 1.075 1.69 0.7036] [ 1.983 6.86 6.73 ] [ 1.353 0.615 8.5 ] [ 9.23 -6.65 1.188 ]] Embedding difference (L2 norm): inf
see comparison script at https://github.com/bdytx5/mistral7B_finetune/blob/main/train/dev/cmp_models.py
also, you will have to add

def get_last_hidden_state(
    self,
    input_ids: torch.Tensor,
    cache: RotatingBufferCache,
    seqlens: List[int],
) -> torch.Tensor:
    assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
    assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0])

    input_metadata = cache.get_input_metadata(seqlens)
    h = self.tok_embeddings(input_ids)
    freqs_cis = self.freqs_cis[input_metadata.positions]

    for layer_id, layer in enumerate(self.layers):
        h = layer(h, freqs_cis, cache.get_view(layer_id, input_metadata))

    cache.update_seqlens(seqlens)

    return h  # Return the embeddings before the output layer.        

into the 'transformer' class of the reference implementation

So is this the cause of the loss issues or just a cleaner more proper implementation?

It's definitely possible that a difference in initial weights is causing the strange training behavior. I might try using the official weights and converting it with their script to make sure the weights on huggingface are the same as the official weights.

One thing I have noticed is the config class for the model has default "rms_norm_eps": 1e-06 where the config used on huggingface hub uses 1e-05. I'm not sure if this matters but I might try converting the weights to make sure that they were originally converted using the right config. You can find the default config here https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/configuration_mistral.py

@bdytx5
Copy link

bdytx5 commented Oct 3, 2023

To follow up Tek, fter looking a little closer at this final layer embeddings

Sampled values from Mistral embedding: [[-1.635 0.4966 -1.647 2.324 -0.1011 ]
[ 0.1438 0.2181 0.0925 -1.136 0.2788 ]
[ 0.2527 0.8457 0.8496 -0.4353 -0.3838 ]
[ 0.1675 0.07324 1.037 -1.225 0.158 ]
[ 0.881 -0.614 0.1123 -1.201 0.2915 ]]
Sampled values from Hugging Face embedding: [[-1.706 0.593 -2.016 2.396 -0.05334]
[ 2.277 0.762 0.0974 -8.88 3.088 ]
[ 2.75 5.703 6.695 -4.22 -2.928 ]
[ 1.782 -0.5884 8.914 -9.2 1.583 ]
[ 7.8 -5.42 1.145 -9.29 4.605 ]]
Embedding difference (L2 norm): inf

The huggingface outputs seem pretty high in comparison to the official ones which does seem suspicious...

@younesbelkada
Copy link
Contributor

younesbelkada commented Oct 3, 2023

Hi @teknium1 @bdytx5

Reading through the thread and the options you have tried I first suspected that the issue might come from the new window causal mask
On my end I have tried to FT mistral-7b using QLoRA, with 2 different approaches:

1- Using vanilla causal mask
2- Using the window attention mask

I have fine-tuned the 7B using QLoRA, this script and using a context length of 512 and sliding window size of 256 to make sure the sliding window mask will behave correctly: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da with model_id being changed to mistral 7b, with packing and here is the behaviour of the losses

Screenshot 2023-10-03 at 13 52 24

Despite the model not "nicely" converging as the ideal loss curve you shared, the model manages to produce generation that are coherent with Guanaco dataset

# input: ### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant:

>>> '### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant: Monopsony is a market structure where there is only one buyer of a good or service. In the context of the labour market, a monopsony occurs when there is only one employer in a particular industry or region. This can happen for a variety of reasons, such as government regulation, natural monopolies, or the existence of a single large firm that dominates the market.\n\nThe concept of monopsony in the labour market has gained increasing attention in recent years'

Model weights here: https://huggingface.co/ybelkada/mistral-7b-guanaco

What @bdytx5 said makes sense, there might be some differences between original model's logits and ours, indeed HF version uses 1e-5: https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json#L16 whereas mistral uses 1e-6: https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L129

@teknium1 can you try to run a training with this version of the model instead: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35 just pass revision="refs/pr/35" when calling from_pretrained

@danieldk
Copy link
Member

danieldk commented Oct 3, 2023

Reading through the thread and the options you have tried I suspected that the issue might come from the new window causal mask

I haven't looked into much detail yet, but the mask seems to unconditionally attend to cached key/values. Shouldn't the sliding window apply to cached key/values as well?

mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)

(In the case of generating a batch of single tokens at a time, there is also https://github.com/huggingface/transformers/blob/ae9a344cce52ff244f721425f660b55ebc522b88/src/transformers/models/mistral/modeling_mistral.py#L795C30-L795C30, which skips applying the window to the k/v cache.)

@teknium1
Copy link
Author

teknium1 commented Oct 3, 2023

Hi @teknium1 @bdytx5

Reading through the thread and the options you have tried I first suspected that the issue might come from the new window causal mask On my end I have tried to FT mistral-7b using QLoRA, with 2 different approaches:

1- Using vanilla causal mask 2- Using the window attention mask

I have fine-tuned the 7B using QLoRA, this script and using a context length of 512 and sliding window size of 256 to make sure the sliding window mask will behave correctly: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da with model_id being changed to mistral 7b, with packing and here is the behaviour of the losses

Screenshot 2023-10-03 at 13 52 24

Despite the model not "nicely" converging as the ideal loss curve you shared, the model manages to produce generation that are coherent with Guanaco dataset

# input: ### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant:

>>> '### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant: Monopsony is a market structure where there is only one buyer of a good or service. In the context of the labour market, a monopsony occurs when there is only one employer in a particular industry or region. This can happen for a variety of reasons, such as government regulation, natural monopolies, or the existence of a single large firm that dominates the market.\n\nThe concept of monopsony in the labour market has gained increasing attention in recent years'

Model weights here: https://huggingface.co/ybelkada/mistral-7b-guanaco

What @bdytx5 said makes sense, there might be some differences between original model's logits and ours, indeed HF version uses 1e-5: https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json#L16 whereas mistral uses 1e-6: https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L129

@teknium1 can you try to run a training with this version of the model instead: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35 just pass revision="refs/pr/35" when calling from_pretrained

Next time I try a full finetune I will. I actually did succeed at training airoboros' dataset over mistral 7b, with a qlora. Leading me to one of two conclusions:

One (or more) of the datasets for hermes 2.0 is malformed, or, qlora is the only way to get the reliable training/good loss curves that I want atm. Will try with the revision next full finetune I try.

@teknium1
Copy link
Author

teknium1 commented Oct 4, 2023

On a side note about Mistral, @younesbelkada,

When I inference 7b Mistral on a 4090, with just 2k max seq length, It uses >24gb of vram. It hits 23.3GB of vram used then starts offloading to CPU.

image

The code I run to make this happen:

import torch#, json, os, sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaForCausalLM, MistralForCausalLM
#import bitsandbytes

tokenizer = LlamaTokenizer.from_pretrained('./collectivecognition-run6', trust_remote_code=True)
model = MistralForCausalLM.from_pretrained(
    "./collectivecognition-run6",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=False
    #trust_remote_code=True
)
benchmarks = [
    "Hello, tell me about the history of the United States",
    "Roleplay as a scientist, who just discovered artificial general intelligence. What do you think about this discovery? What possibilities are there now?"]

index = 0
for obj in benchmarks:
    

    index += 1
    if index < 1:
        continue
    else:
        start_time = time.time()  # Start timing
        prompt = f"USER:\n{obj}\n\nASSISTANT:\n"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
        generated_ids = model.generate(input_ids, max_new_tokens=2048, temperature=None)#, do_sample=True, eos_token_id=tokenizer.eos_token_id)
        response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_space=True)
        print(f"Response  {index}: {response}")

        end_time = time.time()  # End timing
        elapsed_time = end_time - start_time  # Calculate time taken for the iteration
        print(f"Time taken for Response {index}: {elapsed_time:.4f} seconds")
        print(f"tokens total: {len(tokenizer.encode(response))}")

@younesbelkada
Copy link
Contributor

@teknium1
I believe because the vanilla implementation we have currently in transformers does not allow cache slicing as per the original repository.
To benefit from fixed-size cache and memory efficient generation, you can use the Flash Attention 2 version of the model

import torch#, json, os, sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaForCausalLM, MistralForCausalLM
#import bitsandbytes

tokenizer = LlamaTokenizer.from_pretrained('./collectivecognition-run6', trust_remote_code=True)
model = MistralForCausalLM.from_pretrained(
    "./collectivecognition-run6",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_flash_attention_2=True
)
benchmarks = [
    "Hello, tell me about the history of the United States",
    "Roleplay as a scientist, who just discovered artificial general intelligence. What do you think about this discovery? What possibilities are there now?"]

index = 0
for obj in benchmarks:
    

    index += 1
    if index < 1:
        continue
    else:
        start_time = time.time()  # Start timing
        prompt = f"USER:\n{obj}\n\nASSISTANT:\n"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
        generated_ids = model.generate(input_ids, max_new_tokens=2048, temperature=None)#, do_sample=True, eos_token_id=tokenizer.eos_token_id)
        response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_space=True)
        print(f"Response  {index}: {response}")

        end_time = time.time()  # End timing
        elapsed_time = end_time - start_time  # Calculate time taken for the iteration
        print(f"Time taken for Response {index}: {elapsed_time:.4f} seconds")
        print(f"tokens total: {len(tokenizer.encode(response))}")

Check the results of my benchmark here: #26464 (comment)

@vince62s
Copy link

Did you see my other comment above wrt inference and rotating cache?

@younesbelkada
Copy link
Contributor

younesbelkada commented Oct 10, 2023

I have the impression that in HF you implemented only the sliding windows attention by playing only on the attention mask and ONLY at training time, which means that at inference, the full length is taken into account, am I correct ?

If you use the vanilla HF attention yes, that is the case we did not implemented the rotating buffer cache mechanism as it requires an important refactor

However we tried to mimic the rotating buffer caching mechanism by constraining it only in the case where padding_side=left for FA-2 models by shifting the cache and slicing out the previous tokens when generating the next token. See my benchmarks here for more details: #26464 (comment)

@vince62s
Copy link

ok I get it, here: https://github.com/huggingface/transformers/pull/26464/files#diff-fa1653b47666859672060712644a8c40b2e61eb1b79c06a21f9b94569217ed43R372-R393
Anyway it requires some hardware to support seqlen > 4096 ....

@younesbelkada
Copy link
Contributor

younesbelkada commented Oct 10, 2023

yes exactly

Anyway it requires some hardware to support seqlen > 4096 ....

No you can scale to very large sequence length as the cache will be always having 4096 tokens, similarly as the rotating buffer cache from original mistral repository.

Per my understanding (cc @timlacroix please correct me if I am wrong) since we always use absolute positional embedding the model is able to keep the whole context even if we go beyond 4096 tokens.
In case one feeds to the model a super large context (>4096) directly on the first iteration, you will indeed need to enough compute but since the FA module will use sliding window attention, it should be quite memory efficient. Slicing the cache afterwards is not a problem since the model has already computed attention scores based on the entire context on the first iteration so the information is not lost.
In case of batched generation it is slightly more complex since we don't follow the exact same procedure as mistral's rotating buffer cache, we slice out the first tokens of the cache after the first iteration. But in case of BS=1 you should get pretty decent performance, if you have a hardware that supports FlashAttention 2 you can try to generate up to very large number of tokens without any major issue I believe

@vince62s
Copy link

hmm the cache size is not the only limiting factor. You still need to forward the full sequence to the model, and the flash2 still happens with the full length even if the mechanism makes it linear to length (and not quadratic)

@younesbelkada
Copy link
Contributor

but that's the case in any case right? for the first forward is you pass a large context you'll need to compute the attention scores on all tokens.

@teknium1
Copy link
Author

teknium1 commented Oct 13, 2023

@younesbelkada @bdytx5 @vince62s @arthurmensch

Okay update on the issue.

image

The above image is testing with deepspeed zero 2 vs FSDP. Zero 2 is the more stable trajectory run. Same hyperparams on all else. I feel like I tested with zero3 in the past, and found same as FSDP run, a U shaped pattern, but I am not sure atm.

At the moment I dont know if it is being caused by axolotl's interactions with FSDP, or if it is something in transformers/accelerate/who knows what. But this seems like an important development in figuring out whats going on, not sure how much you guys can look into it, but figured I'd place the info here in case it isn't axolotl's code.

edit: nevermind...
image

however, it still looks far better than my loss curves on runs with much lower LR's than this one above (it has 2.5e-5)
image

@teknium1
Copy link
Author

Ok I did a new longer run with deepspeed zero 2 vs fsdp all else same:
image

Something about fsdp is making it converge slower (and technically, loss is not moving downward at all, very very very slightly upward) - with LR 4e-6

@teknium1
Copy link
Author

Zero 3 and Zero 2 seem fine, just not FSDP. I will reference the issue in axolotl and pytorch repos

image

@nps798
Copy link

nps798 commented Oct 15, 2023

image

for me, using transformer, trainer and custom dataset, batch size of 2, accumulation of 6, training loss drop to 0.0 after certain points. Eval loss become NaN
I am using torch_dtype of torch.float16

I ve seen someone saying change float16 to bfloat16 ?

@younesbelkada
Copy link
Contributor

hi @nps798
Yes I think using bfloat16 is preferrable to be on the safe zone. Also something strange that I have noticed is that in case you use padding, make sure to set padding_side="right" in case you train with padding tokens: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da?permalink_comment_id=4636728#gistcomment-4636728

@nps798
Copy link

nps798 commented Oct 16, 2023

hi @nps798 Yes I think using bfloat16 is preferrable to be on the safe zone. Also something strange that I have noticed is that in case you use padding, make sure to set padding_side="right" in case you train with padding tokens: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da?permalink_comment_id=4636728#gistcomment-4636728

Thanks for your reply. I'll give it a try soon.

BTW, I have just encountered another issue with my previous float16 and padding left setting, qlora
I ve checked my input batch data near around those batches (yeah I print out all batch on each step), nothing weird or special.
I check all the model's parameters with the following code

for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f'NaN value detected in model weights: {name}')
        if torch.isinf(param).any():
            print(f'Infinity value detected in model weights: {name}')

Nothing was printed.

So...
Correct me if I am wrong, the folloinwg NaN is not coming from problematic dataset.
It is related to some weights of the model being too small or too big, and the NaN will be produced by any dataset. And are unable to detect beforehand ?

input[0] has nans
output has nans

Detected inf/nan during batch_number=54681
Last 21 forward frames:
abs min abs max metadata
base_model.model.model.layers.30.mlp.gate_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 1.85e+02 input[0]
5.96e-08 3.02e+01 output
base_model.model.model.layers.30.mlp.act_fn SiLUActivation
5.96e-08 3.02e+01 input[0]
0.00e+00 2.39e+01 output
base_model.model.model.layers.30.mlp.up_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 1.85e+02 input[0]
5.96e-08 2.36e+01 output
base_model.model.model.layers.30.mlp.down_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 3.70e+02 input[0]
0.00e+00 1.38e+02 output
base_model.model.model.layers.30.mlp MistralMLP
0.00e+00 1.85e+02 input[0]
0.00e+00 1.38e+02 output
base_model.model.model.layers.30 MistralDecoderLayer
0.00e+00 3.05e+02 input[0]
0.00e+00 1.67e+02 output[0]
0.00e+00 1.68e+01 output[1][0]
0.00e+00 8.28e+00 output[1][1]
base_model.model.model.layers.31.input_layernorm MistralRMSNorm
8.36e-01 8.75e+00 weight
0.00e+00 1.67e+02 input[0]
0.00e+00 9.58e+01 output
base_model.model.model.layers.31.self_attn.q_proj.lora_dropout.default Dropout
0.00e+00 9.58e+01 input[0]
0.00e+00 1.01e+02 output
base_model.model.model.layers.31.self_attn.q_proj.lora_A.default Linear
9.78e-08 1.07e-01 weight
0.00e+00 1.01e+02 input[0]
2.04e-03 8.38e+01 output
base_model.model.model.layers.31.self_attn.q_proj.lora_B.default Linear
1.98e-07 8.64e-02 weight
2.04e-03 8.38e+01 input[0]
2.06e-07 2.49e+01 output
base_model.model.model.layers.31.self_attn.q_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 9.58e+01 input[0]
0.00e+00 2.62e+01 output
base_model.model.model.layers.31.self_attn.k_proj.lora_dropout.default Dropout
0.00e+00 9.58e+01 input[0]
0.00e+00 1.01e+02 output
base_model.model.model.layers.31.self_attn.k_proj.lora_A.default Linear
2.39e-07 7.29e-02 weight
0.00e+00 1.01e+02 input[0]
6.44e-05 5.60e+01 output
base_model.model.model.layers.31.self_attn.k_proj.lora_B.default Linear
3.00e-07 6.73e-02 weight
6.44e-05 5.60e+01 input[0]
4.96e-07 1.24e+01 output
base_model.model.model.layers.31.self_attn.k_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 9.58e+01 input[0]
0.00e+00 1.85e+01 output
base_model.model.model.layers.31.self_attn.v_proj.lora_dropout.default Dropout
0.00e+00 9.58e+01 input[0]
0.00e+00 1.01e+02 output
base_model.model.model.layers.31.self_attn.v_proj.lora_A.default Linear
1.05e-07 1.07e-01 weight
0.00e+00 1.01e+02 input[0]
1.04e-03 5.54e+01 output
base_model.model.model.layers.31.self_attn.v_proj.lora_B.default Linear
7.20e-07 3.79e-02 weight
1.04e-03 5.54e+01 input[0]
7.59e-07 6.53e+00 output
base_model.model.model.layers.31.self_attn.v_proj Linear4bit
0.00e+00 2.55e+02 weight
0.00e+00 9.58e+01 input[0]
0.00e+00 8.99e+00 output
base_model.model.model.layers.31.self_attn.rotary_emb MistralRotaryEmbedding
0.00e+00 8.99e+00 input[0]
5.15e-05 1.00e+00 output[0]
0.00e+00 1.00e+00 output[1]
base_model.model.model.layers.31.self_attn.o_proj Linear4bit
0.00e+00 2.55e+02 weight
nan nan input[0]
nan nan output

@muximus3
Copy link

muximus3 commented Oct 18, 2023

image
My training loss is behaving strangely as it suddenly explodes at different positions during each training. I attempted to resolve this issue by following the instructions in mistral-7b-instruct and setting padding_side to "right", with pad_token being set as eos_token, but it didn't solve the problem. I use deepspeed stage3 and bfloat16.

@nps798
Copy link

nps798 commented Oct 18, 2023

@younesbelkada thank you
I set the torch dtype to bf16 (while remaining the padding as left

successfully qlora fine tuning with 5 epoch without exploding loss or zero loss.

will keep experiment some other combinations of parameters

@teknium1
Copy link
Author

teknium1 commented Oct 18, 2023

I can confirm at least 2 other people have this issue with FSDP now. I still see loss go up after per-epoch drops in my training runs with deepspeed as well however, leaving me concerned but in a better state than previously.. which was always U shaped loss curves
image
image

@younesbelkada
Copy link
Contributor

Hi everyone
Thanks a lot for the deep investigation, recently @pacman100 managed to successfully fine-tune llama (from what I have understood the issue is quite agnostic to the architecture) using FSDP and shared some insights here: huggingface/accelerate#2127 (comment)
It seems the solution is to not load the model in bf16 and instead enable mixed precision training through TrainingArguments by passing bf=16 cc @pacman100 in case I missed something

@jph00
Copy link

jph00 commented Nov 15, 2023

Thanks a lot for the deep investigation, recently @pacman100 managed to successfully fine-tune llama (from what I have understood the issue is quite agnostic to the architecture) using FSDP and shared some insights here

I think this was a misunderstanding, and actually it's not successfully training. However @tmabraham did show a workaround in that thread.

@pacman100
Copy link
Contributor

Hello,

I ran the below experiment to see the fine-tuning using FSDP and Mistral was as expected. Below are the results:

  1. Codebase: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training
  2. Dataset: smangrul/chat-instruct-mixer
  3. Model: mistralai/Mistral-7B-v0.1
  4. Accelerate config after running accelerate config --config_file fsdp_config.yaml and answering the questionnaire:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  1. Command:
accelerate launch \
    --config_file configs/fsdp_config.yaml \
    train.py \
    --model_name "mistralai/Mistral-7B-v0.1" \
    --dataset_name "smangrul/chat-instruct-mixer" \
    --max_seq_len 4096 \
    --max_steps 5000 \
    --logging_steps 25 \
    --eval_steps 1000 \
    --save_steps 1000 \
    --bf16 True \
    --packing True \
    --output_dir "/fsx/sourab/experiments/full-finetune-mistral-7b-fsdp-chat-asst" \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataset_text_field "content" \
    --use_gradient_checkpointing False \
    --learning_rate 5e-6  \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --max_grad_norm 1.0 \
    --use_flash_attn True
  1. Training plots at the end of 1000 steps:
Screenshot 2023-11-16 at 1 59 09 PM
  1. Observations:
    a. Loss is going down as expected and it is successfully training.
    b. Sensitivity to learning rate: When I used learning rates of 5e-5 or 2e-5, the training was not converging properly. 5e-6 worked best for my dataset. So, when fully fine-tuning, hyperparameter tuning is important.
    c. seq-length 4096 with batch size 8 (per GPU 1 and gradient accumulation steps 1) has lower loss when compared to seq-length 2048 with batch size 16 (per GPU 1 and gradient accumulation steps 2).
  2. Library versions:
  • Output of transformers-cli env:
- `transformers` version: 4.35.2
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.11.4
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.2
- Accelerate version: 0.24.1
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.1.0.dev20230809 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • Output of accelerate env:
- `Accelerate` version: 0.24.1
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.11.4
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.1.0.dev20230809 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 1121.82 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
	Not found
  • flash-attn: 2.3.3

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@teknium1
Copy link
Author

Is this solved due to the previous mention?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests