In [14]:
import os
import random
import torch
import torchaudio
import pandas as pd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as DatasetTorch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset, Audio
import pyarrow as pa
import pyarrow.parquet as pq
import librosa
import sys
from tqdm import tqdm
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import colorsys
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, confusion_matrix
from IPython.display import Audio as ipyaudio 
import plotly.graph_objs as go
from ipywidgets import Output
from ipywidgets import VBox
from IPython.display import display
import ipywidgets as widgets

In [4]:
train_dir = '/work/tc062/tc062/manishav/huggingface_cache/datasets/speechcolab___gigaspeech/xs/0.0.0/0db31224ad43470c71b459deb2f2b40956b3a4edfde5fb313aaec69ec7b50d3c/gigaspeech-train.arrow'

# Load the full datasets
full_dataset = Dataset.from_file(train_dir)
full_dataset = full_dataset.cast_column("audio", Audio(sampling_rate=16000))

print('getting category names')

category_names = full_dataset.features['category'].names
# Create a dictionary mapping indices to category names
category_dict = {i: name for i, name in enumerate(category_names)}

print('loaded dataset')


getting category names
loaded dataset


In [5]:
def prepare_datasets(dataset):
    # Get the total number of samples
    total_samples = len(dataset)
    
    # Calculate the number of samples for training (80% of total)
    train_samples = int(0.8 * total_samples)
    
    # Create a random permutation of indices
    shuffled_indices = np.random.permutation(total_samples)
    
    # Split the shuffled indices
    train_indices = shuffled_indices[:train_samples]
    test_indices = shuffled_indices[train_samples:]
    
    # Use the select method to create train and test datasets
    train_dataset = dataset.select(train_indices)
    test_dataset = dataset.select(test_indices)
    
    # Remove audiobooks from the training set
    train_dataset = train_dataset.filter(lambda x: x['category'] != category_names.index('audiobook'))
    
    return train_dataset, test_dataset

# Prepare datasets
train_dataset, test_dataset = prepare_datasets(full_dataset)

print('split data')

Filter:   0%|          | 0/7511 [00:00<?, ? examples/s]

split data


In [6]:
class AudioUtil():
    @staticmethod
    def open(audio_file):
        sig, sr = torchaudio.load(str(audio_file))
        return (sig, sr)

    @staticmethod
    def rechannel(aud, new_channel):
        sig, sr = aud
        if sig.shape[0] == new_channel:
            return aud
        if new_channel == 1:
            sig = sig.mean(dim=0, keepdim=True)
        else:
            sig = sig.expand(new_channel, -1)
        return (sig, sr)

    @staticmethod
    def resample(aud, new_sr):
        sig, sr = aud
        if sr == new_sr:
            return aud
        num_channels = sig.shape[0]
        resig = torchaudio.transforms.Resample(sr, new_sr)(sig[:1, :])
        if num_channels > 1:
            retwo = torchaudio.transforms.Resample(sr, new_sr)(sig[1:, :])
            resig = torch.cat([resig, retwo])
        return (resig, new_sr)

    @staticmethod
    def pad_trunc(aud, max_ms):
        sig, sr = aud
        num_rows, sig_len = sig.shape
        max_len = sr // 1000 * max_ms
        if sig_len > max_len:
            sig = sig[:, :max_len]
        elif sig_len < max_len:
            pad_begin_len = random.randint(0, max_len - sig_len)
            pad_end_len = max_len - sig_len - pad_begin_len
            pad_begin = torch.zeros((num_rows, pad_begin_len))
            pad_end = torch.zeros((num_rows, pad_end_len))
            sig = torch.cat((pad_begin, sig, pad_end), 1)
        return (sig, sr)

    @staticmethod
    def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
        sig, sr = aud
        top_db = 80
        spec = torchaudio.transforms.MelSpectrogram(
            sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)
        spec = torchaudio.transforms.AmplitudeToDB(top_db=top_db)(spec)
        return spec.squeeze(0)  # Remove the channel dimension


