In [1]:
# vision transformer for mfcc

In [1]:
import librosa
from torch import nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import os
import soundfile
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, random_split
from PIL import Image
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN_DIR = './data/train/binary_classification/yes_no/'

In [2]:
# loading data to dataset:
class MFCCDataset(Dataset):

    def __init__(self, directory):
        self.directory = directory
        self.classes = os.listdir(directory)
        self.class_to_num = {cl : i for i, cl in enumerate(self.classes)}
        self.num_to_class = {i : cl for i, cl in enumerate(self.classes)}
        paths = []
        for cl in self.classes:
            tmp = [os.path.join(directory+cl, path) for path in os.listdir(directory + cl)]
            paths+=tmp
        self.paths = paths
    
    
    def __len__(self):
        
        return len(self.paths)
    
    def __getitem__(self, index):
        audio_sample_path = self.paths[index]
        label = self.paths[index].split('/')[-1].split('\\')[0]

        signal, sr = torchaudio.load(audio_sample_path, format = 'wav')
        signal = signal.cpu().numpy().ravel()
        mfcc = librosa.feature.mfcc(y = signal, sr = sr)
        
        
        label_numeric = self.class_to_num[label]
        label_tensor = torch.tensor(label_numeric)
        label_tensor.to(device)
        
        
        return mfcc, label_tensor
        

In [2]:
dataset = datasets.ImageFolder(
    root='./data/spectrograms',
    transform=transforms.Compose([transforms.Resize((224,224)),
                                  transforms.ToTensor()
                                  ])
)
print(dataset)

Dataset ImageFolder
    Number of datapoints: 64721
    Root location: ./data/spectrograms
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
               ToTensor()
           )


In [4]:
from torch.utils.data import Subset
class_counts = {}
for _, label in dataset:
    label = label
    if label not in class_counts:
        class_counts[label] = 0
    class_counts[label] += 1

# Calculate the desired number of samples for each class in training and validation sets
total_samples = len(dataset)
train_ratio = 0.8  # Adjust as needed
train_class_counts = {label: int(train_ratio * count) for label, count in class_counts.items()}
val_class_counts = {label: count - train_class_counts[label] for label, count in class_counts.items()}

# Create samplers for training and validation sets while maintaining class balance
train_indices = []
val_indices = []
shuffled_dataset = DataLoader(dataset, shuffle=True).dataset
for idx, (_, label) in enumerate(shuffled_dataset):
    if train_class_counts[label] > 0:
        train_indices.append(idx)
        train_class_counts[label] -= 1
    else:
        val_indices.append(idx)

train_dataset = Subset(shuffled_dataset, train_indices)
validation_dataset = Subset(shuffled_dataset, val_indices)

In [3]:
len(train_dataset), len(validation_dataset)

NameError: name 'train_dataset' is not defined

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
validation_dataloader = DataLoader(validation_dataset, batch_size = 64)

In [11]:
from torchvision.models import vit_b_16

In [12]:
model = vit_b_16(pretrained = False)
model.heads.head = nn.Linear(in_features = 768, out_features = len(dataset.classes))
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.to(device)
num_epochs = 1  # Adjust as needed
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0.0

    for batch_idx, (images, labels) in enumerate(train_dataloader):
        optimizer.zero_grad()  # Zero out gradients
        outputs = model(images.to(device))
        loss = criterion(outputs, labels.to(device)) 
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters

        total_loss += loss.item()

        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_dataloader)}] Loss: {loss.item():.4f}", end='\r')

    average_loss = total_loss / len(train_dataloader)
    acc = 0
    for images, labels in validation_dataloader:
        images.to(device)
        model.eval()
        acc += (model(images).cpu() == labels).sum()
    acc /= len(validation_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {average_loss:.4f}, val_accuracy: {acc}")

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch [1/1] Batch [1/809] Loss: 3.5553
Epoch [1/1] Batch [11/809] Loss: 4.0157
Epoch [1/1] Batch [21/809] Loss: 3.7888
Epoch [1/1] Batch [31/809] Loss: 3.5109
Epoch [1/1] Batch [41/809] Loss: 3.4825
Epoch [1/1] Batch [51/809] Loss: 3.5076
Epoch [1/1] Batch [61/809] Loss: 3.6070
Epoch [1/1] Batch [71/809] Loss: 3.5664
Epoch [1/1] Batch [81/809] Loss: 3.4762
Epoch [1/1] Batch [91/809] Loss: 3.4384
Epoch [1/1] Batch [101/809] Loss: 3.5003
Epoch [1/1] Batch [111/809] Loss: 3.4293
Epoch [1/1] Batch [121/809] Loss: 3.4282
Epoch [1/1] Batch [131/809] Loss: 3.4775


KeyboardInterrupt: 