In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import librosa
from sklearn.model_selection import train_test_split

from transformers import AutoModelForAudioClassification, AutoFeatureExtractor

from src.const import AUDIO_PATH, MAIN_LABELS, BATCH_SIZE, VALIDATION_SPLIT, SEED
from src.preprocess import load_and_preprocess, transform_to_data_loader, get_dl_for_pretrained
from src.preprocess_utils import normalize_ds

from torchinfo import summary
from torchaudio.transforms import MFCC


# summary(bin_bilstm_model, (1, 63, 128), device="cuda")

In [2]:
def train_model(
        model, 
        criterion, 
        optimizer, 
        train_loader, 
        val_loader, 
        model_type, 
        epoch_count,
        using_pretrained=False,
        early_stopping=False,
    ):
    
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience = 5
    no_improve_count = 0

    # Training
    for epoch in range(epoch_count):
        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0

        for batch_X, batch_y in tqdm(train_loader, f"Epoch {epoch+1}"):
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            model.zero_grad()
            outputs = model(batch_X)

            if using_pretrained:
                outputs = outputs.logits

            if model_type == "bin":
                loss = criterion(outputs, batch_y.unsqueeze(1))
            else:
                loss = criterion(outputs, batch_y.long())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Loss
            train_loss += loss.item() * batch_X.size(0)

            # Accuracy
            if model_type == "bin":
                predicted = (outputs > 0.5).float()
                correct_train += (predicted == batch_y.unsqueeze(1)).sum().item()
            else:
                predicted = torch.argmax(outputs, dim=1)
                correct_train += (predicted == batch_y).sum().item()
            
            
            total_train += batch_y.size(0)

        train_loss /= len(train_loader.dataset)
        train_acc = correct_train / total_train
        train_losses.append(train_loss)
        
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                outputs = model(batch_X)

                if using_pretrained:
                    outputs = outputs.logits

                if model_type == "bin":
                    loss = criterion(outputs, batch_y.unsqueeze(1))
                else:
                    loss = criterion(outputs, batch_y.long())
                
                # Loss
                val_loss += loss.item() * batch_X.size(0)

                # Accuracy
                if model_type == "bin":
                    predicted = (outputs > 0.5).float()
                    correct_val += (predicted == batch_y.unsqueeze(1)).sum().item()
                else:
                    predicted = torch.argmax(outputs, dim=1)
                    correct_val += (predicted == batch_y).sum().item()
                total_val += batch_y.size(0)

            val_loss /= len(val_loader.dataset)
            val_acc = correct_val / total_val
            val_losses.append(val_loss)
            
        print(f'Epoch {epoch+1}/{epoch_count}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}')
        
        
        if not early_stopping:
            continue
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve_count = 0
        else:
            no_improve_count += 1
        
        if no_improve_count >= patience:
            print('Early stopping')
            break
            
    return train_losses, val_losses

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
X, y = np.load("augmented_data.npy"), np.load("augmented_labels.npy")

In [13]:
transform = MFCC(sample_rate=16000, n_mfcc=20, melkwargs={"n_fft": 400, "hop_length": 40, "n_mels": 20, "center": False})
X_tensor = torch.from_numpy(X)

indices = [i for i in range(0, X_tensor.shape[0], 5000)]
indices.append(X_tensor.shape[0])

X_transformed = transform(X_tensor[indices[0]:indices[1]])
for i in range(1, len(indices)-1):
    X_transformed = np.concatenate((X_transformed, transform(X_tensor[indices[i]:indices[i+1]]).numpy()), 0)

In [7]:
# Main task with transformation

X_transformed_main = X_transformed[y != 10]
y_main = y[y != 10]

X_train, X_test, y_train, y_test = train_test_split(
    X_transformed_main, y_main, test_size=0.25, random_state=42
)

del X_transformed_main, y_main

In [None]:
# Binary task with transformation

# X_transformed_bin = X_transformed
# y_bin = create_binary_labels(y)

# X_train, X_test, y_train, y_test = train_test_split(
#     X_transformed_bin, y_bin, test_size=0.25, random_state=222
# )

# del X_transformed_bin, y_bin

In [None]:
# Main task without transformation

# X_main = X[y != 10]
# y_main = y[y != 10]

# X_train, X_test, y_train, y_test = train_test_split(
#     X_main, y_main, test_size=0.25, random_state=42
# )

# X_train = np.expand_dims(X_train, 1)
# X_test = np.expand_dims(X_test, 1)

# del X_main, y_main

In [9]:
train_dl = transform_to_data_loader(X_train, y_train, device=device)
val_dl = transform_to_data_loader(X_test, y_test, device=device)

