In [14]:
from encodec import EncodecModel 
from encodec.quantization import ResidualVectorQuantizer
from encodec.utils import convert_audio
from mobilenetv3.mobilenetv3 import hswish, hsigmoid, SeModule, Block

import torchaudio
import torch 
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset

import IPython

from datasets import load_dataset, DatasetDict

from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Prepare the Data

GTZAN is a good classification dataset for development. It consists of audio/text pairs of music/genre and is a fairly easy task to hit mid-high 90s on given MFCCs or waveforms. This section sets up the dataset. The only reason to run it is if you need to generate encodings at a higher bitrate for further development. The 1.5 bitrate target is already prepared in this repo.

In [18]:
gtzan = load_dataset("marsyas/gtzan")
encoder = EncodecModel.encodec_model_24khz()

def pre_process_gtzan(gtzan: DatasetDict, target_sr, target_channels) -> (list, list):
    """
    Pre-load the data and process it to the correct sample rate and mono/stereo.
    Returns the pre processed data and a list of the targets.
    """
    data, targets = [], []
    for x in tqdm(gtzan['train']):
        audio, sr = torchaudio.load(x['file'])
        audio = convert_audio(audio, sr, target_sr, target_channels)
        audio = audio.narrow(-1, 0, target_sr * 10)  # limit to 10 seconds
        data.append(audio.unsqueeze(0))
        targets.append(x['genre'])
        
    return data, targets



In [19]:
data, targets = pre_process_gtzan(gtzan, encoder.sample_rate, encoder.channels)

100%|██████████| 999/999 [00:17<00:00, 55.60it/s]


In [20]:
def encode_data(data, encoder, batch_size=8, device=None):
    print("Pre-encoding training data")
    
    encodings = []
    with torch.no_grad():
        for i in tqdm(range(0, len(data), batch_size)):
            batch = torch.cat(data[i:i+batch_size], dim=0).to(device)
            encoded_frames = encoder.encode(batch)
            
            codes = torch.cat([e[0] for e in encoded_frames], dim=-1)
            encodings.append(codes)
    
    encodings = torch.cat(encodings, dim=0)
    return encodings

# This takes about 5 minutes to run on a M1 Macbook Pro, a couple of minutes for a GPU
encoder = EncodecModel.encodec_model_24khz()
encoder.set_target_bandwidth(1.5)
encodings = encode_data(data, encoder, batch_size=8, device=device)



Pre-encoding training data


100%|██████████| 125/125 [04:40<00:00,  2.24s/it]


In [21]:
print(encodings[0].shape)

torch.Size([2, 750])


In [22]:
dataset = {
    'data': encodings,
    'targets': targets
}
torch.save(dataset, "gtzan_encodings-1.5.data")

# Load and split the data

In [23]:
class GTZANDataset(Dataset):
    def __init__(self, data, labels, device=None):
        super().__init__()
        self.data = data
        
        self.labels = [torch.tensor(x) for x in labels]
        
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
            data = self.data[index].to(self.device)
            target = self.labels[index].to(self.device)
            
            return data, target

def split_data(data, batch_size=32, random_seed=42, device=None, valid_size=0.1, test_size=0.05, shuffle=True):
    x = data['data']
    y = data['targets']
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=random_seed)
    x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_size, random_state=random_seed)

    train = GTZANDataset(x_train, y_train, device=device)
    valid = GTZANDataset(x_valid, y_valid, device=device)
    test = GTZANDataset(x_test, y_test, device=device)
    
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=True)
    return train_loader, valid_loader, test_loader

data = torch.load("./gtzan_encodings-1.5.data")
train, valid, test = split_data(data, batch_size=32)

## Decoding Example

The stored representations are indices into the quantized code book, so we must recover the codes before we can train on them. This following snippet is an example on how to do that. We could probably make this a little cleaner by not initializing the entire encodec model, but this is the simplest way to do so. After initialization, we can send just the quantizer to device to retrieve the codes as part of the `forward` call of any network. This allows the storage of a dataset on device to remain rather small.

One small caveat is that the quantizer expects a shape of `(n_residuals, batch_size, frames)`, so we need to transpose the input to get the right output from the decoder.

In [35]:
encoder = EncodecModel.encodec_model_24khz()
encoder.set_target_bandwidth(1.5)
quantizer = encoder.quantizer
data, targets = next(iter(train))
data = data.transpose(0, 1)
quantized = quantizer.decode(data)
print(data.shape, quantized.shape)

torch.Size([2, 32, 750]) torch.Size([32, 128, 750])




# Build the Model

The next few blocks are tinkering with model size and individual Block size to make them work with out data. Since we are targeting 10s of audio, our input to a block of the mobilenet should be `(batch_size, 128, n_frames)`