Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize weighted embedding extraction with pyannote 3.1 #214

Open
Tracked by #212
juanmc2005 opened this issue Nov 16, 2023 · 9 comments
Open
Tracked by #212

Optimize weighted embedding extraction with pyannote 3.1 #214

juanmc2005 opened this issue Nov 16, 2023 · 9 comments
Labels
feature New feature or request
Milestone

Comments

@juanmc2005
Copy link
Owner

juanmc2005 commented Nov 16, 2023

With pyannote 3.1, we could do only 1 forward pass of the audio instead of num_speakers when extracting embeddings with weights. This is probably at least one of the causes behind the pytorch version of the wespeaker embedding model being that much slower.

This optimization would also reduce the latency of pyannote/embedding so both would need to be re-computed in the README table.

Important: we should verify that this method is also compatible with masking (e.g. in speechbrain embeddings)

@juanmc2005 juanmc2005 changed the title Optimize weighted embedding extraction to do only 1 forward pass of the audio instead of num_speakers Optimize weighted embedding extraction with pyannote 3.1 Nov 16, 2023
@juanmc2005 juanmc2005 added this to the Version 0.10 milestone Nov 16, 2023
@juanmc2005 juanmc2005 added the feature New feature or request label Nov 16, 2023
@juanmc2005
Copy link
Owner Author

After a quick run of the diarization pipeline with 5s latency, it seems that pyannote 3.1 is hurting performance: DER=27.3 -> DER=29.1 with the exact same models and hyper-parameters.

This is most certainly due to the new API for embedding models compatible with weighted stats pooling, so contrary to my initial idea of leaving this for v0.10, I think it's more a compatibility issue and less a matter of latency optimization.

@juanmc2005
Copy link
Owner Author

cc @hbredin, in case you can think of any other changes that could have caused this

@juanmc2005 juanmc2005 modified the milestones: Version 0.10, Version 0.9 Nov 16, 2023
@juanmc2005 juanmc2005 added bug Something isn't working and removed feature New feature or request labels Nov 16, 2023
@hbredin
Copy link
Collaborator

hbredin commented Nov 17, 2023

cc @hbredin, in case you can think of any other changes that could have caused this

I can't think of any. I did not witness any change of performance in offline speaker diarization when switching from hbredin/wespeaker-voxceleb-resnet34-LM (ONNX) to pyannote/wespeaker-voxceleb-resnet34-LM(PyTorch).

@juanmc2005 juanmc2005 added feature New feature or request and removed bug Something isn't working labels Nov 18, 2023
@juanmc2005
Copy link
Owner Author

@hbredin I narrowed it down to a change in the interpolation method for the weights:

- weights = F.interpolate(
-     weights, size=num_frames, mode="linear", align_corners=False
- )
+ weights = F.interpolate(weights, size=num_frames, mode="nearest")

Any particular reason for this change? I guess interpolating before calling the model should do the trick but I'm curious.

@hbredin
Copy link
Collaborator

hbredin commented Nov 18, 2023

MPS support. "linear" interpolation is not yet supported with MPS backend.

@juanmc2005
Copy link
Owner Author

@hbredin MPS is needed for training though, right? Any way we can make it apply a different strategy? Because to interpolate before calling the model I actually need to know the number of frames that the model will produce internally right before the stats pooling.

I suggest something like:

model = Model.from_pretrained("pyannote/embedding")
model.set_interpolation_method("linear")
embeddings = model(waveform, weights)

Inside the forward method of the model:

outputs = self.stats_pool(outputs, weights=weights, method=self.interpolation_method)

And to call the interpolation:

def forward(..., interpolation_method: str = "nearest"):
    ...
    weights = F.interpolate(weights, size=num_frames, mode=interpolation_method)
    ...

Notice that the interpolation method would need to be changed with a setter because StatsPool is (as it should) created during instantiation.

Also, the same thing could be implemented for the masks in other embedding models (i.e. speechbrain and nemo). Right now I can't take advantage of this optimization because I would need to keep track of which model is compatible with it, and the rule may very easily change in the future, leading to compatibility hell.

I'm willing to open PRs for both features.

@juanmc2005
Copy link
Owner Author

juanmc2005 commented Nov 18, 2023

Since everything is working fine except for that performance loss on AMI, I think I'd prefer not to wait for this fix to release v0.9. I'll add a note to the reproducibility section of the README to indicate that pyannote<3.1 should be used to get the exact same results and link this issue.

In any case, probably re-tuning hyper-parameters for this new interpolation method will give similar results.

@shenberg
Copy link

shenberg commented Aug 19, 2024

Hi, I'm raising this issue to note that as far as I can tell, the missing functionality has been implemented in the mps backend since PyTorch 2.3.0 (The specific commit)

On my Macbook pro M1, using PyTorch 2.4.0:

>>> a = torch.arange(10, device='mps')[None, None].float() # size 1x1x10
>>> F.interpolate(a, size=20, mode='linear', align_corners=False)
tensor([[[0.0000, 0.2500, 0.7500, 1.2500, 1.7500, 2.2500, 2.7500, 3.2500,
          3.7500, 4.2500, 4.7500, 5.2500, 5.7500, 6.2500, 6.7500, 7.2500,
          7.7500, 8.2500, 8.7500, 9.0000]]], device='mps:0')
>>> F.interpolate(a, size=5, mode='linear', align_corners=False)
tensor([[[0.5000, 2.5000, 4.5000, 6.5000, 8.5000]]], device='mps:0')

so there should be no obstacle to returning to the original behavior, right?

@juanmc2005
Copy link
Owner Author

@shenberg have you checked whether using mode="linear" and mode="mps" give the same results on AMI?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants