Skip to content

[bug] when dim=1 #228

@Gabibing

Description

@Gabibing

issue

When using ResidualVQ with feature dimension dim=1 and passing a 2D mask of shape [B, T], masking is not handled correctly.
When mask=None is used as a workaround, the output’s batch dimension collapse to 1.

vq = ResidualVQ(dim=1, num_quantizers=1, codebook_size=4)

error log

vector_quantize_pytorch.py", line 1349, in forward
[rank0]:     quantize = einx.where(
                        ^^^^^^^^^^^
einx.expr.stage3.SolveValueException: Failed to solve values of expressions.
Found contradictory values {5936, 1} for equivalent expressions {'d', '5936', '1'}
Input:
    'b n = 8 742'
    'b n d = 8 742 5936'
    'b n d = 8 742 1'
    'b n d = None'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions