Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interesting Problems of Accuracy & Inference Speed with run_eval_needle.sh #70

Open
Treemann opened this issue Apr 18, 2024 · 0 comments

Comments

@Treemann
Copy link

Hi, I prepared the pg19 dataset and ran run_eval_needle.sh with different settings of parameters.
I have some questions about the experiment results and hope someone could help.

Device: NVIDIA V100/A100 GPUs
Script:

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export llama_tokenizer_path="weights/LWM-Text-Chat-1M-Jax/tokenizer.model"
export lwm_text_checkpoint="weights/LWM-Text-Chat-1M-Jax/params"
export haystack_file="data/pg19.jsonl"
export output_file="eval_needle_1m_jax.log"

export CUDA_VISIBLE_DEVICES=0,1,2,3

chuck_size=1024
ctx_len=6144
use_bolck=True
python3 -u scripts/eval_needle.py \
    --mesh_dim='!1,1,2,2' \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --update_llama_config="dict(theta=50000000,max_sequence_length=1048576,scan_attention=${use_bolck},scan_query_chunk_size=${chuck_size},scan_key_chunk_size=${chuck_size},scan_mlp=${use_bolck},scan_mlp_chunk_size=${chuck_size},scan_layers=True)" \
    --load_checkpoint="params::$lwm_text_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --output_file="$output_file" \
    --haystack_file="$haystack_file" \
    --max_tokens_per_batch=5000 \
    --context_lengths_min=${ctx_len} \
    --context_lengths_max=${ctx_len} \
    --n_context_length_intervals=1 \
    --n_document_depth_intervals=2 \
    --n_rounds=3
read

Some questions about the experiment results:

  1. The model seemed to have correct predictions with mesh_dim='1,2,2,1' and dtype=float32 or bf16. However, if sequence parallelism was adopted with mesh_dim='1,1,2,2', the prediction was wrong.
mesh_dim fp32 fp64 fp16 bf16
1,2,2,1 correct wrong wrong correct
1,1,2,2 wrong wrong wrong wrong

With further investigation, I found that the attention outputs began to have nan values from some middle layer and the final outputs of the networks are all nan.

...
attn_output shape: (array(1, dtype=int32), array(8192, dtype=int32), array(4096, dtype=int32)),
attn_output: [[[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524
-0.00334746]
[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524
-0.00334746]
[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524
-0.00334746]
...
[ nan nan nan ... nan nan
nan]
[ nan nan nan ... nan nan
nan]
[ nan nan nan ... nan nan
nan]]]
...
attn_output shape: (array(1, dtype=int32), array(1, dtype=int32), array(4096, dtype=int32)), attn_output: [[[nan nan nan ... nan nan nan]]]
...

  1. The inference speed was very slow with setting mesh_dim='1,1,2,2'. By counting the number of the printed attention outputs, I found that with mesh_dim='1,1,2,2', the FlaxLLaMABlock was called much more times than mesh_dim='1,2,2,1'
mesh_dim times of calling FlaxLLaMABlock
1,2,2,1 544
1,1,2,2 65568

Did anyone have similar findings and could share insight on these questions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant