In [None]:
from encodec import EncodecModel 
from encodec.quantization import ResidualVectorQuantizer
from encodec.utils import convert_audio
from mobilenetv3.mobilenetv3 import hswish, hsigmoid, SeModule, Block

import torchaudio
import torch 
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import init

import IPython
from datasets import load_dataset, DatasetDict

from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Prepare the Data

GTZAN is a good classification dataset for development. It consists of audio/text pairs of music/genre and is a fairly easy task to hit mid-high 90s on given MFCCs or waveforms. This section sets up the dataset. The only reason to run it is if you need to generate encodings at a higher bitrate for further development. The 1.5 bitrate target is already prepared in this repo.

In [None]:
gtzan = load_dataset("marsyas/gtzan")
encoder = EncodecModel.encodec_model_24khz()

def pre_process_gtzan(gtzan: DatasetDict, target_sr, target_channels) -> (list, list):
    """
    Pre-load the data and process it to the correct sample rate and mono/stereo.
    Returns the pre processed data and a list of the targets.
    """
    data, targets = [], []
    for x in tqdm(gtzan['train']):
        audio, sr = torchaudio.load(x['file'])
        audio = convert_audio(audio, sr, target_sr, target_channels)
        audio = audio.narrow(-1, 0, target_sr * 10)  # limit to 10 seconds
        data.append(audio.unsqueeze(0))
        targets.append(x['genre'])
        
    return data, targets

In [None]:
data, targets = pre_process_gtzan(gtzan, encoder.sample_rate, encoder.channels)

In [None]:
def encode_data(data, encoder, batch_size=8, device=None):
    print("Pre-encoding training data")
    
    encodings = []
    with torch.no_grad():
        for i in tqdm(range(0, len(data), batch_size)):
            batch = torch.cat(data[i:i+batch_size], dim=0).to(device)
            encoded_frames = encoder.encode(batch)
            
            codes = torch.cat([e[0] for e in encoded_frames], dim=-1)
            encodings.append(codes)
    
    encodings = torch.cat(encodings, dim=0)
    return encodings

# This takes about 5 minutes to run on a M1 Macbook Pro, a couple of minutes for a GPU
encoder = EncodecModel.encodec_model_24khz()
encoder.set_target_bandwidth(3.0)
encodings = encode_data(data, encoder, batch_size=8, device=device)

In [None]:
print(encodings[0].shape)

In [None]:
dataset = {
    'data': encodings,
    'targets': targets
}
torch.save(dataset, "gtzan_encodings-3.0.data")

# Load and split the data

In [None]:
class GTZANDataset(Dataset):
    def __init__(self, data, labels, device=None):
        super().__init__()
        self.data = data
        
        self.labels = [torch.tensor(x) for x in labels]
        
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
            data = self.data[index].to(self.device)
            target = self.labels[index].to(self.device)
            
            return data, target

def split_data(data, batch_size=32, random_seed=42, device=None, valid_size=0.1, test_size=0.05, shuffle=True):
    x = data['data']
    y = data['targets']
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=random_seed)
    x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_size, random_state=random_seed)

    train = GTZANDataset(x_train, y_train, device=device)
    valid = GTZANDataset(x_valid, y_valid, device=device)
    test = GTZANDataset(x_test, y_test, device=device)
    
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=True)
    return train_loader, valid_loader, test_loader

data = torch.load("./gtzan_encodings-3.0.data")
train, valid, test = split_data(data, batch_size=128)

## Decoding Example

The stored representations are indices into the quantized code book, so we must recover the codes before we can train on them. This following snippet is an example on how to do that. We could probably make this a little cleaner by not initializing the entire encodec model, but this is the simplest way to do so. After initialization, we can send just the quantizer to device to retrieve the codes as part of the `forward` call of any network. This allows the storage of a dataset on device to remain rather small.

One small caveat is that the quantizer expects a shape of `(n_residuals, batch_size, frames)`, so we need to transpose the input to get the right output from the decoder.

