Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Accelerate with TPU Pod VM like v3-32 #471

Closed
Flamentt opened this issue Jun 24, 2022 · 6 comments · Fixed by #1049
Closed

Using Accelerate with TPU Pod VM like v3-32 #471

Flamentt opened this issue Jun 24, 2022 · 6 comments · Fixed by #1049
Assignees
Labels
bug Something isn't working feature request Request for a new feature to be added to Accelerate TPU Bug or feature on TPU platforms

Comments

@Flamentt
Copy link

Flamentt commented Jun 24, 2022

Hi, thank you for great library.

I have just install accelerate on a TPU VM V3-32 but when I set number of TPU cores to 32 with accelerate config and run accelerate test, it throw an error:

ValueError: The number of devices must be either 1 or 8, got 32 instead

So that mean accelerate haven't supported training on a TPU pod VM. Can you please add this feature to Accelerate?

By the way, I meet another problem, too. If I use accelerate=0.9 with TPU VM v2-alpha, accelerate test run successfully. But if I use accelerate=0.10 with v2-alpha or tpu-vm-pt-1.11 or tpu-vm-pt-1.10, accelerate test can not finish runing, it just run forever.

And when I run

accelerate launch run_clm_no_trainer.py \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --model_name_or_path gpt2 \
    --output_dir /tmp/test-clm

it throw some errors (even accelerate=0.9 with TPU VM v2-alpha).

06/24/2022 18:10:16 - INFO - run_clm_no_trainer - ***** Running training *****
06/24/2022 18:10:16 - INFO - run_clm_no_trainer -   Num examples = 2318
06/24/2022 18:10:16 - INFO - run_clm_no_trainer -   Num Epochs = 3
06/24/2022 18:10:16 - INFO - run_clm_no_trainer -   Instantaneous batch size per device = 8
06/24/2022 18:10:16 - INFO - run_clm_no_trainer -   Total train batch size (w. parallel, distributed & accumulation) = 64
06/24/2022 18:10:16 - INFO - run_clm_no_trainer -   Gradient Accumulation steps = 1
06/24/2022 18:10:16 - INFO - run_clm_no_trainer -   Total optimization steps = 111
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 16.44ba/s]
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 16.31ba/s]
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 16.66ba/s]
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 16.12ba/s]
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 15.94ba/s]
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 15.75ba/s]
Grouping texts in chunks of 1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:02<00:00, 14.59ba/s]
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 17.02ba/s]
Grouping texts in chunks of 1024:  50%|███████████████████████████████████████████████████                                                   | 2/4 [00:00<00:00, 14.53ba/s]2022-06-24 18:10:19.812027: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:19.812100: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 17.28ba/s]
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 16.89ba/s]
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 15.94ba/s]
Grouping texts in chunks of 1024:  50%|███████████████████████████████████████████████████                                                   | 2/4 [00:00<00:00, 14.34ba/s]2022-06-24 18:10:20.217092: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:20.217159: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-06-24 18:10:20.223097: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:20.223158: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-06-24 18:10:20.231867: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:20.231934: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 16.53ba/s]
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 16.28ba/s]
Grouping texts in chunks of 1024: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14.42ba/s]
2022-06-24 18:10:20.468890: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:20.468975: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-06-24 18:10:20.474551: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:20.474636: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-06-24 18:10:20.509402: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-06-24 18:10:20.509462: E tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
  1%|█▏                                                                                                                                    | 1/111 [00:06<12:12,  6.66s/it]2022-06-24 18:11:19.419635: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)
