#### Requirements

In [1]:
!pip install datasets[audio] torchaudio_augmentations

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Dataset

In [18]:
import torchaudio
from datasets import load_dataset
from torch.utils.data import Dataset


class GTZAN(Dataset):

    def __init__(self, split="train"):
        self.dataset = load_dataset("marsyas/gtzan", split=split)
        self.labels = ['blues', 'classical', 'country', 'disco', 'hiphop',
                       'jazz', 'metal', 'pop', 'reggae', 'rock']

        self.label2idx = {label: idx for idx, label in enumerate(self.labels)}
        self.n_classes = len(self.labels)

    def __getitem__(self, idx):
        file_path = self.dataset[idx]['file']
        audio, sr = torchaudio.load(file_path)
        label = self.label2idx[self.labels[self.dataset[idx]['genre']]]

        return audio, label

    def __len__(self):
        return len(self.dataset)

In [19]:
"""Wrapper for Torch Dataset class to enable contrastive training
"""
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio_augmentations import Compose
from typing import Tuple, List


class ContrastiveDataset(Dataset):
    def __init__(self, dataset: Dataset, input_shape: List[int], transform: Compose):
        self.dataset = dataset
        self.transform = transform
        self.input_shape = input_shape
        self.ignore_idx = []

    def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
        if idx in self.ignore_idx:
            return self[idx + 1]

        audio, label = self.dataset[idx]

        if audio.shape[1] < self.input_shape[1]:
            self.ignore_idx.append(idx)
            return self[idx + 1]

        if self.transform:
            audio = self.transform(audio)
        return audio, label

    def __len__(self) -> int:
        return len(self.dataset)

    def concat_clip(self, n: int, audio_length: float) -> Tensor:
        audio, _ = self.dataset[n]
        batch = torch.split(audio, audio_length, dim=1)
        batch = torch.cat(batch[:-1])
        batch = batch.unsqueeze(dim=1)

        if self.transform:
            batch = self.transform(batch)

        return batch

In [40]:
import os
from torchaudio_augmentations import(ComposeMany, RandomResizedCrop)

def get_dataset(dataset):
    train_transform = [RandomResizedCrop(n_samples=59049)]
    num_augmented_samples = 1
    if dataset == "gtzan":
        d = GTZAN()
        contrastive_dataset = ContrastiveDataset(
        d,
        input_shape=(1, 59049),
        transform=ComposeMany(train_transform, num_augmented_samples)
    )
    else:
        raise NotImplementedError("Dataset not implemented")
    return contrastive_dataset

#### Dataset tests

In [22]:
import unittest
import torch

class TestGTZAN(unittest.TestCase):
    def test_dataset(self):
        dataset = GTZAN()
        sample_idx = 0
        sample = dataset.__getitem__(sample_idx)

        # Audio waveform
        self.assertIsInstance(sample[0], torch.Tensor)

        # Label
        self.assertIsInstance(sample[1], int)

        # Audio waveform has at least 1 sample
        self.assertGreaterEqual(sample[0].shape[0], 1)

        # Label is non-negative
        self.assertGreaterEqual(sample[1], 0)
        
        # Label is less than the number of classes in the dataset
        self.assertLess(sample[1], dataset.num_classes)

In [23]:
gtzan_test = TestGTZAN()
gtzan_test.test_dataset

<bound method TestGTZAN.test_dataset of <__main__.TestGTZAN testMethod=runTest>>

### Dataloader

In [41]:
from torch.utils.data import DataLoader
dataset = get_dataset("gtzan")

dataloader = DataLoader(
    dataset,
    batch_size=48,
    num_workers=0,
    drop_last=True,
    shuffle=False,
)



#### Dataloader tests

In [42]:
class TestDataLoader(unittest.TestCase):
    
    def test_batch_size(self):
        # set up
        dataset = get_dataset("gtzan")
        input_shape = (1, 59049)
        dataloader = DataLoader(dataset, batch_size=48, num_workers=0, drop_last=True, shuffle=False)
        
        # test
        for batch_idx, (data, target) in enumerate(dataloader):
            batch_size = data.shape[0]
            self.assertEqual(batch_size, 48, f"Batch {batch_idx} - Expected batch size: 48, Actual batch size: {batch_size}")

In [43]:
# dataloader_test = TestDataLoader()
# dataloader_test.test_batch_size()



RuntimeError: ignored

## Model

In [32]:
import torch.nn as nn
import numpy as np


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def initialize(self, m):
        if isinstance(m, (nn.Conv1d)):
            # nn.init.xavier_uniform_(m.weight)
            # if m.bias is not None:
            #     nn.init.xavier_uniform_(m.bias)

            nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [37]:
class SampleCNN(Model):
    def __init__(self, strides):
        super(SampleCNN, self).__init__()

        self.strides = strides
        self.sequential = [
            nn.Sequential(
                nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0),
                nn.BatchNorm1d(128),
                nn.ReLU(),
            )
        ]

        self.hidden = [
            [128, 128],
            [128, 128],
            [128, 256],
            [256, 256],
            [256, 256],
            [256, 256],
            [256, 256],
            [256, 256],
            [256, 512],
        ]

        assert len(self.hidden) == len(
            self.strides
        ), "Number of hidden layers and strides are not equal"
        for stride, (h_in, h_out) in zip(self.strides, self.hidden):
            self.sequential.append(
                nn.Sequential(
                    nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1),
                    nn.BatchNorm1d(h_out),
                    nn.ReLU(),
                    nn.MaxPool1d(stride, stride=stride),
                )
            )

        # 1 x 512
        self.sequential.append(
            nn.Sequential(
                nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(512),
                nn.ReLU(),
            )
        )

        self.sequential = nn.Sequential(*self.sequential)

    def forward(self, x):
        x = x[:, 0, :]
        out = self.sequential(x)
        return x

In [34]:
from collections import OrderedDict


def load_encoder_checkpoint(checkpoint_path: str) -> OrderedDict:
    state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if "encoder." in k:
            new_state_dict[k.replace("encoder.", "")] = v

    return new_state_dict

In [38]:
import pickle
strides = [3, 3, 3, 3, 3, 3, 3, 3, 3]
model = SampleCNN(strides)
checkpoint_path = "clmr_checkpoint_10000.pt"
model.load_state_dict(load_encoder_checkpoint(checkpoint_path))
with open('model.pkl', 'wb') as f:
  pickle.dump(model, f)

In [39]:
model.eval()  # put the model in evaluation mode

all_representations = []
with torch.no_grad():
    for data, target in dataloader:
        # pass the data through the model to get representations
        representations = model(data)
        all_representations.append(representations)

all_representations = torch.cat(all_representations)
with open('representations.pkl', 'wb') as f:
  pickle.dump(all_representations, f)