Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training RuntimeError: Too many open files. Communication with the workers is no longer possible #3953

Open
veer5551 opened this issue Feb 14, 2022 · 5 comments

Comments

@veer5551
Copy link

Instructions To Reproduce the Issue: (Multi-GPU training with validation and best checkpointer hook)

  1. Added the LossEvalHook and Best Checkpointer to the training:
    Ref: https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b

Codes:
1.a lossEvalHook.py

from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
import torch
import time
import datetime
import logging
import numpy as np
from utils import *


class LossEvalHook(HookBase):
    def __init__(self, eval_period, model, data_loader):
        self._model = model
        self._period = eval_period
        self._data_loader = data_loader

    def _do_loss_eval(self):
        # Copying inference_on_dataset from evaluator.py
        total = len(self._data_loader)
        num_warmup = min(5, total - 1)

        start_time = time.perf_counter()
        total_compute_time = 0
        losses = []
        for idx, inputs in enumerate(self._data_loader):
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0
            start_compute_time = time.perf_counter()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Loss on Validation  done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)
                    ),
                    n=5,
                )
            loss_batch = self._get_loss(inputs)
            losses.append(loss_batch)
        mean_loss = np.mean(losses)
        self.trainer.storage.put_scalar('validation_loss', mean_loss)
        comm.synchronize()

        return losses

    def _get_loss(self, data):
        # How loss is calculated on train_loop
        metrics_dict = self._model(data)
        metrics_dict = {
            k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
            for k, v in metrics_dict.items()
        }
        total_losses_reduced = sum(loss for loss in metrics_dict.values())
        return total_losses_reduced

    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final = next_iter == self.trainer.max_iter
        if is_final or (self._period > 0 and next_iter % self._period == 0):
            self._do_loss_eval()
        self.trainer.storage.put_scalars(timetest=12)

1.b myTrainer.py

import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Select the GPU number for training. Only for Single GPU training
from detectron2.utils.logger import setup_logger

setup_logger()
from detectron2 import model_zoo
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetMapper, build_detection_test_loader, DatasetCatalog
from detectron2.engine import DefaultPredictor, DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import (
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator,
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
    inference_context,
    inference_on_dataset
)
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.structures import BoxMode
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.visualizer import Visualizer
import detectron2.utils.comm as comm
from detectron2.utils.logger import log_every_n_seconds
from detectron2.engine.hooks import HookBase, BestCheckpointer
from detectron2.config import CfgNode
from lossEvalHook import LossEvalHook   # ref:
import pickle
import json

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from utils import *


# ============================== FUNCTIONS START ====================================================
def oi_data_dicts(img_dir):
    json_file = os.path.join(img_dir, "data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns):
        v['file_name'] = os.path.join(img_dir, v['file_name'] + ".bmp")
        v['height'] = int(v['height'])
        v['width'] = int(v['width'])

        new_obj_list = []
        for annotation in v['annotations']:
            obj = {
                'bbox': annotation['bbox'],
                'bbox_mode': BoxMode.XYXY_ABS,
                'segmentation': annotation['segmentation'],
                'category_id': annotation['category_id']

            }
            new_obj_list.append(obj)

        v['annotations'] = new_obj_list
        dataset_dicts.append(v)
    return dataset_dicts


class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)

    def build_hooks(self):
        print("68 - BUILD HOOK")
        hooks = super().build_hooks()

        hooks.insert(-1, LossEvalHook(
            self.cfg.TEST.EVAL_PERIOD,
            self.model,
            build_detection_test_loader(
                self.cfg,
                self.cfg.DATASETS.TEST[0],
                DatasetMapper(self.cfg, True)
            )
        ))

        print("69 - BUILD BEST CHECKPOINTERHOOK")
        if self.cfg.SOLVER.BEST_CHECKPOINTER and comm.is_main_process():
            hooks.append(BestCheckpointer(
                self.cfg.TEST.EVAL_PERIOD,
                self.checkpointer,
                self.cfg.SOLVER.BEST_CHECKPOINTER.METRIC,
                mode=self.cfg.SOLVER.BEST_CHECKPOINTER.MODE
                ))

        # swap the order of PeriodicWriter and ValidationLoss
        # code hangs with no GPUs > 1 if this line is removed
        hooks = hooks[:-2] + hooks[-2:][::-1]
        return hooks


