# CROSSFAIDER - First Iteration

In [1]:
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor
import torch
import numpy as np
from scipy.io.wavfile import write as write_wav
import torchaudio.transforms as T

  from .autonotebook import tqdm as notebook_tqdm


## Load Facebooks Encodec model

In [2]:
# Load model and processor
model = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


## Extract features of A's end and B's beginning using Encodec

In [3]:
import torchaudio

def extract_encodec_features(audio_path, duration=5, last=True):

    # Load audio file
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Convert to mono by averaging channels if multiple channels exist
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Determine the amount of samples (=duration) to keep
    total_samples = waveform.shape[1]
    samples_to_keep = int(duration * sample_rate)
    
    # Select the first or last duration seconds
    if last:
        start_sample = max(0, total_samples - samples_to_keep)
    else:
        start_sample = 0
    
    # Calculate the end position for the segment
    end_sample = min(start_sample + samples_to_keep, total_samples)

    # Slice the waveform between start_sample and end_sample
    waveform = waveform[:, start_sample:end_sample]
    
    # Resample if needed
    if sample_rate != processor.sampling_rate:
        waveform = T.Resample(orig_freq=sample_rate, new_freq=processor.sampling_rate)(waveform)
    
    # For compatibility with encodec processor, which expects a 1D waveform (just the sample values), we remove the channel dimension
    waveform = waveform.squeeze(0)
    
    # Process for encodec
    inputs = processor(raw_audio=waveform, sampling_rate=processor.sampling_rate, return_tensors="pt")
    
    # Encode audio features into compressed encodec features (audio_codes, audio_scales)
    # audio_codes: quantized tokens that represent original audio in compressed form
    # audio_scales: adjusts the loudness/amplitude of the encoded features
    with torch.no_grad():
        # input_values: preprocessed audio waveform in tensor form
        # padding_mask: indicates which parts of the input are valid audio and which are padding
        encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
    
    return encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"]

## Interpolate between A and B features

Each interpolation step generates a new version of the audio with a slightly different mix between Track A and Track B. These individual versions represent gradual transformation from A → B.

At alpha = 0.0, the audio is 100% Track A.
At alpha = 1.0, the audio is 100% Track B.
Intermediate values (e.g., alpha = 0.2, alpha = 0.4, etc.) create intermediate sounds.
Each of these is saved as a separate .wav file.

In [4]:
def interpolate_encodec_features(features_A, features_B, steps=10):
    # input: features - tuples of (audio_codes, audio_scales, padding_mask)
    # returns list of tuples containing interpolated features
    
    audio_codes_A, audio_scales_A, padding_mask_A = features_A
    audio_codes_B, audio_scales_B, padding_mask_B = features_B
    
    # Ensure that both feature sets have the same shape
    if audio_codes_A.shape != audio_codes_B.shape:
        raise ValueError(f"Shape mismatch: A codes {audio_codes_A.shape}, B codes {audio_codes_B.shape}")
    
    interpolations = []

    # Generate steps evenly spaced, from values 0 to 1
    # alpha represents the interpolation weight between A (alpha=0) and B (alpha=1)
    for alpha in np.linspace(0, 1, steps):
        # For discrete tokens, use probabilistic selection
        # Create a random matrix with values between 0 and 1, matching the shape of audio_codes, where values<alpha are true (= take from audio_codes_B)
        # For intermediate values, a random mix occurs
        selection_mask = torch.rand_like(audio_codes_A.float()) < alpha
        
        # Create new tensor initialized with codes from A
        interpolated_code = audio_codes_A.clone()
        # Use the selection_mask to replace certain values with audio_codes_B
        interpolated_code[selection_mask] = audio_codes_B[selection_mask]
        
        # Interpolate scales based on type
        if isinstance(audio_scales_A, torch.Tensor) and isinstance(audio_scales_B, torch.Tensor):
            # If they are tensors do weighted average
            interpolated_scale = (1 - alpha) * audio_scales_A + alpha * audio_scales_B
        else:
            # If they are not tensors use A scales for alpha < 0.5, otherwise use B, hard switch
            interpolated_scale = audio_scales_A if alpha < 0.5 else audio_scales_B
        
        # Use padding mask from A (they should be the same length anyway)
        interpolations.append((interpolated_code, interpolated_scale, padding_mask_A))
    
    return interpolations

## Decode EnCodec features into audio waveforms

In [5]:
def decode_encodec_features(interpolated_features):

    #input: interpolated_features = list of tuples (audio_codes, audio_scales, padding_mask)
    # returns List of decoded audio arrays

    interpolated_audio = []
    
    for codes, scales, mask in interpolated_features:
        with torch.no_grad():
            # Ensure audio codes are integer type
            codes = codes.long()
            
            # Decode
            audio_output = model.decode(
                audio_codes=codes,
                audio_scales=scales,
                padding_mask=mask
            )
            
            # Extract audio values
            audio_values = audio_output.audio_values
            
            # Convert to numpy array
            audio_np = audio_values.squeeze().cpu().numpy()
            
        interpolated_audio.append(audio_np)
    
    return interpolated_audio

In [6]:
def save_audio_files(interpolated_audio, sample_rate=24000, base_filename="interpolation", save_full_transition=True):

    audio_files = []
    
    # Process each individual interpolated segment
    for i, audio in enumerate(interpolated_audio):
        # Normalize audio to prevent clipping
        if np.abs(audio).max() > 0.0:
            audio = audio / np.abs(audio).max()
        
        filename = f"{base_filename}_{i}.wav"
        write_wav(filename, sample_rate, audio.astype(np.float32))
        audio_files.append(filename)
        
        print(f"Saved {filename}")
    
    # Save the full transition if requested
    if save_full_transition and len(interpolated_audio) > 0:
        # Concatenate all audio segments
        full_transition = np.concatenate(interpolated_audio, axis=0)
        
        # Normalize the full transition
        if np.abs(full_transition).max() > 0.0:
            full_transition = full_transition / np.abs(full_transition).max()
        
        # Save as a single file
        full_filename = f"{base_filename}_full_transition.wav"
        write_wav(full_filename, sample_rate, full_transition.astype(np.float32))
        audio_files.append(full_filename)
        
        print(f"Saved full transition as {full_filename}")
    
    return audio_files

In [7]:
# Example usage:
features_A = extract_encodec_features("track_A.mp3", duration=10)
features_B = extract_encodec_features("track_B.mp3", duration=10, last=False)
interpolated_features = interpolate_encodec_features(features_A, features_B, steps=5)
interpolated_audio = decode_encodec_features(interpolated_features)
audio_files = save_audio_files(interpolated_audio)

Saved interpolation_0.wav
Saved interpolation_1.wav
Saved interpolation_2.wav
Saved interpolation_3.wav
Saved interpolation_4.wav
Saved full transition as interpolation_full_transition.wav
