Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_tracks_v2 has bug when plotting with trackgroup argument. #92

Open
chandlj opened this issue Apr 24, 2024 · 2 comments
Open

plot_tracks_v2 has bug when plotting with trackgroup argument. #92

chandlj opened this issue Apr 24, 2024 · 2 comments

Comments

@chandlj
Copy link

chandlj commented Apr 24, 2024

I am running this notebook for RoboTAP clustering. After computing the clusters, I am running the following cell:

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

pointtrack_video = viz_utils.plot_tracks_v2(
    (demo_videos[demo_episode_ids[0]]).astype(np.uint8),
    separation_tracks_trim[demo_episode_ids[0]],
    1.0-separation_visibility_trim[demo_episode_ids[0]],
    trackgroup=clustered['classes']
)
media.show_video(pointtrack_video, fps=20)

However, the plot only shows about 10 points no matter how many points I track, and there are really no clusters to be found. I found that if I comment out trackgroup, then the plotting code works correctly and I can see the full range of points (although not colored with cluster ID). I can also verify that clusters are correctly computed by plotting individual frames like so:

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

frame = 35
plt.scatter(
  separation_tracks_trim["dummy_id"][:, frame, 0],
  separation_tracks_trim["dummy_id"][:, frame, 1],
  c=clustered["classes"],
  cmap="viridis",
)
plt.imshow(video[frame])

It's really only when trackgroup is specified that this code does not behave properly. Any ideas of how to fix?

@cdoersch
Copy link
Collaborator

Now that tapir_clustering.py is fixed, I've run the colab at head and verified that the code will plot more than 20 tracks. Your snippets above look correct to me--I don't see why it wouldn't plot the full set the way that the colab does. Maybe set a breakpoint at https://github.com/google-deepmind/tapnet/blob/main/utils/viz_utils.py#L193 and check what's being passed to plt.scatter?

@chandlj
Copy link
Author

chandlj commented Apr 29, 2024

@cdoersch The most recent code that was pushed for tapir_clustering has a bug and did not work for me in the notebook. Looking at the commit here, it looks on line 574 changing len to np.prod is causing problems. I noticed that jax.tree_map(lambda x: np.prod(x.shape), query_features) actually returns shape 1 for the resolutions array, not 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants