In [None]:
import torch
from torch.utils.data import DataLoader
from model import MaqamCNN
from dataset import MaqamDataset

# Define training and validation datasets with specified test size
train_dataset = MaqamDataset(mode='train', test_size=0.2)
val_dataset = MaqamDataset(mode='val', test_size=0.2)

# Define training and validation data loaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Initialize model and define loss function and optimizer
model = MaqamCNN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model for a specified number of epochs
num_epochs = 50

for epoch in range(num_epochs):
    # Training loop
    model.train()
    for i, (inputs, targets, mfcc) in enumerate(train_loader):
        optimizer.zero_grad()

        outputs = model(inputs, mfcc)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

    # Validation loop
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        total_correct = 0
        total_samples = 0
        for inputs, targets, mfcc in val_loader:
            outputs = model(inputs, mfcc)
            val_loss += criterion(outputs, targets).item() * len(targets)

            _, predicted_labels = torch.max(outputs, 1)
            total_correct += (predicted_labels == targets).sum().item()
            total_samples += len(targets)

        val_loss /= len(val_dataset)
        val_acc = float(total_correct) / total_samples

    print(f'Epoch {epoch + 1:02d}: train_loss={loss.item():.5f}, val_loss={val_loss:.5f}, val_acc={val_acc:.5f}')

# Save the trained model
torch.save(model.state_dict(), 'maqam_cnn.pth')

In [None]:
# Test the model on new data
test_dataset = MaqamDataset(mode='test')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model.eval()
with torch.no_grad():
    total_correct = 0
    total_samples = 0
    for inputs, targets, mfcc in test_loader:
        outputs = model(inputs, mfcc)
        _, predicted_labels = torch.max(outputs, 1)
        total_correct += (predicted_labels == targets).sum().item()
        total_samples += len(targets)

    test_acc = float(total_correct) / total_samples

print(f'Test accuracy: {test_acc:.5f}')
