Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use STFT support in coremltools 7 #2

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions convert-to-coreml
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,16 @@ def main():
# Create sample 'audio' for tracing
wav = torch.zeros(2, int(args.length * samplerate))

# Reproduce the STFT step (which we cannot convert to Core ML, unfortunately)
_, stft_mag = estimator.compute_stft(wav)

print('==> Tracing model')
traced_model = torch.jit.trace(estimator.separator, stft_mag)
out = traced_model(stft_mag)
traced_model = torch.jit.trace(estimator, wav)

print('==> Converting to Core ML')
mlmodel = ct.convert(
traced_model,
convert_to='mlprogram',
# TODO: Investigate whether we'd want to make the input shape flexible
# See https://coremltools.readme.io/docs/flexible-inputs
inputs=[ct.TensorType(shape=stft_mag.shape)]
inputs=[ct.TensorType(shape=wav.shape)]
)

output_dir: Path = args.output
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ description = "Spleeter implementation in PyTorch"
# and fail during model conversions e.g. noting that BlobWriter is not available.
requires-python = "<3.11"
dependencies = [
"coremltools >= 6.3, < 7",
"coremltools == 7.0b1",
"numpy >= 1.24, < 2",
"tensorflow >= 2.13.0rc0",
"torch >= 2.0, < 3",
Expand Down
16 changes: 12 additions & 4 deletions spleeter_pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,27 @@ def compute_stft(self, wav):

stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
center=True, return_complex=True, pad_mode='constant')

# implement torch.view_as_real(stft) manually since coremltools doesn't support it
stft = torch.stack((torch.real(stft), torch.imag(stft)), axis=-1)

# only keep freqs smaller than self.F
stft = stft[:, :self.F, :]
mag = stft.abs()
stft = stft[:, :self.F]

return torch.view_as_real(stft), mag
# implement torch.hypot manually since coremltools doesn't support it
mag = torch.sqrt(stft[..., 0] ** 2 + stft[..., 1] ** 2)

return stft, mag

def inverse_stft(self, stft):
"""Inverses stft to wave form"""

pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
stft = torch.view_as_complex(stft)

# implement torch.view_as_complex(stft) manually since coremltools doesn't support it
stft = torch.complex(stft[..., 0], stft[..., 1])

wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
window=self.win)
return wav.detach()
Expand Down