del X_train, X_test, y_train, y_test

In [10]:
class main_Transformer(nn.Module):
    
    def __init__(self, d_model, n_head, num_layers, num_class):
        super().__init__()
        self.trans_enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, batch_first=True, activation="relu").to(device)
        self.transformer_encoder = nn.TransformerEncoder(self.trans_enc_layer, num_layers=num_layers).to(device)
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(d_model, 64)
        self.fc2 = nn.Linear(64, num_class)
        self.bc1 = nn.BatchNorm1d(d_model)
        self.bc2 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)

    def forward(self, x):
        # out = self.conv(x.unsqueeze(1)).squeeze(1)
        # out = self.dropout(out)
        # out = self.relu(out)
        out = self.transformer_encoder(x)[:, -1, :]
        out = self.bc1(out)
        out = self.fc1(out)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.bc2(out)
        out = self.fc2(out)
        out = self.dropout(out)
        return out

In [11]:
d_model = 98
n_head = 7
num_class = 10
num_layers = 2

model = main_Transformer(d_model, n_head, num_layers, num_class).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())



In [12]:
train_losses, val_losses = train_model(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_dl,
    val_loader=val_dl, 
    model_type="main", 
    epoch_count=50
)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1: 100%|██████████| 223/223 [00:04<00:00, 51.27it/s]


Epoch 1/50, Train Loss: 2.3499, Train Acc: 0.10, Val Loss: 2.3106, Val Acc: 0.10


Epoch 2: 100%|██████████| 223/223 [00:03<00:00, 65.84it/s]


Epoch 2/50, Train Loss: 2.3221, Train Acc: 0.10, Val Loss: 2.3107, Val Acc: 0.10


Epoch 3: 100%|██████████| 223/223 [00:03<00:00, 64.81it/s]


Epoch 3/50, Train Loss: 2.3146, Train Acc: 0.10, Val Loss: 2.3095, Val Acc: 0.10


Epoch 4: 100%|██████████| 223/223 [00:03<00:00, 66.19it/s]


Epoch 4/50, Train Loss: 2.3156, Train Acc: 0.10, Val Loss: 2.3089, Val Acc: 0.10


Epoch 5: 100%|██████████| 223/223 [00:03<00:00, 65.18it/s]


Epoch 5/50, Train Loss: 2.3106, Train Acc: 0.10, Val Loss: 2.3049, Val Acc: 0.10


Epoch 6: 100%|██████████| 223/223 [00:03<00:00, 63.79it/s]


Epoch 6/50, Train Loss: 2.3090, Train Acc: 0.10, Val Loss: 2.3046, Val Acc: 0.11


Epoch 7: 100%|██████████| 223/223 [00:03<00:00, 64.24it/s]


Epoch 7/50, Train Loss: 2.3082, Train Acc: 0.10, Val Loss: 2.3076, Val Acc: 0.10


Epoch 8: 100%|██████████| 223/223 [00:03<00:00, 65.50it/s]


Epoch 8/50, Train Loss: 2.3076, Train Acc: 0.10, Val Loss: 2.3054, Val Acc: 0.10


Epoch 9: 100%|██████████| 223/223 [00:03<00:00, 64.22it/s]


Epoch 9/50, Train Loss: 2.3059, Train Acc: 0.10, Val Loss: 2.3044, Val Acc: 0.10


Epoch 10: 100%|██████████| 223/223 [00:03<00:00, 65.21it/s]


Epoch 10/50, Train Loss: 2.3067, Train Acc: 0.10, Val Loss: 2.3048, Val Acc: 0.10


Epoch 11: 100%|██████████| 223/223 [00:03<00:00, 64.88it/s]


Epoch 11/50, Train Loss: 2.3072, Train Acc: 0.10, Val Loss: 2.3042, Val Acc: 0.10


Epoch 12: 100%|██████████| 223/223 [00:03<00:00, 64.27it/s]


Epoch 12/50, Train Loss: 2.3062, Train Acc: 0.10, Val Loss: 2.3072, Val Acc: 0.10


Epoch 13: 100%|██████████| 223/223 [00:03<00:00, 63.08it/s]


Epoch 13/50, Train Loss: 2.3063, Train Acc: 0.10, Val Loss: 2.3045, Val Acc: 0.10


Epoch 14: 100%|██████████| 223/223 [00:03<00:00, 60.19it/s]


Epoch 14/50, Train Loss: 2.3056, Train Acc: 0.10, Val Loss: 2.3084, Val Acc: 0.09


