Skip to content

Getting an error when running on a single host with 8xA100 GPU #582

@xuzhao9

Description

@xuzhao9

I am trying to run the mnist pytorch example on a single host of 8xA100 GPU.

Description

The command to reproduce:

python3 submission_runner.py     --framework=pytorch     --workload=mnist     --experiment_dir=$HOME/experiments     --experiment_name=my_first_experiment     --submission_path=baselines/adamw/pytorch/submission.py     --tuning_search_space=baselines/adamw/tuning_search_space.json

Error Log:

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_utils.py", line 699, in reraise
    raise exception
torch._dynamo.exc.TorchRuntimeError: Failed running call_module fn(*(FakeTensor(..., device='cuda:0', size=(16, 1, 28, 28)),), **{}):
Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/git/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py", line 43, in forward
    return self.net(x)
           ^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1233, in __torch_dispatch__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_ops.py", line 516, in __call__
    return self._op(*args, **kwargs or {})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1405, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1615, in dispatch
    return decomposition_table[func](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 250, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 72, in inner
    r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 1314, in addmm
    out = alpha * torch.mm(mat1, mat2)
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1405, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1713, in dispatch
    with in_kernel_invocation_manager(self):
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/contextlib.py", line 137, in __enter__
    return next(self.gen)
           ^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1027, in in_kernel_invocation_manager
    assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: False, True
from user code:
   File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner               0
python3 submission_runner.py     --framework=pytorch     --workload=mnist     --experiment_dir=$HOME/experiments     --experiment_name=my_first_experiment     --submission_path=baselines/adamw/pytorch/submission.py     --tuning_search_space=baselines/adamw/tuning_search_space.json --torch_compile=False

Error log:

I1117 16:55:25.182675 139816804692992 logger_utils.py:76] Creating experiment directory at /home/xzhao9/experiments/my_first_experiment/mnist_pytorch.
I1117 16:55:25.198398 139816804692992 submission_runner.py:528] Using RNG seed 3088258823
I1117 16:55:25.199095 139816804692992 submission_runner.py:537] --- Tuning run 1/1 ---
I1117 16:55:25.199179 139816804692992 submission_runner.py:542] Creating tuning directory at /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1.
I1117 16:55:25.199327 139816804692992 logger_utils.py:92] Saving hparams to /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1/hparams.json.
I1117 16:55:25.199881 139816804692992 submission_runner.py:205] Initializing dataset.
I1117 16:55:25.199955 139816804692992 submission_runner.py:212] Initializing model.
I1117 16:55:28.326041 139816804692992 submission_runner.py:245] Initializing optimizer.
I1117 16:55:28.902134 139816804692992 submission_runner.py:252] Initializing metrics bundle.
I1117 16:55:28.902250 139816804692992 submission_runner.py:270] Initializing checkpoint and logger.
I1117 16:55:28.902581 139816804692992 submission_runner.py:290] Saving meta data to /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1/meta_data_0.json.
I1117 16:55:29.364587 139816804692992 submission_runner.py:294] Saving flags to /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1/flags_0.json.
I1117 16:55:29.401283 139816804692992 submission_runner.py:304] Starting training loop.
I1117 16:55:29.408903 139816804692992 dataset_info.py:578] Load dataset info from /home/xzhao9/data/mnist/3.0.1
I1117 16:55:29.414449 139816804692992 dataset_info.py:669] Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
I1117 16:55:29.414723 139816804692992 dataset_builder.py:528] Reusing dataset mnist (/home/xzhao9/data/mnist/3.0.1)
I1117 16:55:29.490485 139816804692992 logging_logger.py:49] Constructing tf.data.Dataset mnist for split train[:50000], from /home/xzhao9/data/mnist/3.0.1
NCCL version 2.19.3+cuda11.8
Traceback (most recent call last):
  File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 673, in <module>
    app.run(main)
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 641, in main
    score = score_submission_on_workload(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 554, in score_submission_on_workload
    timing, metrics = train_once(workload, global_batch_size,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 326, in train_once
    optimizer_state, model_params, model_state = update_params(
                                                 ^^^^^^^^^^^^^^
  File "/data/users/xzhao9/git/algorithmic-efficiency/baselines/adamw/pytorch/submission.py", line 86, in update_params
    loss_dict = workload.loss_fn(
                ^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/git/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py", line 196, in loss_fn
    per_example_losses *= mask_batch
RuntimeError: The size of tensor a (16) must match the size of tensor b (2) at non-singleton dimension 0
[2023-11-17 16:56:16,168] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-11-17 16:56:16,168] torch._dynamo.utils: [INFO] Function    Runtimes (s)
[2023-11-17 16:56:16,168] torch._dynamo.utils: [INFO] ----------  --------------

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions