In [15]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import WhisperForAudioClassification, WhisperConfig
import torch
import evaluate
import librosa
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

In [2]:
class KWS_dataset(Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, index):
        keyword = self.output_data[index]
        audio_features = self.input_data[index]
        return audio_features, keyword

In [33]:
path = "../data/"
train_dataloader = torch.load('../data/en_splits_30.trainloader')
dev_dataloader = torch.load('../data/en_splits_30.devloader')
test_dataloader = torch.load('../data/en_splits_30.testloader')

In [34]:
print(len(train_dataloader.dataset))
print(len(dev_dataloader.dataset))
print(len(test_dataloader.dataset))

26411
3284
3304


In [80]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision


class EfficientNetModel(nn.Module):
    def __init__(self, num_classes):
        super(EfficientNetModel, self).__init__()
        # Load EfficientNet-B0 as the base model
        self.efficient_b0_model = torchvision.models.efficientnet_b0(pretrained=True)

        # Add a global average pooling layer
        self.global_avg_pool = nn.AdaptiveAvgPool2d((None, 512))

        # Add two dense layers of 2048 units with ReLU activations
        self.linear1 = nn.Linear(512, 512)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(512, 512)
        self.relu2 = nn.ReLU()
        # Add a penultimate 1024-unit SELU activation layer
        self.linear3 = nn.Linear(512, 256)
        self.selu = nn.SELU()
        # add a softmax layer
        self.linear4 = nn.Linear(256, num_classes)
        self.softmax = nn.Softmax(dim=1)


    def forward(self, x):
        # print(f"input shape: {x.shape}")
        # Pass the input through the base model
        x = x.unsqueeze(1)
        # print(f"after unsqueeze: {x.shape}")
        x = x.repeat(1, 3, 1, 1)
        # print(f"after repeat: {x.shape}")
        x = self.efficient_b0_model(x)
        # print(f"after efficientnet: {x.shape}")
        # add a 1 to the first dimension
        x = x.unsqueeze(0)
        # print(f"after unsqueeze: {x.shape}")
        # Pass the output through the global average pooling layer
        x = self.global_avg_pool(x)
        # print(f"after global_avg_pool: {x.shape}")
        # pass the output through the dense layers
        # remove the first 1 in the shape
        x = x.squeeze(0)
        # print(f"after squeeze: {x.shape}")
        x = self.linear1(x)
        # print(f"after linear1: {x.shape}")
        x = self.relu1(x)
        # print(f"after relu1: {x.shape}")
        x = self.linear2(x)
        # print(f"after linear2: {x.shape}")
        x = self.relu2(x)
        # print(f"after relu2: {x.shape}")
        x = self.linear3(x)
        # print(f"after linear3: {x.shape}")
        x = self.selu(x)
        # print(f"after selu: {x.shape}")
        # pass the output through the softmax layer
        x  = self.linear4(x)
        # print(f"after linear4: {x.shape}")
        x = self.softmax(x)
        # print(f"after softmax: {x.shape}")

        return x


In [81]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
metric = evaluate.load("accuracy")

In [82]:
model = EfficientNetModel(31)
model.to(device)
optim = torch.optim.Adam(model.parameters())
loss_fn = nn.NLLLoss()

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="efficientnet",
    config= {
    "architecture": "efficientnet",
    "dataset": "en_30",
    "epochs": "10", 
    }
    
)

model.float()
epochs = 10

for epoch in range(epochs):
    model.train()
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        audio = batch[0].to(device)
        labels = batch[1].to(device)
        outputs = model(audio)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optim.step()
        
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            }, f'../model/whisper/epoch_{epoch+1}')
        
    model.eval()
        
    for batch in tqdm(dev_dataloader):
        audio = batch[0].to(device)
        labels = batch[1].to(device)
        outputs = model(audio)
        
        metric.add_batch(predictions=outputs.argmax(-1), references=labels)

    wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

100%|█████████████████████████████████████| 3302/3302 [1:37:16<00:00,  1.77s/it]
100%|███████████████████████████████████████████| 13/13 [01:06<00:00,  5.14s/it]
 44%|████████████████                     | 1438/3302 [48:57<1:04:38,  2.08s/it]