## Setup

In [None]:
import os, sys
sys.path.append(os.path.abspath('../../src/'))

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from model.spectogram_dataset import SpectrogramDataset

import wandb

wandb.login()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

## CNN

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Convolutional and pooling layers
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # Output: (6, 83, 82)
        self.pool = nn.MaxPool2d(2, 2)              # Output: (6, 41, 41)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5) # Output: (16, 37, 37)
        self.pool2 = nn.MaxPool2d(2, 2)             # Output: (16, 18, 18)
        
        # Fully connected layers
        self.fc1 = nn.Linear(16 * 18 * 18, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Output: (6, 41, 41)
        x = self.pool2(F.relu(self.conv2(x))) # Output: (16, 18, 18)
        x = torch.flatten(x, 1)              # Flatten to (batch_size, 5184)
        x = F.relu(self.fc1(x))              # Fully connected layer 1
        x = F.relu(self.fc2(x))              # Fully connected layer 2
        x = torch.sigmoid(self.fc3(x))       # Output layer
        return x

In [None]:
from sklearn.metrics import f1_score

def train(model, criterion, optimizer, num_epochs, train_loader, val_loader, model_name = "cnn"):
    model.to(device)
    
    PRINT_STEP = len(train_loader) // 5 - 1
    epochs_without_val_acc_improvement = 0
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        model.train()

        running_loss = 0.0
        correct = 0
        all_labels = []
        all_preds = []

        for  i, data in enumerate(train_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            labels = labels.unsqueeze(1)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds = (outputs > 0.5).float()  # Binary prediction with threshold 0.5
            all_labels.extend(labels.numpy().flatten())
            all_preds.extend(preds.detach().numpy().flatten())

            correct += (preds == labels).float().mean().item()
            
            running_loss += loss.item()

            if i % PRINT_STEP == PRINT_STEP-1:
                accuracy = correct / PRINT_STEP
                loss = running_loss / PRINT_STEP
                step = epoch * len(train_loader) + i
                wandb.log({
                        "train/accuracy": accuracy,
                        "train/loss": loss
                    },
                    step=step
                )
                running_loss = 0.0
                correct = 0

        f1 = f1_score(all_labels, all_preds, average='macro')
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}, F1 Score: {f1:.4f}")

        # Validation
        model.eval()

        val_loss = 0.0
        val_correct = 0

        with torch.no_grad():
            for j, data in enumerate(val_loader):
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                labels = labels.unsqueeze(1)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Binary predictions
                preds = (outputs > 0.5).float()
                val_correct += (preds == labels).float().mean().item()

        accuracy = val_correct / len(val_loader)
        loss = val_loss / len(val_loader)
        wandb.log({
                "validation/accuracy": accuracy,
                "validation/loss": loss
            },
            step=(epoch + 1) * len(train_loader)
        )

        if accuracy > best_val_acc:
            best_val_acc = accuracy
            epochs_without_val_acc_improvement = 0
        else:
            epochs_without_val_acc_improvement += 1
        if epochs_without_val_acc_improvement >= 10:
            print("10 epochs without a val accuracy improvement. Stopping the train")
            break

    torch.save(model, f"{model_name}.pth")
    print("Training complete.")
    
def test(model, test_loader):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on test images: {100 * correct // total} %')

## Data

In [None]:
from utils.dataset_creator import SpecgramsRandomFilter, SpecgramsSilentFilter

DATASETS_PARENT_PATH = 'datasets'
DATASET_PATH = 'dataset'
DATA_DIR = os.path.join('../../', DATASETS_PARENT_PATH, DATASET_PATH)

In [None]:
from prepare_datasets import create_datasets

create_datasets(DATASETS_PARENT_PATH, DATASET_PATH, [
    SpecgramsRandomFilter(),
    SpecgramsSilentFilter()
])

In [None]:
INPUT_WIDTH = 67
INPUT_HEIGHT = 66

LEARNING_RATE = 0.03
LR_DECAY = 0.95
BATCH_SIZE = 64
EPOCHS = 10

LEARNING_RATE_ARRAY = [0.03, 0.01, 0.003, 0.001]
BATCH_SIZE_ARRAY = [32, 64, 128]

## Training

In [None]:
wandb.init(
    project="iml_test",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_decay": LR_DECAY,
        "batch_size": BATCH_SIZE,
        "input_resolution": (INPUT_WIDTH, INPUT_HEIGHT),
        "architecture": "CNN",
        "dataset": "DAPS"
    }
)

In [None]:
model = Net()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
transform = transforms.ToTensor()

train_dataset = SpectrogramDataset(data_dir=os.path.join(DATA_DIR, "train"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = SpectrogramDataset(data_dir=os.path.join(DATA_DIR, "validation"), transform=transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_dataset = SpectrogramDataset(data_dir=os.path.join(DATA_DIR, "test"),transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

train(model, criterion, optimizer, EPOCHS, train_loader, val_loader, model_name="simple_cnn")
test(model, test_loader)

In [None]:
wandb.finish()