In [None]:
encoder = EncodecModel.encodec_model_24khz()
encoder.set_target_bandwidth(1.5)
quantizer = encoder.quantizer
data, targets = next(iter(train))
data = data.transpose(0, 1)
quantized = quantizer.decode(data)
print(data.shape, quantized.shape)
quantized = quantized.unsqueeze(1) # add channel dimension

# Build the Model

The next few blocks are tinkering with model size and individual Block size to make them work with out data. Since we are targeting 10s of audio, our input to a block of the mobilenet should be `(batch_size, 128, n_frames)` The input of the standard MobileNetV3 is `(batch_size, 224, 224, 3)`. First thought is to just use a learnable projection to put it into the right dimensionality expected by the base model and ignore the three channels on the first block

In [None]:
proj = nn.ConvTranspose2d(1, 3, kernel_size=(2,3), stride=(2,1), padding=(16, 264), bias=False)
projected = proj(quantized)
print(projected.shape)

From here, we should be able to use the MobileNet as is. We'll add the quantizer as the first step of the forward pass and ensure to freeze it so we don't end up back propping to it.

In [None]:
rand = torch.randn(1, 1, 128, 750)
projection = nn.Sequential(
    nn.Linear(750, 128),
    nn.ReLU(),
)
conv1 = nn.Conv2d(1, 3, kernel_size=21, stride=1, padding=2, bias=False)

out = conv1(projection(rand))
out.shape

In [None]:
rand2 = torch.randn(1, 3, 224, 224)
conv2 = nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1, bias=False)
out = conv2(rand2)
out.shape

In [None]:
# first layer of mobilenet
class MobileNetV3_Smol(nn.Module):
    def __init__(self, encodec_bw=1.5, num_classes=10, act=nn.Hardswish):
        super(MobileNetV3_Smol, self).__init__()
        encoder = EncodecModel.encodec_model_24khz()
        encoder.set_target_bandwidth(encodec_bw)
        self.quantizer = encoder.quantizer
        self.quantizer.requires_grad = False
        
        self.projection = nn.Sequential(
            nn.ConvTranspose2d(1, 3, kernel_size=(2, 3), stride=(2, 1), padding=(16, 264), bias=False),
            nn.BatchNorm2d(3),
            act(inplace=True)
        )
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1=act(inplace=True)
        
        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU, True, 2),
            Block(3, 16, 72, 24, nn.ReLU, False, 2),
            Block(3, 24, 88, 24, nn.ReLU, False, 1),
            Block(5, 24, 96, 40, act, True, 2),
            Block(5, 40, 240, 40, act, True, 1),
            Block(5, 40, 240, 40, act, True, 1),
            Block(5, 40, 120, 48, act, True, 1),
            Block(5, 48, 144, 48, act, True, 1),
            Block(5, 48, 288, 96, act, True, 2),
            Block(5, 96, 576, 96, act, True, 1),
            Block(5, 96, 576, 96, act, True, 1),
        )
        
        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = act(inplace=True)
        self.gap = nn.AdaptiveAvgPool2d(1)

        self.linear3 = nn.Linear(576, 1280, bias=False)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = act(inplace=True)
        self.drop = nn.Dropout(0.2)
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()
        
    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    
    def forward(self, x):
        # decode from the encodec representation
        x = x.transpose(0, 1) 
        x = self.quantizer.decode(x)
        
        x = x.unsqueeze(1) # add in a channel dimension
        x = self.projection(x)
        
        # run mobile net projection
        x = self.hs1(self.bn1(self.conv1(x)))
        
        #run the bnet
        x = self.bneck(x)
        
        # classify
        x = self.hs2(self.bn2(self.conv2(x)))
        x = self.gap(x).flatten(1)
        x = self.drop(self.hs3(self.bn3(self.linear3(x))))
        
        return self.linear4(x)
        
model = MobileNetV3_Smol()

x, y = next(iter(train))
out = model(x.to('cpu'))
out.shape

