In [2]:

# If you only want to use a pretrained model, you can also directly get it
# from the CompressionModel base model class.
from audiocraft.models import CompressionModel

# Here do not put the `//pretrained/` prefix!
model = CompressionModel.get_pretrained('facebook/encodec_32khz')

import torch
import torchaudio 
import glob
from einops import rearrange

In [3]:
# model = model.cuda().eval()

# audio_list = ["backtoblack.mp3"]
# len_signals = 120

# with torch.no_grad():

#     x_, fs = torchaudio.load(audio_list[0])
#     x_ = x_[0, :int(len_signals*fs)]
#     if fs != 32000:
#         x_ = torchaudio.transforms.Resample(fs, 32000)(x_)

#     x = x_.cuda().unsqueeze(0)
#     q, scale = model.encode(x) #int codes
#     xhat_encodec = model.decode(q, scale).cpu()

#     torchaudio.save(f".tests/prompt.wav", x.cpu().squeeze().unsqueeze(0), sample_rate=32000, channels_first=True)
#     torchaudio.save(f".tests/encodec.wav", xhat_encodec.squeeze().unsqueeze(0), sample_rate=32000, channels_first=True)

In [3]:

def _shuffle_codebooks(x, groups: int = 0):
    ''' x is B,N
    If groups > 0 we do group shuffling i.e. we only permute between groups.
    The shuffling is the same intra group, so that only the groups
    will be independent between them, and we keep the intra-group correlation
    '''
    B, N, dim = x.size()
    x_shuffled = torch.zeros_like(x)
    groups = groups if groups else N #That way, we keep the original version compatible

    assert N%groups == 0, f"Dimensions {N} are not divisible by number of groups {groups}"
    for group in range(groups):
        batch_perm = torch.randperm(B, device=x.device)
        x_shuffled[:, group*N//groups: (group+1)*N//groups] = x[batch_perm, group*N//groups: (group+1)*N//groups]
    return x_shuffled

class MMDLoss(torch.nn.Module):
    def __init__(self, delay: bool = False, kernel: str = "rbf", scales = [0.1, 1, 5, 10, 20, 50], device=None):
        super().__init__()
        self.device = device
        self.delay = delay
        self.kernel = kernel
        self.scales = scales

    def _kernel_fn(self, dxx: torch.Tensor, a: torch.Tensor = None):

        if self.kernel == "rbf":
            return torch.exp((-0.5 / a) * dxx).sum()
        
        elif self.kernel == "inverse":
            return (a**2 / (a**2 + dxx)).sum()
        
        elif self.kernel == "linear":
            return dxx.sum()

        elif self.kernel == "quadratic":
            return dxx.sum().square()
        
    def forward(self, inputs: torch.Tensor):
        """inputs is [B, K, D, T].
        """
        if self.device is not None:
            inputs = inputs.to(self.device)

        miniB, K, D, T = inputs.size()
        x = inputs.type(torch.float)
        x = (x - x.mean(dim=(0, 2, 3), keepdim=True)) / torch.sqrt(x.var(dim=(0, 2, 3), keepdim=True) + 1e-8)

        # Reshaping / Delaying
        if self.delay:
            x = torch.cat([ torch.nn.functional.pad(x[:, delay: (delay+1), :, : T-delay], (delay, 0)) for delay in range(K) ], dim=1)
            x = x[..., K: ] #Crop to remove zeros introduced by padding

        # Group time dimension and shuffle to sample from factorized distribution
        x = rearrange(x, 'b k d t -> (b t) k d')
        B = x.shape[0]
        y = _shuffle_codebooks(x)
        x = x.view(B, -1)
        y = y.view(B, -1)

        xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
        rx = xx.diag().unsqueeze(0).expand_as(xx)
        ry = yy.diag().unsqueeze(0).expand_as(yy)

        out: torch.Tensor = 0.0  # type: ignore

        if self.kernel == "rbf":
            dxx = rx.t() + rx - 2.0 * xx
            dyy = ry.t() + ry - 2.0 * yy
            dxy = rx.t() + ry - 2.0 * zz

            for a in self.scales:
                out += ( torch.exp((-0.5 / a) * dxx).sum() - B) / (B * (B - 1)) #remove B because that is the sum of the diagonal
                out += ( torch.exp((-0.5 / a) * dyy).sum() - B) / (B * (B - 1)) #remove B because that is the sum of the diagonal
                out += (-2 / B**2) * torch.exp((-0.5 / a) * dxy).sum()

        elif self.kernel == "inverse":
            dxx = rx.t() + rx - 2.0 * xx
            dyy = ry.t() + ry - 2.0 * yy
            dxy = rx.t() + ry - 2.0 * zz

            for a in self.scales:
                out += ( (a**2/(a**2 + dxx)).sum() - B) / (B * (B - 1)) #remove B because that is the sum of the diagonal
                out += ( (a**2/(a**2 + dyy)).sum() - B) / (B * (B - 1)) #remove B because that is the sum of the diagonal
                out += (-2 / B**2) * (a**2/(a**2 + dxy)).sum()

        elif self.kernel == "linear": 
            
            out += xx.sum() / (B * (B - 1))
            out += yy.sum() / (B * (B - 1))
            out += (-2 / B**2) * zz.sum()

        elif self.kernel == "quadratic": 
            
            out += xx.square().sum() / (B * (B - 1))
            out += yy.square().sum() / (B * (B - 1))
            out += (-2 / B**2) * zz.square().sum()

        return out.clamp(min=0)

In [7]:
model = model.cuda().eval()

b = 5
audio_list = 10*["backtoblack.mp3"]
len_signals = 120

# loss_rbf = MMDLoss(delay=True, kernel="rbf", scales=[0.1, 1, 5, 10, 20, 50], device="cuda")
# loss_inverse = MMDLoss(delay=True, kernel="inverse", scales=[0.1, 1, 5, 10, 20, 50], device="cuda")

# loss_inverse = MMDLoss(delay=True, kernel="inverse", scales=[0.1, 1, 5, 10, 20, 50], device="cuda")
# loss_lin = MMDLoss(delay=True, kernel="linear", device="cuda")
# loss_quad = MMDLoss(delay=True, kernel="quadratic",device="cuda")

with torch.no_grad():

    x = []
    for i in range(b):

        x_, fs = torchaudio.load(audio_list[i])
        x_ = x_[0, :int(len_signals*fs)]

        if fs != 32000:
            x_ = torchaudio.transforms.Resample(fs, 32000)(x_)

        x.append(x_)
    
    x = torch.stack(x).cuda().unsqueeze(1)
    q, scale = model.encode(x) #int codes
    _, unquantized = model.model.quantizer.decode(q.transpose(0, 1), return_multi=True) #vectors
    
    # torch.random.manual_seed(1)
    
    for kernel in ["ms-rbf", "rbf", "inverse", "linear", "quadratic"]:

        if kernel == "ms-rbf":
            
            scales = [0.1, 1, 5, 10, 20, 50]
            loss = MMDLoss(delay=True, kernel="rbf", scales=scales, device="cuda")
            _loss = loss(unquantized)
            _loss *= 1/0.0008367030532
            print(1/0.0008367030532)
            print(f"Scale: {scales}, {loss.kernel}: {_loss.item()}")

        if kernel == "rbf":
            
            for scale in [120]:
                loss = MMDLoss(delay=True, kernel=kernel, scales=[scale], device="cuda")
                _loss = loss(unquantized)
                _loss *= 1/0.000265
                print(1/0.000265)
                print(f"Scale: {scale}, {loss.kernel}: {_loss.item()}")

        elif kernel == "inverse":

            # continue
            # scale = 12
            # for scale in [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]:
            for scale in [12]:
                loss = MMDLoss(delay=True, kernel=kernel, scales=[scale], device="cuda")
                _loss = loss(unquantized)
                _loss *= 1/0.000165
                print(1/0.000165)
                print(f"Scale: {scale}, {loss.kernel}: {_loss.item()}")

        elif kernel == "linear":

            loss = MMDLoss(delay=True, kernel=kernel, device="cuda")
            _loss = loss(unquantized)
            _loss *= 1/0.000692
            print(1/0.000692)
            print(f"{loss.kernel}: {_loss.item()}")

        elif kernel == "quadratic":

            loss = MMDLoss(delay=True, kernel=kernel, device="cuda")
            _loss = loss(unquantized)
            _loss *= 1/74.66
            print(1/74.66)
            print(f"{loss.kernel}: {_loss.item()}")

torch.Size([4, 5, 6000]) 4
1195.1671458296526
Scale: [0.1, 1, 5, 10, 20, 50], rbf: 0.9987932443618774
3773.5849056603774
Scale: 120, rbf: 1.0078104734420776
6060.606060606061
Scale: 12, inverse: 1.0096665620803833
1445.086705202312
linear: 0.9977748394012451
0.01339405304045004
quadratic: 1.0015389919281006