https://symbolize.stripped_domain/r/?trace=7f147ec0c18b,7f147ec0c20f,7f13cd4ff64f,7f13c833ec97,7f13c8333b01,7f13c835429e,7f13c8353e0b,7f13c4f6793d,7f13c98422a8,7f13ccff5580,7f13ccff7943,7f13cd4d0f71,7f13cd4d07a0,7f13cd4ba32b,7f147ebac608&map=c5ea6dcea9ec73900e238cf37efee14d75fd7749:7f13c06a5000-7f13d0013e28 
*** SIGABRT received by PID 26683 (TID 28667) on cpu 14 from PID 26683; stack trace: ***
PC: @     0x7f147ec0c18b  (unknown)  raise
    @     0x7f120bb881e0        976  (unknown)
    @     0x7f147ec0c210       3968  (unknown)
    @     0x7f13cd4ff650         16  tensorflow::internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f13c833ec98        592  tensorflow::tpu::TpuProgramGroup::Initialize()
    @     0x7f13c8333b02       1360  tensorflow::tpu::TpuCompilationCacheExternal::InitializeEntry()
    @     0x7f13c835429f        800  tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsentHelper()
    @     0x7f13c8353e0c        128  tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsent()
    @     0x7f13c4f6793e        944  tensorflow::XRTCompileOp::Compute()
    @     0x7f13c98422a9        432  tensorflow::XlaDevice::Compute()
    @     0x7f13ccff5581       2080  tensorflow::(anonymous namespace)::ExecutorState<>::Process()
    @     0x7f13ccff7944         48  std::_Function_handler<>::_M_invoke()
    @     0x7f13cd4d0f72        128  Eigen::ThreadPoolTempl<>::WorkerLoop()
    @     0x7f13cd4d07a1         48  tensorflow::thread::EigenEnvironment::CreateThread()::{lambda()#1}::operator()()
    @     0x7f13cd4ba32c         80  tensorflow::(anonymous namespace)::PThread::ThreadFn()
    @     0x7f147ebac609  (unknown)  start_thread
https://symbolize.stripped_domain/r/?trace=7f147ec0c18b,7f120bb881df,7f147ec0c20f,7f13cd4ff64f,7f13c833ec97,7f13c8333b01,7f13c835429e,7f13c8353e0b,7f13c4f6793d,7f13c98422a8,7f13ccff5580,7f13ccff7943,7f13cd4d0f71,7f13cd4d07a0,7f13cd4ba32b,7f147ebac608&map=c5ea6dcea9ec73900e238cf37efee14d75fd7749:7f13c06a5000-7f13d0013e28,ca1b7ab241ee28147b3d590cadb5dc1b:7f11fee89000-7f120bebbb20 
E0624 18:11:19.687595   28667 coredump_hook.cc:292] RAW: Remote crash data gathering hook invoked.
E0624 18:11:19.687634   28667 coredump_hook.cc:384] RAW: Skipping coredump since rlimit was 0 at process start.
E0624 18:11:19.687656   28667 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0624 18:11:19.687666   28667 coredump_hook.cc:447] RAW: Sending fingerprint to remote end.
E0624 18:11:19.687679   28667 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0624 18:11:19.687727   28667 coredump_hook.cc:451] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0624 18:11:19.687735   28667 coredump_hook.cc:525] RAW: Discarding core.
E0624 18:11:19.966672   28667 process_state.cc:771] RAW: Raising signal 6 with default behavior

Can you please tell me which TPU VM version do you ussually use with Accelerate?

Thank you!

@muellerzr
Copy link
Collaborator

Thanks for this report @huunguyen10, I'll look into this further.

As to how we run tests, we use colab's v2 VM.

Re; your ValueError, can you provide the full stack trace for me to look at? I think I know what the problem is but that would be much appreciated!

Will look into the issue on v2-alpha, it may be a torch issue. We'll also see about setting up a v3-32 instance to test as well.

@muellerzr muellerzr added bug Something isn't working TPU Bug or feature on TPU platforms labels Jun 24, 2022
@Flamentt
Copy link
Author

Flamentt commented Jun 24, 2022

Thank you @muellerzr!

Here is the error I met:

nguyen@t1v-n-1b19a50e-w-0:~$ accelerate config
In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 0
Which type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU): 3
What is the name of the function in your script that should be launched in all parallel scripts? [main]: main
How many TPU cores should be used for distributed training? [1]:32
nguyen@t1v-n-1b19a50e-w-0:~$ accelerate test

