-
Notifications
You must be signed in to change notification settings - Fork 75
Closed
Description
I am trying to run the mnist pytorch example on a single host of 8xA100 GPU.
Description
The command to reproduce:
python3 submission_runner.py --framework=pytorch --workload=mnist --experiment_dir=$HOME/experiments --experiment_name=my_first_experiment --submission_path=baselines/adamw/pytorch/submission.py --tuning_search_space=baselines/adamw/tuning_search_space.json
Error Log:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
output.reraise()
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_utils.py", line 699, in reraise
raise exception
torch._dynamo.exc.TorchRuntimeError: Failed running call_module fn(*(FakeTensor(..., device='cuda:0', size=(16, 1, 28, 28)),), **{}):
Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/xzhao9/git/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py", line 43, in forward
return self.net(x)
^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1233, in __torch_dispatch__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_ops.py", line 516, in __call__
return self._op(*args, **kwargs or {})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1405, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1615, in dispatch
return decomposition_table[func](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 250, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 72, in inner
r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 1314, in addmm
out = alpha * torch.mm(mat1, mat2)
^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1405, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1713, in dispatch
with in_kernel_invocation_manager(self):
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/contextlib.py", line 137, in __enter__
return next(self.gen)
^^^^^^^^^^^^^^
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1027, in in_kernel_invocation_manager
assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: False, True
from user code:
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] Function Runtimes (s)
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] ------------------------------- --------------
[2023-11-17 16:50:50,833] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner 0
python3 submission_runner.py --framework=pytorch --workload=mnist --experiment_dir=$HOME/experiments --experiment_name=my_first_experiment --submission_path=baselines/adamw/pytorch/submission.py --tuning_search_space=baselines/adamw/tuning_search_space.json --torch_compile=False
Error log:
I1117 16:55:25.182675 139816804692992 logger_utils.py:76] Creating experiment directory at /home/xzhao9/experiments/my_first_experiment/mnist_pytorch.
I1117 16:55:25.198398 139816804692992 submission_runner.py:528] Using RNG seed 3088258823
I1117 16:55:25.199095 139816804692992 submission_runner.py:537] --- Tuning run 1/1 ---
I1117 16:55:25.199179 139816804692992 submission_runner.py:542] Creating tuning directory at /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1.
I1117 16:55:25.199327 139816804692992 logger_utils.py:92] Saving hparams to /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1/hparams.json.
I1117 16:55:25.199881 139816804692992 submission_runner.py:205] Initializing dataset.
I1117 16:55:25.199955 139816804692992 submission_runner.py:212] Initializing model.
I1117 16:55:28.326041 139816804692992 submission_runner.py:245] Initializing optimizer.
I1117 16:55:28.902134 139816804692992 submission_runner.py:252] Initializing metrics bundle.
I1117 16:55:28.902250 139816804692992 submission_runner.py:270] Initializing checkpoint and logger.
I1117 16:55:28.902581 139816804692992 submission_runner.py:290] Saving meta data to /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1/meta_data_0.json.
I1117 16:55:29.364587 139816804692992 submission_runner.py:294] Saving flags to /home/xzhao9/experiments/my_first_experiment/mnist_pytorch/trial_1/flags_0.json.
I1117 16:55:29.401283 139816804692992 submission_runner.py:304] Starting training loop.
I1117 16:55:29.408903 139816804692992 dataset_info.py:578] Load dataset info from /home/xzhao9/data/mnist/3.0.1
I1117 16:55:29.414449 139816804692992 dataset_info.py:669] Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
I1117 16:55:29.414723 139816804692992 dataset_builder.py:528] Reusing dataset mnist (/home/xzhao9/data/mnist/3.0.1)
I1117 16:55:29.490485 139816804692992 logging_logger.py:49] Constructing tf.data.Dataset mnist for split train[:50000], from /home/xzhao9/data/mnist/3.0.1
NCCL version 2.19.3+cuda11.8
Traceback (most recent call last):
File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 673, in <module>
app.run(main)
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/xzhao9/.conda/envs/jax/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 641, in main
score = score_submission_on_workload(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 554, in score_submission_on_workload
timing, metrics = train_once(workload, global_batch_size,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/xzhao9/git/algorithmic-efficiency/submission_runner.py", line 326, in train_once
optimizer_state, model_params, model_state = update_params(
^^^^^^^^^^^^^^
File "/data/users/xzhao9/git/algorithmic-efficiency/baselines/adamw/pytorch/submission.py", line 86, in update_params
loss_dict = workload.loss_fn(
^^^^^^^^^^^^^^^^^
File "/data/users/xzhao9/git/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py", line 196, in loss_fn
per_example_losses *= mask_batch
RuntimeError: The size of tensor a (16) must match the size of tensor b (2) at non-singleton dimension 0
[2023-11-17 16:56:16,168] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-11-17 16:56:16,168] torch._dynamo.utils: [INFO] Function Runtimes (s)
[2023-11-17 16:56:16,168] torch._dynamo.utils: [INFO] ---------- --------------
Metadata
Metadata
Assignees
Labels
No labels