def get_config(
        config_file_path,
        checkpoint_url,
        train_dataset_name,
        test_dataset_name,
        num_classes,
        device,
        output_dir):
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(config_file_path))
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(checkpoint_url)  # Let training initialize from model zoo

    cfg.DATASETS.TRAIN = (train_dataset_name,)
    cfg.DATASETS.TEST = (test_dataset_name,)
    cfg.INPUT.MIN_SIZE_TRAIN = (1080)
    cfg.INPUT.MAX_SIZE_TRAIN = (1920)

    cfg.INPUT.MIN_SIZE_TEST = (1080)
    cfg.INPUT.MAX_SIZE_TEST = (1920)

    cfg.DATALOADER.NUM_WORKERS = 8

    cfg.SOLVER.IMS_PER_BATCH = 6
    # one_epoch = int(total_files / self._cfg.SOLVER.IMS_PER_BATCH)
    # self._cfg.SOLVER.MAX_ITER = int(one_epoch * epochs)
    cfg.SOLVER.BASE_LR = 0.0001  # pick a good LR
    cfg.SOLVER.MAX_ITER = 61000  
    cfg.SOLVER.STEPS = []  # do not decay learning rate
    cfg.TEST.EVAL_PERIOD = 500  # 5000 iteration 500validation

    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512  # faster, and good enough for this toy dataset (default: 512)

    # only has one class (ballon).
    # (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
    # NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes

    cfg.MODEL.DEVICE = device
    cfg.OUTPUT_DIR = output_dir

    # Best checkpointer hook
    cfg.SOLVER.BEST_CHECKPOINTER = CfgNode({"ENABLED": False})
    cfg.SOLVER.BEST_CHECKPOINTER.METRIC = "segm/AP50"
    cfg.SOLVER.BEST_CHECKPOINTER.MODE = "max"

    return cfg


# ============================== FUNCTIONS END====================================================


def main():
    # Step 1: ------------------------------- Setup the Variables ---------------------------------------------
    config_file_path = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml'
    checkpoint_url = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml'
    output_dir = './model_zoo/mask_rcnn_X_101_32x8d_FPN_3x__1e4_512_2gpu'
    num_classes = 1
    device = "cuda"  # else "cpu"

    training_data_meta_path = "Temp_datastaging_train_v3"

    # Setup the Train Data
    train_dataset_name = "v1_training"
    train_dataset_path = "training"
    train_annotations_path = os.path.join(train_dataset_path, "data.json")

    # Setup the Test Data
    test_dataset_name = "v1_validation"
    test_dataset_path = "validation"
    test_annotations_path = os.path.join(test_dataset_path, "data.json")

    cfg_save_path = os.path.join(output_dir, 'instance_seg.pickle')

    # Step 2: ----------------------------- Register the Dataset with detectron2 ----------------------------------
    print("2. Registering the Dataset ...")
    d = "Training"
    DatasetCatalog.register(train_dataset_name,
                            lambda d=d: oi_data_dicts(os.path.join(training_data_meta_path, train_dataset_path)))
    MetadataCatalog.get(train_dataset_name).set(thing_classes=["Vehicle"])

    d = "validation"
    DatasetCatalog.register(test_dataset_name,
                            lambda d=d: oi_data_dicts(os.path.join(training_data_meta_path, test_dataset_path)))
    MetadataCatalog.get(test_dataset_name).set(thing_classes=["Vehicle"])

    oi_metadata = MetadataCatalog.get(train_dataset_name + "_" + train_dataset_path)
    print("Dataset Registered..")

    # Step 3: ------------------------------- Get the config file ---------------------------------------------
    print("3. Creating the config file...")
    cfg = get_config(
        config_file_path=config_file_path,
        checkpoint_url=checkpoint_url,
        train_dataset_name=train_dataset_name,
        test_dataset_name=test_dataset_name,
        num_classes=num_classes,
        device=device,
        output_dir=output_dir
    )
    print("Config File Created with following parameters \n", cfg)

    # Step 4: ------------------------------- Make the Output directory ---------------------------------------------
    print("4. Creating the Output Directories")
    if os.path.exists(cfg.OUTPUT_DIR):
        print("Model Output Directory Already exists. Please give a different output directory")
        # print("Exiting the program...")
        # return
    else:
        os.makedirs(cfg.OUTPUT_DIR,exist_ok=True)

    # Step 5: -------------------------- Dump the config file as picke file in output directory --------------------
    print(f"5. Dumping the config file as Pickle file in the Output Directory {cfg.OUTPUT_DIR}/{cfg_save_path}")
    with open(cfg_save_path, 'wb') as f:
        pickle.dump(cfg, f, protocol=pickle.HIGHEST_PROTOCOL)

    # Step 6: -------------------------------- Start the Training  -------------------------------------------------
    print("6. Initializing the Training ...")
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=False)
    # trainer.train()

    return trainer.train()


if __name__ == '__main__':
    # main()
    launch(main, num_gpus_per_machine=2, dist_url="auto")

  1. What exact command you run: python mytrainer.py

  2. Full logs or other relevant observations:

�[4m�[5m�[31mERROR�[0m �[32m[02/12 06:57:59 d2.engine.train_loop]: �[0mException during training:
Traceback (most recent call last):
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 150, in train
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 180, in after_step
  File "/data/detectron2/detectron2_training_scripts/lossEvalHook.py", line 70, in after_step
  File "/data/detectron2/detectron2_training_scripts/lossEvalHook.py", line 28, in _do_loss_eval
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1023, in _try_get_data
RuntimeError: Too many open files. Communication with the workers is no longer possible. Please increase the limit using `ulimit -n` in the shell or change the sharing strategy by calling `torch.multiprocessing.set_sharing_strategy('file_system')` at the beginning of your code
�[32m[02/12 06:57:59 d2.engine.hooks]: �[0mOverall training speed: 27497 iterations in 13:14:20 (1.7333 s / it)
�[32m[02/12 06:57:59 d2.engine.hooks]: �[0mTotal training time: 1 day, 9:51:47 (20:37:27 on hooks)
Traceback (most recent call last):
  File "multi-gpu_training_detectron2.py", line 261, in <module>
    launch(main, num_gpus_per_machine=2, dist_url="auto")
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/launch.py", line 79, in launch
    daemon=False,
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/launch.py", line 126, in _distributed_worker
  File "/data/detectron2/detectron2_training_scripts/multi-gpu_training_detectron2.py", line 253, in main
    return trainer.train()
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/defaults.py", line 484, in train
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 150, in train
  File "/data/detectron2/aaa/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 180, in after_step
  File "/data/detectron2/detectron2_training_scripts/lossEvalHook.py", line 70, in after_step
  File "/data/detectron2/detectron2_training_scripts/lossEvalHook.py", line 28, in _do_loss_eval
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
  File "/data/detectron2/aaa/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1023, in _try_get_data
RuntimeError: Too many open files. Communication with the workers is no longer possible. Please increase the limit using `ulimit -n` in the shell or change the sharing strategy by calling `torch.multiprocessing.set_sharing_strategy('file_system')` at the beginning of your code

Expected behavior: Expected to Run the Training till complete.

Observations:

  • The ETA for this training was shown to be around 1 day 5 hours (excluding the validation time)
  • Each validation hook execution took ~30mins (~8k images)
  • After almost a day of training, the trainer gave the above Error.
  • After some digging into it observed the following:
  1. The Data loader workers (cfg.DATALOADER.NUM_WORKERS) are per GPU. i.e if you initialize it to 4, you will get 8 workers (4 per GPU)
  2. These workers are initialized separately (newly) for validation hook.
  3. After the Execution of the validation hook, I suspect that these workers aren't exiting the process gracefully, and are still present (locked?) in the processes.
  4. After searching for this error, the internet suggested to add these lines to the code:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

Added with no luck!

Below are the snaps of the processes during training.
image
image

Environment:

Paste the output of the following command:

----------------------  ----------------------------------------------------------------------------
sys.platform            linux
Python                  3.6.9 (default, Jan 26 2021, 15:33:00) [GCC 8.4.0]
numpy                   1.19.5
detectron2              0.6 @/data/detectron2/aaa/lib/python3.6/site-packages/detectron2
Compiler                GCC 7.3
CUDA compiler           CUDA 11.1
detectron2 arch flags   3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6
DETECTRON2_ENV_MODULE   <not set>
PyTorch                 1.9.0+cu111 @/data/detectron2/aaa/lib/python3.6/site-packages/torch
PyTorch debug build     False
GPU available           Yes
GPU 0,1,2,3,4,5,6,7     A100-PCIE-40GB (arch=8.0)
Driver version          460.91.03
CUDA_HOME               /usr/local/cuda
Pillow                  8.4.0
torchvision             0.10.0+cu111 @/data/detectron2/aaa/lib/python3.6/site-packages/torchvision
torchvision arch flags  3.5, 5.0, 6.0, 7.0, 7.5, 8.0, 8.6
fvcore                  0.1.5.post20220119
iopath                  0.1.9
cv2                     4.5.5
----------------------  ----------------------------------------------------------------------------
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  - CuDNN 8.0.5
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

Testing NCCL connectivity ... this should not hang.
NCCL succeeded.

Any thoughts on the above behavior and how can we handle it?

Thanks a lot for the Amazing Work! :)

@veer5551 veer5551 changed the title RuntimeError: Too many open files. Communication with the workers is no longer possible Training RuntimeError: Too many open files. Communication with the workers is no longer possible Feb 21, 2022
@veer5551
Copy link
Author

Hi @ppwwyyxx, any chance you might look into this?
I am restricted to train models.

Thanks!

@zensenlon
Copy link

Hello, I'm just curious if the LossEvalHook works on multiple GPUs. In my turns it hangs after calculate validation loss.

@veer5551
Copy link
Author

I used this fix @zensenlon.
It works, the best checkpoint is saved after the evaluation, and the training resumes, but the threads initialized for evaluation and best checkpointer are not released and are locked in.

@ShreyasSkandanS
Copy link

I used this fix @zensenlon. It works, the best checkpoint is saved after the evaluation, and the training resumes, but the threads initialized for evaluation and best checkpointer are not released and are locked in.

I used the fix mentioned in that post, but I can no longer see my validation loss in the tensorboard. Does your implementation log validation loss correctly?

@geotsl
Copy link

geotsl commented Apr 25, 2023

@ShreyasSkandanS at the end did you manage to make it work? I mean the issue of the validation loss vanishing from the tensorboard?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants