Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method #940

Closed
clam004 opened this issue Dec 22, 2022 · 31 comments
Assignees
Labels
bug Something isn't working solved The bug or feature request has been solved, but the issue is still opened

Comments

@clam004
Copy link

clam004 commented Dec 22, 2022

Im using python3.8, pytorch 1.12.1, on Ubuntu 20.04.4 LTS (GNU/Linux 5.15.0-1029-azure x86_64) trying to use 2 V100 GPUs and CUDA Version: 11.6 from within my notebook using the notebook_launcher

notebook_launcher(training_loop, args, num_processes=2)

@muellerzr
Copy link
Collaborator

I need much more information than this, please provide if possible either the contents of each cell or the notebook itself. I can't do anything with what you've provided :)

@clam004
Copy link
Author

clam004 commented Dec 22, 2022

Wow, what great response time! here are the cells

0

import numpy as np

import torch

from transformers import GPT2Tokenizer
from transformers import TrainingArguments

from accelerate import Accelerator

from myalgorithms import GPT2HeadWithValueModel # from https://github.com/lvwerra/trl 
from mydatasets import FineTuneDataset

1

train_dataset = torch.load('large_data/train_dataset.pth')

2

def training_loop(num_train_epoch, per_device_train_batch_size, gradient_accumulation_steps):
    
    default_args = {
        "output_dir": "tmp",
        "evaluation_strategy": "steps",
        "num_train_epochs": num_train_epoch,
        "log_level": "error",
        "report_to": "none",
    }
    
    model = GPT2HeadWithValueModel.from_pretrained("gpt2")

    training_args = TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        gradient_checkpointing=True,
        fp16=True,
        **default_args,
    )

    train_dataloader = DataLoader(train_dataset, batch_size=training_args.per_device_train_batch_size)

    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=5e-5, 
        momentum=0.1,
        weight_decay=0.1,
        nesterov=True,
    )

    accelerator = Accelerator(fp16=training_args.fp16)

    model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

    model.train()
    for step, batch in enumerate(tqdm(train_dataloader), start=1):

        inputs = torch.tensor(batch['input_ids'])
        inputs = inputs.to(model.device)

        outputs = model(inputs, return_dict=True)

        logits = outputs.logits

        shift_labels = inputs[..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()

        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss / training_args.gradient_accumulation_steps

        accelerator.backward(loss)

        if step % training_args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

3

num_train_epoch, per_device_train_batch_size, gradient_accumulation_steps = 1,2,4
args = (num_train_epoch, per_device_train_batch_size, gradient_accumulation_steps)
notebook_launcher(training_loop, args, num_processes=2)

@lewtun
Copy link
Member

lewtun commented Dec 23, 2022

FYI we're also seeing a similar error reported by participants in the DreamBooth Hackathon when using T4 x 2 GPUs on Kaggle.

Steps to reproduce

  1. Open the Kaggle notebook. (I simplified it to the essential steps)
  2. Select the T4 x 2 GPU accelerator and install the dependencies + restart notebook (Kaggle has an old version of torch preinstalled)

Screen Shot 2022-12-23 at 17 36 22

3. Run all remaining cells

Here's the output from accelerate env:

- `Accelerate` version: 0.15.0
- Platform: Linux-5.15.65+-x86_64-with-debian-bullseye-sid
- Python version: 3.7.12
- Numpy version: 1.21.6
- PyTorch version (GPU?): 1.11.0 (True)
- `Accelerate` default config:
	Not found

@muellerzr
Copy link
Collaborator

As I'm on vacation and can't look at this thoroughly I'll give you both a TL;DR of how to investigate it :)

After every single cell, or at the end of it, include the following:

torch.cuda.is_initialized()

This must be False always until you do notebook_launcher(...)

If anything initializes CUDA then whatever triggers it must be in that training function you pass into notebook_launcher. HTH!

@muellerzr
Copy link
Collaborator

Another thing to watch out for @lewtun pointed out to me is making sure any of the libraries you are using @clam004 aren't trying to initialize cuda later on in their library code. If doing it cell by cell and finding that nothing is being done, that means they're initializing CUDA themselves and that's not an issue we can necessarily fix

@muellerzr muellerzr self-assigned this Dec 23, 2022
@muellerzr muellerzr added the bug Something isn't working label Dec 23, 2022
@clam004
Copy link
Author

clam004 commented Dec 28, 2022

Thanks @muellerzr , I did as you say and in every cell above

args = (model, tokenizer, config)
notebook_launcher(training_loop, args, num_processes=2)

I have verified that each has torch.cuda.is_initialized() returning False for me at the end of the cell.

unfortunately I am still getting the same Cannot re-initialize CUDA in forked subprocess. error in the title.

File ~/path/python3.8/site-packages/accelerate/launchers.py:122, in notebook_launcher(function, args, num_processes, use_fp16, mixed_precision, use_port)
    119         launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
    121         print(f"Launching training on {num_processes} GPUs.")
--> 122         start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
    124 else:
    125     # No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
    126     use_mps_device = "false"

File ~/path/python3.8/site-packages/torch/multiprocessing/spawn.py:198, in start_processes(fn, args, nprocs, join, daemon, start_method)
    195     return context
    197 # Loop on join until it returns True or raises an exception.
--> 198 while not context.join():
    199     pass

File ~/path/python3.8/site-packages/torch/multiprocessing/spawn.py:160, in ProcessContext.join(self, timeout)
    158 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
    159 msg += original_trace
--> 160 raise ProcessRaisedException(msg, error_index, failed_process.pid)

ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/path/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/path/python3.8/site-packages/accelerate/utils/launch.py", line 97, in __call__
    self.launcher(*args)
  File "/tmp/ipykernel_49172/1813046040.py", line 15, in training_loop
    accelerator = Accelerator(mixed_precision='fp16')
  File "/path/python3.8/site-packages/accelerate/accelerator.py", line 308, in __init__
    self.state = AcceleratorState(
  File "python3.8/site-packages/accelerate/state.py", line 156, in __init__
    torch.cuda.set_device(self.device)
  File "python3.8/site-packages/torch/cuda/__init__.py", line 313, in set_device
    torch._C._cuda_setDevice(device)
  File "/path/python3.8/site-packages/torch/cuda/__init__.py", line 206, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

OK thanks for your help and enjoy your vacation, ill check back here in Jan 2023

@muellerzr muellerzr added the wip Work in progress label Jan 22, 2023
@huggingface huggingface deleted a comment from github-actions bot Jan 22, 2023
@leemengtw
Copy link

leemengtw commented Jan 25, 2023

Facing the same error when already assert not torch.cuda.is_initialized() before notebook_launcher(..., num_processes=8)

I'm using accelerate==0.15.0 with python 3.10, any suggestion is highly appreciated! 🙏

Additional version Info:

torch          : 1.13.1
torchvision    : 0.14.1
diffusers      : 0.11.1
accelerate     : 0.15.0
xformers       : 0.0.16.dev432+git.bc08bbc

@rohan-gangaraju
Copy link

Was facing this issue due to a device variable initialized that was being done in the first cell

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Commenting this line resolved the issue.

FYI: The torch.cuda.is_initialized() returned False despite the initialization that was happening. Probably best to manually check for initialization code like above.

@graldij
Copy link

graldij commented Apr 17, 2023

Hi all.
I get the same error using multiprocessing in this example notebook (run locally on SLURM server), i.e. with num_processes=2. Any suggestion is welcomed!

@willlllllio
Copy link

Had the same error, putting all my imports into my training_loop function fixed it, so my guess is some library that's being imported did something to trigger the cuda initialization or something like that.

@hikkilover
Copy link

same error. Try to launch your script with "accelerate launch" on CLI, do not use "notebook_launcher()" inside the script. (follow the official guideline: https://github.com/huggingface/diffusers/blob/3b641eabe9876e7c48977b35331fda54ce972b4a/examples/unconditional_image_generation/README.md)
It works to me.

@Liweixin22
Copy link

I using T4 x 2 GPUs on Kaggle, face the same error.

@muellerzr
Copy link
Collaborator

@Liweixin22 ensure that you haven't called anything to CUDA before running notebook_launcher. If there are still issues, please let us know what notebook you are trying to run

@Liweixin22
Copy link

Liweixin22 commented May 21, 2023

@muellerzr Here is the link of my notebook
https://www.kaggle.com/code/liweixin23/ml2023spring-hw2

@stymbhrdwj
Copy link

I am experiencing the same issue in the official diffusers training example notebook here.

I downloaded the notebook and ran it on my system with dual-GPU setup and set num_processes=2. I had generated the accelerator config prior to this.

torch.cuda.is_initialized() returns False after each cell preceding the notebook_launcher call.

@RRaphaell
Copy link

I'm facing the same issue, any updates on this?

@mingrenbuke
Copy link

I'm facing the same issue, and I've encapsulated all the code within functions. The first line of code executed is notebook_launcher

@ghadiaravi13
Copy link

ghadiaravi13 commented Aug 20, 2023

Hi @muellerzr
I am able to reproduce the same error for a dummy case, where I make no others imports other than the accelerate library.

Cell:0

from accelerate import Accelerator

Cell:1

def dummy_func():
    accelerator = Accelerator()
    accelerator.print("Hello")

Cell:2

from accelerate import notebook_launcher
notebook_launcher(dummy_func,num_processes=2)

RuntimeError: CUDA has been initialized before the notebook_launcher could create a forked subprocess. This likely stems from an outside import causing issues once the notebook_launcher() is called. Please review your imports and test them when running the notebook_launcher() to identify which one is problematic.

Also, I tried looking at the GPU Memory before and after executing Cell:0, using nvidia-smi
Before:
GPU:0 - 0MB
GPU:1 - 0MB

After executing Cell:0
GPU:0 - 2MB
GPU:1 - 2MB

Could this mean that the import itself is leading to CUDA initialization, if so, how do we handle it?

@muellerzr
Copy link
Collaborator

Try installing from git via pip install git+https://github.com/huggingface/accelerate. I believe this is fixed on main

@Sewens
Copy link

Sewens commented Aug 21, 2023

Try installing from git via pip install git+https://github.com/huggingface/accelerate. I believe this is fixed on main

Work for me, I uninstall the old version which accelerate-0.21.0.
And reinstall the package directly from github, to verison accelerate-0.22.0.dev0.
Then it just work!

For those whoes default https github has been blocked, typing the following command for installing.

pip install git+ssh://git@github.com/huggingface/accelerate.git

@ghadiaravi13
Copy link

Works for me too, thanks a lot @muellerzr for the prompt response!

@muellerzr
Copy link
Collaborator

We'll have a release out this week with the non-braking version, thanks @Sewens @ghadiaravi13!

@muellerzr muellerzr added solved The bug or feature request has been solved, but the issue is still opened and removed wip Work in progress labels Aug 21, 2023
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@infamix
Copy link

infamix commented Sep 20, 2023

I am still getting the same error on Kaggle T4s. It doesn't matter whether I load the model outside the training function or inside, the outcome is the same.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@muellerzr
Copy link
Collaborator

@infamix i need a copy of your code to help

@csaroff
Copy link

csaroff commented Nov 7, 2023

@muellerzr the following code is breaking for me(First cell of the notebook):

from accelerate import notebook_launcher

def test_nb_launcher():
    from fastai.test_utils import synth_learner
    import fastai.distributed # Updated

    learn = synth_learner()
    with learn.distrib_ctx(in_notebook=True): # Updated
        learn.fit(3)

notebook_launcher(test_nb_launcher, num_processes=2)

Running accelerate v0.24.1

import accelerate
accelerate.__version__ # '0.24.1'

I'm also running this on a single node w/ 2x T4 GPUs. I'm guessing all the T4 reports are a coincidence because everyone's just trying to run on the two cheapest cloud gpu's they can get. 🤷

Edit: I forgot to include the distrib_ctx in the example. Unfortunately, it still fails with the same error even after I updated it.

Copy link

github-actions bot commented Dec 1, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@csaroff
Copy link

csaroff commented Dec 2, 2023

Bump @muellerzr

@muellerzr
Copy link
Collaborator

@csaroff ran just fine for me:

import os
import torch
from accelerate import notebook_launcher
from fastai.test_utils import synth_learner
import fastai.distributed

os.environ["NCCL_P2P_DISABLE"]="1"
os.environ["NCCL_IB_DISABLE"]="1"

def test_nb_launcher():
    learn = synth_learner()
    with learn.distrib_ctx(in_notebook=True): # Updated
        learn.fit(3)

notebook_launcher(test_nb_launcher, num_processes=2)

Note: I'm using 2x 4090's, hence why the OS setting is needed (this will be automatic soon)

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working solved The bug or feature request has been solved, but the issue is still opened
Projects
None yet
Development

No branches or pull requests