Running:  accelerate-launch --config_file=None /usr/local/lib/python3.8/dist-packages/accelerate/test_utils/test_script.py
stderr: Traceback (most recent call last):
stderr:   File "/usr/local/bin/accelerate-launch", line 8, in <module>
stderr:     sys.exit(main())
stderr:   File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 574, in main
stderr:     launch_command(args)
stderr:   File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 564, in launch_command
stderr:     tpu_launcher(args)
stderr:   File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 394, in tpu_launcher
stderr:     xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
stderr:   File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 384, in spawn
stderr:     pf_cfg = _pre_fork_setup(nprocs)
stderr:   File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 199, in _pre_fork_setup
stderr:     raise ValueError(
stderr: ValueError: The number of devices must be either 1 or 8, got 32 instead
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/accelerate_cli.py", line 43, in main
    args.func(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/test.py", line 52, in test_command
    result = execute_subprocess_async(cmd, env=os.environ.copy())
  File "/usr/local/lib/python3.8/dist-packages/accelerate/test_utils/testing.py", line 276, in execute_subprocess_async
    raise RuntimeError(
RuntimeError: 'accelerate-launch --config_file=None /usr/local/lib/python3.8/dist-packages/accelerate/test_utils/test_script.py' failed with returncode 1

The combined stderr from workers follows:
Traceback (most recent call last):
  File "/usr/local/bin/accelerate-launch", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 574, in main
    launch_command(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 564, in launch_command
    tpu_launcher(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 394, in tpu_launcher
    xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 384, in spawn
    pf_cfg = _pre_fork_setup(nprocs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 199, in _pre_fork_setup
    raise ValueError(
ValueError: The number of devices must be either 1 or 8, got 32 instead

I used TPU VM v2-alpha, and above error happend with both accelerate 0.9 and 0.10.

@sumanthd17
Copy link

Were you able to get past this issue? @huunguyen10

@Ontopic
Copy link

Ontopic commented Aug 4, 2022

Would love to know as well what the follow-up on this is. Also see sumanthd17's issue

@huggingface huggingface deleted a comment from github-actions bot Aug 30, 2022
@muellerzr muellerzr added the feature request Request for a new feature to be added to Accelerate label Aug 30, 2022
@muellerzr
Copy link
Collaborator

We're going to keep this issue and the linked issue below open about the TPU pods, see Sylvain and I's last note on it for more information as to what's happening currently and the state we're at with it #501 (comment)

@muellerzr
Copy link
Collaborator

This has now been introduced in #1049. Please follow the new accelerate config command to set this up. Below are some directions:

  1. Install accelerate via `pip install git+https://github.com/huggingface/accelerate (and ensure each node has this installed as well)
  2. Very Important: Either torch_xla needs to be installed via git, or run wget https://raw.githubusercontent.com/pytorch/xla/master/torch_xla/distributed/xla_dist.py -O /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py on the host node only is all that should be needed I believe. If not use the tpu-config option or add it to the startup command (as we rely on that refactor of xla_dist to launch)
  3. Run accelerate config on the host node and configure it accordingly
  4. Based on the setup of the system, it may require to do sudo pip install. If so, the prompt in accelerate config should be set to True when asked about this, and accelerate config should be sudo accelerate config. (I hit some permissions issues, this has been my workaround for now)
  5. Download the script you wish to run into /usr/share/some_script
  6. Run accelerate launch /usr/share/some_script.py

The example script I use is located here:
https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/bbdfff6868cbf61fcc0dcff8b76fe64b06fe43ab/xla_script.py

We have also introduced a tpu-config command which will run commands across the pods, so you could instead of having a startup script to install everything perform:
accelerate tpu-config --command "sudo wget https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/bbdfff6868cbf61fcc0dcff8b76fe64b06fe43ab/xla_script.py -O /usr/share/xla_script.py"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feature request Request for a new feature to be added to Accelerate TPU Bug or feature on TPU platforms
Projects
None yet
4 participants