Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect input shape when reading .png files #25

Closed
ruslanmustafin opened this issue Feb 19, 2024 · 1 comment · Fixed by #26
Closed

Incorrect input shape when reading .png files #25

ruslanmustafin opened this issue Feb 19, 2024 · 1 comment · Fixed by #26

Comments

@ruslanmustafin
Copy link

Using run_vision_chat.sh with a .PNG image results in

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 254, in <module>
    run(main)
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 250, in main
    output = sampler(prompts, FLAGS.max_n_frames)[0]
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 228, in __call__
    batch = self.construct_input(prompts, max_n_frames)
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 123, in construct_input
    vision = self._read_process_vision(prompt['input_path'], max_n_frames)
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 102, in _read_process_vision
    enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 53, in encode
    return self._encode(pixel_values)
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 35, in fn
    return self.model.apply(
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 122, in encode
    hidden_states = self.encoder(pixel_values)
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 155, in __call__
    hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 429, in __call__
    kernel = self.param('kernel', self.kernel_init, kernel_shape,
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (3, 3, 3, 128) but got shape (3, 3, 4, 128) instead for parameter "kernel" in "/encoder/Conv_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

when number of channels in the input is > 3 (if transparency is present).

@wilson1yan
Copy link
Contributor

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants