# Models

In [5]:
import torch

## 1. MixWaveUNet

## 2. Differentiable Mixing Console (DMC)

In [30]:
from automix.models.dmc import PostProcessor, Mixer, ShortChunkCNN_Res
from automix.utils import restore_from_0to1

Create the main modules of the system.

In [3]:
sample_rate = 44100
mixer = Mixer(sample_rate)
encoder = ShortChunkCNN_Res(sample_rate,ckpt_path="../checkpoints/encoder.ckpt")
post_processor = PostProcessor(mixer.num_params, encoder.d_embed * 2)



Loaded weights from ../checkpoints/encoder.ckpt


The `forward()` function describes how to generate a mix using the subsystems.

In [6]:
def forward(self, x: torch.Tensor, track_mask: torch.Tensor = None):
        """Given a set of tracks, analyze them with a shared encoder, predict a set of mixing parameters,
        and use these parameters to generate a stereo mixture of the inputs.

        Args:
            x (torch.Tensor): Input tracks with shape (bs, num_tracks, seq_len)
            track_mask (torch.Tensor, optional): Mask specifying inactivate tracks with shape (bs, num_tracks)

        Returns:
            y (torch.Tensor): Final stereo mixture with shape (bs, 2, seq_len)
            p (torch.Tensor): Estimated (denormalized) mixing parameters with shape (bs, num_tracks, num_params)
        """
        bs, num_tracks, seq_len = x.size()

        # move tracks to the batch dimension to fully parallelize embedding computation
        x = x.view(bs * num_tracks, -1)

        # generate single embedding for each track
        e = self.encoder(x)
        e = e.view(bs, num_tracks, -1)  # (bs, num_tracks, d_embed)

        # generate the "context" embedding
        c = e.mean(dim=1, keepdim=True)  # (bs, 1, d_embed)
        c = c.repeat(1, num_tracks, 1)  # (bs, num_tracks, d_embed)

        # fuse the track embs and context embs
        ec = torch.cat((e, c), dim=-1)  # (bs, num_tracks, d_embed*2)

        # estimate mixing parameters for each track (in parallel)
        p = self.post_processor(ec)  # (bs, num_tracks, num_params)

        # generate the stereo mix
        x = x.view(bs, num_tracks, -1)  # move tracks back from batch dim
        y, p = self.mixer(x, p)  # (bs, 2, seq_len) # and denormalized params

        return y, p

In [18]:
batch_size = 2
num_tracks = 8
num_samples = 131072

x = torch.randn(batch_size, num_tracks, num_samples)
bs, num_tracks, seq_len = x.size()

### 1. Generating embeddings

In [22]:
# move tracks to the batch dimension to fully parallelize embedding computation
x = x.view(bs * num_tracks, -1)
print(f"We get {bs}x{num_tracks} items in first dim: {x.shape}")

# generate single embedding for each track
e = encoder(x)
e = e.view(bs, num_tracks, -1)  # (bs, num_tracks, d_embed
print(f"We get {num_tracks} embeddings of size {encoder.d_embed}: {e.shape}")

We get 2x8 items in first dim: torch.Size([16, 131072])
We get 8 embeddings of size 512: torch.Size([2, 8, 512])


### 2. "Context" embedding

In [23]:
# generate the "context" embedding
c = e.mean(dim=1, keepdim=True)  # (bs, 1, d_embed)
c = c.repeat(1, num_tracks, 1)  # (bs, num_tracks, d_embed)

# fuse the track embs and context embs
ec = torch.cat((e, c), dim=-1)  # (bs, num_tracks, d_embed*2)

### 3. Estimate mixing parameters

In [25]:
# estimate mixing parameters for each track (in parallel)
p = post_processor(ec)  # (bs, num_tracks, num_params)

### 4. Generate the mix

In [32]:
def generate_mix(x: torch.Tensor, p: torch.Tensor):
    """Generate a mix of stems given mixing parameters normalized to (0,1).

    Args:
        x (torch.Tensor): Batch of waveform stem tensors with shape (bs, num_tracks, seq_len).
        p (torch.Tensor): Batch of normalized mixing parameters (0,1) for each stem with shape (bs, num_tracks, num_params)

    Returns:
        y (torch.Tensor): Batch of stereo waveform mixes with shape (bs, 2, seq_len)
    """
    bs, num_tracks, seq_len = x.size()
    
    min_gain_dB = -48.0
    max_gain_dB = 24.0

    # ------------- apply gain -------------
    gain_dB = p[..., 0]  # get gain parameter
    gain_dB = restore_from_0to1(gain_dB, min_gain_dB, max_gain_dB)
    gain_lin = 10 ** (gain_dB / 20.0)  # convert gain from dB scale to linear
    gain_lin = gain_lin.view(bs, num_tracks, 1)  # reshape for multiplication
    x = x * gain_lin  # apply gain (bs, num_tracks, seq_len)

    # ------------- apply panning -------------
    # expand mono stems to stereo, then apply panning
    x = x.view(bs, num_tracks, 1, -1)  # (bs, num_tracks, 1, seq_len)
    x = x.repeat(1, 1, 2, 1)  # (bs, num_tracks, 2, seq_len)

    pan = p[..., 1]  # get pan parameter
    pan_theta = pan * torch.pi / 2
    left_gain = torch.cos(pan_theta)
    right_gain = torch.sin(pan_theta)
    pan_gains_lin = torch.stack([left_gain, right_gain], dim=-1)
    pan_gains_lin = pan_gains_lin.view(bs, num_tracks, 2, 1)  # reshape for multiply
    x = x * pan_gains_lin  # (bs, num_tracks, 2, seq_len)

    # ----------------- apply mix -------------
    # generate a mix for each batch item by summing stereo tracks
    y = torch.sum(x, dim=1)  # (bs, 2, seq_len)

    p = torch.cat(
        (
            gain_dB.view(bs, num_tracks, 1),
            pan.view(bs, num_tracks, 1),
        ),
        dim=-1,
    )

    return y, p

In [33]:
# generate the stereo mix
x = x.view(bs, num_tracks, -1)  # move tracks back from batch dim
y, p = generate_mix(x, p)  # (bs, 2, seq_len) # and denormalized params

In [38]:
for tidx, track_params in enumerate(p[0,...]):
    print(f"{tidx} gain dB:{track_params[0]:0.3f}  pan:{track_params[1]:0.3f}")

0 gain dB:-12.856  pan:0.501
1 gain dB:-12.848  pan:0.501
2 gain dB:-12.841  pan:0.501
3 gain dB:-12.850  pan:0.501
4 gain dB:-12.849  pan:0.501
5 gain dB:-12.857  pan:0.501
6 gain dB:-12.861  pan:0.501
7 gain dB:-12.854  pan:0.501