In [7]:
class GenreDataset(DatasetTorch):
    def __init__(self, dataset, duration=5000, sr=16000, transform=None):
        self.dataset = dataset
        self.duration = duration
        self.sr = sr
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        audio_data = item['audio']

        # Get the file path
        file_path = audio_data.get('path', '')
        
        # Assuming audio_data is a list with a single dictionary
        if isinstance(audio_data, list) and len(audio_data) > 0:
            audio_data = audio_data[0]
        
        # Now audio_data should be a dictionary
        if 'array' in audio_data:
            if isinstance(audio_data['array'], np.ndarray):
                sig = torch.from_numpy(audio_data['array']).float()
            else:
                # If it's not a numpy array, it might be a list, so convert it
                sig = torch.tensor(audio_data['array']).float()
        else:
            # If 'array' is not present, try to load from 'path'
            audio_path = audio_data.get('path')
            if audio_path:
                sig, sr = torchaudio.load(audio_path)
            else:
                raise ValueError(f"Cannot load audio data for item {idx}")

        sr = audio_data.get('sampling_rate', self.sr)

        # Ensure the signal is 2D (add channel dimension if necessary)
        if sig.dim() == 1:
            sig = sig.unsqueeze(0)

        # Resample audio to ensure uniform sampling rate
        if sr != self.sr:
            sig = torchaudio.transforms.Resample(sr, self.sr)(sig)

        label = item['category']

        aud = (sig, self.sr)  # Ensure uniform sampling rate
        aud = AudioUtil.rechannel(aud, 1)
        aud = AudioUtil.pad_trunc(aud, self.duration)
        sgram = AudioUtil.spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None)
        
        # Remove the channel dimension if it exists
        if sgram.dim() == 3:
            sgram = sgram.squeeze(0)

        if self.transform:
            sgram = self.transform(sgram)

        return sgram, torch.tensor(label, dtype=torch.long), file_path


In [8]:
# Create dataset objects
train_dataset = GenreDataset(train_dataset)
test_dataset = GenreDataset(test_dataset)

print('starting dataloader')
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
print('ending dataloader')


starting dataloader
ending dataloader


In [9]:
class AudioClassifier(nn.Module):
    def __init__(self, num_classes=29):
        super(AudioClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = None  # We'll define this in the forward pass
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # print("Input shape to model:", x.shape)
        if x.dim() == 3:
            x = x.unsqueeze(1)
        # print("Shape after potential unsqueeze:", x.shape)
        x = self.pool(F.relu(self.conv1(x)))
        # print("Shape after first conv and pool:", x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        
        # Flatten the output
        x = x.view(x.size(0), -1)
        # print("Shape after flattening:", x.shape)
    
        # Dynamically create fc1 if it doesn't exist
        if self.fc1 is None:
            self.fc1 = nn.Linear(x.shape[1], 128).to(x.device)
        
        embeddings = F.relu(self.fc1(x))
        x = self.dropout(embeddings)
        x = self.fc2(x)
        return x, embeddings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AudioClassifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    train_losses = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for batch in tqdm(train_loader):
            inputs, labels, _ = batch
            inputs = torch.stack(inputs)
            labels = torch.stack(labels)
            if inputs.dim() == 3:
                inputs = inputs.unsqueeze(1)
            
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        
        print(f'Epoch {epoch}/{num_epochs - 1}, Train Loss: {epoch_loss:.4f}')
    
    # Save the loss plot
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss over Epochs')
    plt.legend()
    plt.savefig(os.path.join("gs-embeddings", f"loss_plot_{num_epochs}epochs_2Kepochs_no-audiobook-training.png"))
    plt.close()

# Train the model
train_model(model, train_loader, criterion, optimizer, num_epochs=1)


100%|██████████████████████████████████████████████████████████████████████████████████████| 177/177 [00:13<00:00, 13.61it/s]


Epoch 0/0, Train Loss: 1.8707


In [11]:
# Evaluation on test dataset
model.eval()
test_embeddings = []
test_labels = []
test_predictions = []
test_file_paths = []

with torch.no_grad():
    for batch in test_loader:
        inputs, labels, paths = batch
        inputs = torch.stack(inputs)
        labels = torch.stack(labels)
        if inputs.dim() == 3:
            inputs = inputs.unsqueeze(1)
        inputs = inputs.to(device)
        outputs, emb = model(inputs)
        _, preds = torch.max(outputs, 1)
        
        test_embeddings.append(emb.cpu().numpy())
        test_labels.append(labels.cpu().numpy())
        test_predictions.append(preds.cpu().numpy())
        test_file_paths.extend(paths)

test_embeddings = np.concatenate(test_embeddings)
test_labels = np.concatenate(test_labels)
test_predictions = np.concatenate(test_predictions)


In [12]:
# Evaluation metrics
accuracy = accuracy_score(test_labels, test_predictions)
f1 = f1_score(test_labels, test_predictions, average='weighted')
precision = precision_score(test_labels, test_predictions, average='weighted')

print(f"Test Set Accuracy: {accuracy:.4f}")
print(f"Test Set F1 Score: {f1:.4f}")
print(f"Test Set Precision: {precision:.4f}")

# Confusion Matrix
cm = confusion_matrix(test_labels, test_predictions)
all_categories = list(category_dict.values())

plt.figure(figsize=(25, 20))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=all_categories,
            yticklabels=all_categories)