In [None]:
# first layer of mobilenet
class MobileNetV3_LARGE(nn.Module):
    def __init__(self, encodec_bw=1.5, num_classes=10, act=nn.Hardswish):
        super(MobileNetV3_LARGE, self).__init__()
        encoder = EncodecModel.encodec_model_24khz()
        encoder.set_target_bandwidth(encodec_bw)
        self.quantizer = encoder.quantizer
        self.quantizer.requires_grad = False
        self.projection = nn.Sequential(
            nn.Linear(750, 128),
            nn.ReLU(),
        )
        
        self.conv1 = nn.Conv2d(1, 16, kernel_size=21, stride=1, padding=2, bias=False)
        
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = act(inplace=True)
        
        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU, False, 1),
            Block(3, 16, 64, 24, nn.ReLU, False, 2),
            Block(3, 24, 72, 24, nn.ReLU, False, 1),
            Block(5, 24, 72, 40, nn.ReLU, True, 2),
            Block(5, 40, 120, 40, nn.ReLU, True, 1),
            Block(5, 40, 120, 40, nn.ReLU, True, 1),
            Block(3, 40, 240, 80, act, False, 2),
            Block(3, 80, 200, 80, act, False, 1),
            Block(3, 80, 184, 80, act, False, 1),
            Block(3, 80, 184, 80, act, False, 1),
            Block(3, 80, 480, 112, act, True, 1),
            Block(3, 112, 672, 112, act, True, 1),
            Block(5, 112, 672, 160, act, True, 2),
            Block(5, 160, 672, 160, act, True, 1),
            Block(5, 160, 960, 160, act, True, 1),
        )
        
        self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(960)
        self.hs2 = act(inplace=True)
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        self.linear3 = nn.Linear(960, 1280, bias=False)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = act(inplace=True)
        self.drop = nn.Dropout(0.2)

        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()
        
    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    
    def forward(self, x):
        # decode from the encodec representation
        x = x.transpose(0, 1) 
        x = self.quantizer.decode(x)
        
        x = x.unsqueeze(1) # add in a channel dimension
        x = self.projection(x)
        
        # run mobile net projection
        x = self.hs1(self.bn1(self.conv1(x)))
        
        #run the bnet
        x = self.bneck(x)
        
        # classify
        x = self.hs2(self.bn2(self.conv2(x)))
        x = self.gap(x).flatten(1)
        x = self.drop(self.hs3(self.bn3(self.linear3(x))))
        
        return self.linear4(x)
        
model = MobileNetV3_LARGE()

x, y = next(iter(train))
print(y)
out = model(x.to('cpu'))
out.shape

In [None]:
model = MobileNetV3_LARGE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
model.to(device)


n_epochs = 250
for i in range(n_epochs):
    model.train()
    training_loss = 0
    for x, y in train:
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        training_loss += loss.item()
        
    
    model.eval()
    with torch.no_grad():
        valid_loss = 0
        total, correct = 0, 0
        for x, y in valid:
            out = model(x)
            _, pred = torch.max(out, dim=1)
            total += len(y)
            correct += torch.sum(pred == y)
        
            valid_loss += criterion(out, y).item()
        
        print(f"Epoch {i}: training_loss (total) : {training_loss} | valid_loss: {valid_loss} | accuracy: {correct / total}")

In [None]:
gtzan

In [None]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification
from datasets import load_dataset
import torch

# dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# dataset = dataset.sort("id")
dataset = load_dataset("marsyas/gtzan")
sampling_rate = dataset.features["audio"].sampling_rate

feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

# audio file is decoded on the fly
inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_ids = torch.argmax(logits, dim=-1).item()
predicted_label = model.config.id2label[predicted_class_ids]
predicted_label

# compute loss - target_label is e.g. "down"
target_label = model.config.id2label[0]
inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
loss = model(**inputs).loss
round(loss.item(), 2)

In [None]:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="test_trainer")

feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

import numpy as np
import evaluate

metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


In [None]:
trainer = Trainer(
    model = model,
    args=training_args,
    train_dataset=gtzan,
    compute_metrics=compute_metrics
)

In [None]:
import os

os.environ["WANDB_PROJECT"] = "AST Finetune"
os.environ["WANDB_LOG_MODEL"] = "initial" # log all model checkpoints

trainer.train()

In [None]:
from hear21passt.base import get_basic_model, get_model_passt

model = get_basic_model(mode="logits")
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)