Skip to content

Commit

Permalink
Add GPU support. (#1)
Browse files Browse the repository at this point in the history
It would be cool to be able to leverage GPU to compute pitch much faster. Currently this does not work because of incorrect device placement for new tensors. However, with some simple changes the existing implementation is GPU-compatible. The speedup is 50x on my machine for a batch of 32, 10 second waves.
  • Loading branch information
Piotr Dabkowski committed Aug 23, 2022
1 parent 5676b97 commit 1c0f508
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchyin/yin.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def estimate(
return torch.where(
tau > 0,
sample_rate / (tau + tau_min + 1).type(signal.dtype),
torch.tensor(0).type(signal.dtype),
torch.tensor(0, device=tau.device).type(signal.dtype),
)


Expand All @@ -103,8 +103,8 @@ def _diff(frames: torch.Tensor, tau_max: int) -> torch.Tensor:
# cumulative mean normalized difference function (equation 8)
return (
diff[..., 1:]
* torch.arange(1, diff.shape[-1])
/ np.maximum(diff[..., 1:].cumsum(-1), 1e-5)
* torch.arange(1, diff.shape[-1], device=diff.device)
/ torch.maximum(diff[..., 1:].cumsum(-1), torch.tensor(1e-5, device=diff.device))
)


Expand All @@ -113,7 +113,7 @@ def _search(cmdf: torch.Tensor, tau_max: int, threshold: float) -> torch.Tensor:
# if none are below threshold (argmax=0), this is a non-periodic frame
first_below = (cmdf < threshold).int().argmax(-1, keepdim=True)
first_below = torch.where(first_below > 0, first_below, tau_max)
beyond_threshold = torch.arange(cmdf.shape[-1]) >= first_below
beyond_threshold = torch.arange(cmdf.shape[-1], device=cmdf.device) >= first_below

# mask all periods with upward sloping cmdf to find the local minimum
increasing_slope = torch.nn.functional.pad(cmdf.diff() >= 0.0, [0, 1], value=1)
Expand Down

0 comments on commit 1c0f508

Please sign in to comment.