Epoch 15: 100%|██████████| 223/223 [00:03<00:00, 60.04it/s]


Epoch 15/50, Train Loss: 2.3059, Train Acc: 0.10, Val Loss: 2.3029, Val Acc: 0.10


Epoch 16: 100%|██████████| 223/223 [00:03<00:00, 63.51it/s]


Epoch 16/50, Train Loss: 2.3060, Train Acc: 0.10, Val Loss: 2.3057, Val Acc: 0.10


Epoch 17: 100%|██████████| 223/223 [00:03<00:00, 63.39it/s]


Epoch 17/50, Train Loss: 2.3057, Train Acc: 0.10, Val Loss: 2.3043, Val Acc: 0.10


Epoch 18: 100%|██████████| 223/223 [00:03<00:00, 62.62it/s]


Epoch 18/50, Train Loss: 2.3050, Train Acc: 0.10, Val Loss: 2.3047, Val Acc: 0.10


Epoch 19: 100%|██████████| 223/223 [00:03<00:00, 63.73it/s]


Epoch 19/50, Train Loss: 2.3057, Train Acc: 0.10, Val Loss: 2.3035, Val Acc: 0.10


Epoch 20: 100%|██████████| 223/223 [00:03<00:00, 62.09it/s]


Epoch 20/50, Train Loss: 2.3048, Train Acc: 0.10, Val Loss: 2.3040, Val Acc: 0.10


Epoch 21: 100%|██████████| 223/223 [00:03<00:00, 62.29it/s]


Epoch 21/50, Train Loss: 2.3052, Train Acc: 0.10, Val Loss: 2.3034, Val Acc: 0.10


Epoch 22: 100%|██████████| 223/223 [00:03<00:00, 61.92it/s]


Epoch 22/50, Train Loss: 2.3048, Train Acc: 0.10, Val Loss: 2.3039, Val Acc: 0.11


Epoch 23: 100%|██████████| 223/223 [00:03<00:00, 57.63it/s]


Epoch 23/50, Train Loss: 2.3050, Train Acc: 0.10, Val Loss: 2.3028, Val Acc: 0.11


Epoch 24: 100%|██████████| 223/223 [00:03<00:00, 63.04it/s]


Epoch 24/50, Train Loss: 2.3043, Train Acc: 0.10, Val Loss: 2.3028, Val Acc: 0.10


Epoch 25: 100%|██████████| 223/223 [00:03<00:00, 63.13it/s]


Epoch 25/50, Train Loss: 2.3044, Train Acc: 0.10, Val Loss: 2.3033, Val Acc: 0.10


Epoch 26: 100%|██████████| 223/223 [00:03<00:00, 64.31it/s]


Epoch 26/50, Train Loss: 2.3040, Train Acc: 0.10, Val Loss: 2.3028, Val Acc: 0.09


Epoch 27: 100%|██████████| 223/223 [00:03<00:00, 60.67it/s]


Epoch 27/50, Train Loss: 2.3044, Train Acc: 0.10, Val Loss: 2.3029, Val Acc: 0.09


Epoch 28: 100%|██████████| 223/223 [00:03<00:00, 62.26it/s]


Epoch 28/50, Train Loss: 2.3045, Train Acc: 0.10, Val Loss: 2.3040, Val Acc: 0.10


Epoch 29: 100%|██████████| 223/223 [00:03<00:00, 61.18it/s]


Epoch 29/50, Train Loss: 2.3042, Train Acc: 0.10, Val Loss: 2.3033, Val Acc: 0.10


Epoch 30: 100%|██████████| 223/223 [00:03<00:00, 58.21it/s]


Epoch 30/50, Train Loss: 2.3039, Train Acc: 0.10, Val Loss: 2.3031, Val Acc: 0.10


Epoch 31: 100%|██████████| 223/223 [00:03<00:00, 57.85it/s]


Epoch 31/50, Train Loss: 2.3038, Train Acc: 0.10, Val Loss: 2.3039, Val Acc: 0.10


Epoch 32: 100%|██████████| 223/223 [00:03<00:00, 60.92it/s]


Epoch 32/50, Train Loss: 2.3039, Train Acc: 0.10, Val Loss: 2.3047, Val Acc: 0.10


Epoch 33: 100%|██████████| 223/223 [00:03<00:00, 59.56it/s]


Epoch 33/50, Train Loss: 2.3041, Train Acc: 0.10, Val Loss: 2.3031, Val Acc: 0.09


Epoch 34:  78%|███████▊  | 175/223 [00:03<00:00, 57.90it/s]


KeyboardInterrupt: 