In [None]:
import torch
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
from torch.hub import load_state_dict_from_url

class AudioFusion:
    def __init__(self):
        self.vggish = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def _load_vggish(self):
        """Load pretrained VGGish model for audio embeddings"""
        vggish_url = 'https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth'
        state_dict = load_state_dict_from_url(vggish_url, progress=True)

        class VGGish(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.embeddings = torch.nn.Sequential(
                    torch.nn.Conv2d(1, 64, (3, 3), stride=(1, 1), padding=(1, 1)),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
                    # ... [rest of VGGish architecture]
                )

            def forward(self, x):
                return self.embeddings(x)

        self.vggish = VGGish().to(self.device).eval()
        self.vggish.load_state_dict(state_dict)

    def _preprocess_audio(self, waveform, sr):
        """Preprocess audio for VGGish model"""
        waveform = waveform.to(self.device)
        if sr != 16000:
            resampler = T.Resample(sr, 16000).to(self.device)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1:  # Convert to mono
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        return waveform.unsqueeze(0)  # Add batch dimension

    def _extract_embeddings(self, waveform):
        """Extract VGGish embeddings from audio"""
        if self.vggish is None:
            self._load_vggish()

        with torch.no_grad():
            specs = self._create_spectrograms(waveform)
            embeddings = self.vggish(specs)
        return embeddings.squeeze(0)

    def _create_spectrograms(self, waveform):
        """Create log-mel spectrograms compatible with VGGish"""
        mel_kwargs = {
            'sample_rate': 16000,
            'n_fft': 400,
            'win_length': 400,
            'hop_length': 160,
            'n_mels': 64
        }
        mel_spectrogram = T.MelSpectrogram(**mel_kwargs).to(self.device)
        specs = torch.log(mel_spectrogram(waveform) + 1e-6)
        return specs.unsqueeze(1)  # Add channel dimension

    def _find_optimal_transition(self, emb1, emb2):
        """Find optimal transition point using embedding similarity"""
        window_size = 30  # Compare 30-segment windows (~28.8 seconds)
        end_window = emb1[-window_size:]
        start_window = emb2[:window_size]

        # Compute similarity matrix
        sim_matrix = torch.cdist(end_window, start_window, p=2)

        # Find minimum distance path
        path = torch.zeros_like(sim_matrix)
        for i in range(1, sim_matrix.shape[0]):
            for j in range(1, sim_matrix.shape[1]):
                path[i,j] = sim_matrix[i,j] + torch.min(path[i-1,j],
                                                       path[i,j-1],
                                                       path[i-1,j-1])

        # Traceback optimal path
        i, j = sim_matrix.shape[0]-1, sim_matrix.shape[1]-1
        best_indices = [(i, j)]
        while i > 0 and j > 0:
            prev = torch.argmin(torch.tensor([
                path[i-1,j],
                path[i,j-1],
                path[i-1,j-1]
            ]))
            i -= (prev != 1)
            j -= (prev != 0)
            best_indices.append((i, j))

        # Convert indices to time offsets
        transition_end = emb1.shape[0] - (window_size - best_indices[-1][0])
        transition_start = best_indices[-1][1]
        return transition_end * 0.96, transition_start * 0.96

    def _smart_crossfade(self, song1, song2, sr, fade_point1, fade_point2, fade_duration=5.0):
        """Apply dynamic time-warped crossfade with beat alignment"""
        # Convert time points to samples
        fade_samples = int(fade_duration * sr)
        start1 = int(fade_point1 * sr) - fade_samples
        start2 = int(fade_point2 * sr)

        # Extract segments
        seg1 = song1[:, start1:start1+fade_samples]
        seg2 = song2[:, start2:start2+fade_samples]

        # Create dynamic crossfade curve
        x = torch.linspace(0, 1, fade_samples, device=self.device)
        fade_out = 0.5 * (1 + torch.cos(x * torch.pi))
        fade_in = 0.5 * (1 - torch.cos(x * torch.pi))

        # Apply windowing
        mixed = seg1 * fade_out + seg2 * fade_in

        # Reconstruct final track
        combined = torch.cat([
            song1[:, :start1],
            mixed,
            song2[:, start2+fade_samples:]
        ], dim=1)

        return combined

    def fuse_tracks(self, file1, file2, output_file):
        """Main fusion method"""
        # Load and preprocess audio
        wav1, sr1 = torchaudio.load(file1)
        wav2, sr2 = torchaudio.load(file2)

        # Resample to common rate (VGGish requirement)
        target_sr = 16000
        wav1 = F.resample(wav1, sr1, target_sr)
        wav2 = F.resample(wav2, sr2, target_sr)

        # Extract semantic embeddings
        emb1 = self._extract_embeddings(wav1)
        emb2 = self._extract_embeddings(wav2)

        # Find optimal transition points
        t1, t2 = self._find_optimal_transition(emb1, emb2)

        # Apply intelligent crossfade
        combined = self._smart_crossfade(wav1, wav2, target_sr, t1, t2)

        # Save result
        torchaudio.save(output_file, combined.cpu(), target_sr)

# Usage
fuser = AudioFusion()
fuser.fuse_tracks("/var/mnt/ssd/Файлы/Музыка/Tracks/Sharks - Maze Of Affection.mp3", "/var/mnt/ssd/Файлы/Музыка/Tracks/Machinedrum - H0N3Y.mp3", "fusion_output.mp3")

RuntimeError: Couldn't find appropriate backend to handle uri /var/mnt/ssd/Файлы/Музыка/Tracks/Sharks - Maze Of Affection.mp3 and format None.