plt.xlabel('Predicted', fontsize=14)
plt.ylabel('True', fontsize=14)
plt.title('Confusion Matrix for Test Set', fontsize=16)
plt.xticks(rotation=90, fontsize=10)
plt.yticks(rotation=0, fontsize=10)
plt.tight_layout()
plt.savefig('confusion_matrix_test_2Kepochs_no-audiobook-training_soundtest.png', dpi=300)
plt.close()

print("Confusion matrix saved as 'confusion_matrix_test_2Kepochs_no-audiobook-training_soundtest.png'")

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Test Set Accuracy: 0.3823
Test Set F1 Score: 0.2992
Test Set Precision: 0.2617
Confusion matrix saved as 'confusion_matrix_test_2Kepochs_no-audiobook-training_soundtest.png'


In [15]:
# T-SNE visualization
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(test_embeddings)

def generate_distinct_colors(n):
    HSV_tuples = [(x * 1.0 / n, 0.5, 0.5) for x in range(n)]
    RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)
    return ['rgb'+str(tuple(int(255*x) for x in rgb)) for rgb in RGB_tuples]

colors = generate_distinct_colors(len(category_dict))
color_map = {name: color for name, color in zip(category_dict.values(), colors)}

def load_audio(file_path):
    try:
        y, sr = librosa.load(file_path, duration=5)  # Load first 5 seconds
        return y, sr
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None, None

df = pd.DataFrame({
    'x': tsne_results[:, 0],
    'y': tsne_results[:, 1],
    'label': test_labels,
    'predicted': test_predictions,
    'category': [category_dict[l] for l in test_labels],
    'predicted_category': [category_dict[p] for p in test_predictions],
    'file_path': test_file_paths
})

# Load audio data
print("Loading audio data...")
df['audio_data'] = df['file_path'].apply(load_audio)

def create_audio_widget(audio_data):
    y, sr = audio_data
    if y is not None and sr is not None:
        return ipyaudio(data=y, rate=sr)
    return None



Loading audio data...


In [None]:
# Create the plot
fig = go.Figure()

print('1')

for category in df['category'].unique():
    category_df = df[df['category'] == category]
    fig.add_trace(go.Scatter(
        x=category_df['x'],
        y=category_df['y'],
        mode='markers',
        name=category,
        marker=dict(color=color_map[category]),
        customdata=category_df[['category', 'predicted_category', 'file_path', 'audio_data']].values,
        hovertemplate=(
            "<b>Category:</b> %{customdata[0]}<br>"
            "<b>Predicted:</b> %{customdata[1]}<br>"
            "<b>File:</b> %{customdata[2]}<br>"
            "<extra></extra>"
        )
    ))

print('2')

fig.update_layout(
    title='T-SNE of Test Set Embeddings',
    legend_title_text='Categories',
    legend=dict(
        itemsizing='constant',
        title_font_family='Arial',
        font=dict(family='Arial', size=10),
        itemwidth=30
    )
)

print('3')

# Create an Output widget for the audio
output = widgets.Output()

@output.capture()
def on_click(trace, points, state):
    if points.point_inds:
        ind = points.point_inds[0]
        audio_data = trace.customdata[ind][3]
        audio_widget = create_audio_widget(audio_data)
        if audio_widget:
            display(audio_widget)

# Attach the click event to all traces
for trace in fig.data:
    trace.on_click(on_click)

print('4')
# Display the plot
display(fig)

# Display the output widget for audio
display(output)
print('5')

1
2
3
4
