In [None]:
# Install dependencies for the project
!pip install -r requirements.txt

In [None]:
import os

# Make sure the path points to the correct directory
raw_data_dir = os.path.join(os.getcwd(), 'data/raw') # /data/raw/EG-IPT/pickups
csv_dir = os.path.join(os.getcwd(), 'data/dataset')

# Setup A uses only DI data
train_dir = os.path.join(raw_data_dir, 'HB-neck/DI') 
val_dir = os.path.join(raw_data_dir, 'HB-bridge/DI')
test_dir = os.path.join(raw_data_dir, 'HB-couple/DI')

# Name your run
name = 'run01'
csv_path = os.path.join(csv_dir, f'{name}_dataset_split.csv')
print(csv_path)

In [None]:
from utils import DatasetSplitter, DatasetValidator

DatasetSplitter.split_train_validation(csv_path, train_dir, test_dir, val_dir, name) # Add val_ratio=0.2, val_split='test' for setup C
DatasetValidator.validate_labels(csv_path)

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from utils import PrepareData

target_sr = 8000
batch_size = 32
augment = False # Set to True for offline data augmentation
dataPreparator = PrepareData(csv_path, device, target_sr, batch_size, augment)
train_loader, test_loader, val_loader, num_classes, classnames, segment_length = dataPreparator.prepare()

In [None]:
from utils import PrepareModel
from augments import AudioOnlineTransforms
from torch import nn
from torch.optim import Adam
import numpy as np

modelPreparator = PrepareModel(device, num_classes, segment_length, target_sr, classnames)
model = modelPreparator.prepare()

loss_fn = nn.CrossEntropyLoss() 
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Online data augmentation is applied across all setups
augmenter = AudioOnlineTransforms(target_sr, segment_length)

In [None]:
from tqdm import tqdm

max_val_loss = np.inf
max_epochs = 100

for epoch in range(max_epochs):
    # -- TRAINING --
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs}", leave=False):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = augmenter(inputs)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # -- VALIDATION --
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_loader)

    # -- LOGGING --
    tqdm.write(f"Epoch {epoch+1}/{max_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # -- SAVE BEST --
    if val_loss < max_val_loss:
        max_val_loss = val_loss
        torch.save(model.state_dict(), f"runs/{name}_best_model.pth")

In [None]:
model.load_state_dict(torch.load(f"runs/{name}_best_model.pth"))
model.eval()

In [None]:
from sklearn.metrics import accuracy_score, f1_score

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)
        preds = outputs.argmax(dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='macro')

print(f"Accuracy: {acc:.4f}")
print(f"Macro F1-score: {f1:.4f}")

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

cm = confusion_matrix(all_labels, all_preds)

cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(cm_normalized, cmap='Blues')

ax.set_xlabel("Predicted")
ax.set_ylabel("True")

ax.set_xticks(np.arange(len(classnames)))
ax.set_yticks(np.arange(len(classnames)))
ax.set_xticklabels(classnames)
ax.set_yticklabels(classnames)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        text = f"{cm_normalized[i, j]*100:.1f}"
        ax.text(j, i, text, ha="center", va="center", color="black")

fig.colorbar(im)
plt.tight_layout()
plt.show()

In [None]:
model = model.to('cpu')
scripted_model = torch.jit.script(model)
scripted_model.save(f'runs/{name}_model.ts')
print(f'TorchScript file has been exported to the {os.getcwd()}/runs directory.')