In [1]:
import sys

sys.path.insert(0, '../')

from modules.spec_dataset import *
from modules.train_prep import *
from modules.plot_results import *

In [2]:
class SpectrogramCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(SpectrogramCNN, self).__init__()
        # Reduce the number of convolutional layers and channels
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 10), padding=(1, 5))  # One layer, more channels
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d((2, 4))  # Pooling to reduce dimensions

        # Simplified dropout and fewer fully connected layers
        self.dropout = nn.Dropout(0.25)
        
        # The fully connected layer sizes need to be adjusted based on the actual output dimensions
        # Placeholder dimensions for illustration; you will need to calculate the exact number based on your input size after pooling
        self.fc1 = nn.Linear(430080, num_classes)  # Direct connection to output

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(cuda_device)
elif torch.backends.mps.is_available():
    mps_device = torch.device("mps")

data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

# Setup datasets and loaders
dsets = {
    'train': SpectroDataset(train_dir, train_behav_file, data_transforms['train']),
    'val': SpectroDataset(val_dir, val_behav_file, data_transforms['val'])
}
dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

dset_loaders = {}
for split in ['train', 'val']:
    targets = np.array([dsets[split].get_label(i) for i in range(len(dsets[split]))])
    class_counts = dsets[split].get_class_counts()
    class_weights = np.array([1.0 / class_counts[label] if class_counts[label] > 0 else 0 for label in targets])
    sampler = WeightedRandomSampler(class_weights, num_samples=len(class_weights), replacement=True)
    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader:', split)

# Initialize model, criterion, and optimizer
model_ft = SpectrogramCNN(num_classes=4)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001, weight_decay=1e-5)

# Device configuration with MPS support
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model_ft.to(device)
criterion.to(device)
print(f"Training on {device}")

# Train model
model_ft, accuracies, losses, preds, labels = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dset_loaders, dset_sizes, num_epochs=n_epochs)

In [None]:
# Output results
for split in ['train', 'val']:
    print(split, 'accuracies by epoch:', accuracies[split])
    print(split, 'losses by epoch:', losses[split])

for key, value in accuracies.items():
    new_value = [v.cpu().numpy() for v in value]
    accuracies[key] = new_value

torch.save(model_ft.state_dict(), '../models/CNN_spec_best_model.pt')

plot_training_history(accuracies, losses, 'CNN_spec')
plot_confusion_matrix(labels, preds, 'CNN_spec')