In [None]:
import torch
import torch.nn as nn

class _Branch(nn.Module):
    def __init__(self, channel_sequence, kernel_sizes, paddings, strides, pool_sizes, pool_strides, dropout_rates):
        super().__init__()
        layers = []
        in_channels = 1

        for i, (out_channels, k, p, s) in enumerate(zip(channel_sequence, kernel_sizes, paddings, strides)):
            layers.extend([
                nn.Conv1d(in_channels, out_channels, kernel_size=k, padding=p, stride=s, bias=False),
                nn.BatchNorm1d(out_channels),
                nn.ReLU()
            ])
            in_channels = out_channels

            # add pooling and dropout after each block except the last
            if i < len(pool_sizes):
                layers.append(nn.MaxPool1d(kernel_size=pool_sizes[i], stride=pool_strides[i]))
                if dropout_rates[i] > 0: # -> 0 dropout rate? Don't add the layer ya dingus
                    layers.append(nn.Dropout(dropout_rates[i]))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class CNN_BinaryClassifier(nn.Module):
    """
    n-branched convolutional neural network for binary sleep classification.
    Input is arbitrary, i don't really give a care. Please have at least about 1000 samples? please?

    An example of a model predicting N3 with 3000 samples as input:
    X -> 3000 samples of whatever data
    y -> 0 or 1
      0 -> NOT N3
      1 -> N3

    branch_configs = {
      "left": {
          "channel_sequence": [32, 64, 64],
          "kernel_sizes": [22, 8, 8],
          "paddings": [22//2, 3, 3],
          "strides": [6, 1, 1],
          "pool_sizes": [8, 4],
          "pool_strides": [8, 4],
          "dropout_rates": [0.1, 0.0]  # dropout only after first pool
        },
      "right": {
          "channel_sequence": [32, 64, 64],
          "kernel_sizes": [400, 6, 6],
          "paddings": [175, 2, 2],
          "strides": [50, 1, 1],
          "pool_sizes": [4, 2],
          "pool_strides": [4, 2],
          "dropout_rates": [0.1, 0.0]  # dropout only after first pool
        }
    }

    model_args = {
        "name": "MyN3Classifier",
        "n_samples": 3000,
        "branch_configs": branch_configs
    }

    model = SleepstageClassifier(**model_args)
    """
    WAKE = 0
    N1 = 1
    N2 = 2
    N3 = 3
    REM = 4

    def __init__(self, name, n_samples, branch_configs):
        super().__init__()
        self.name = name
        self.branches = nn.ModuleDict()
        self.branch_output_sizes = {}

        for name, config in branch_configs.items():
            self.branches[name] = _Branch(**config)

        # output sizes using dummy input
        with torch.inference_mode():
            dummy = torch.zeros(1, 1, n_samples)
            for name, branch in self.branches.items():
                branch.eval()
                out = branch(dummy)
                self.branch_output_sizes[name] = out.numel() // out.shape[0]
                branch.train()

        total_features = sum(self.branch_output_sizes.values())

        self.fc = nn.Sequential(
            nn.Linear(total_features, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 32), # 32 vector embedding :)
            nn.ReLU(),
        )

        self.classifier = nn.Linear(32, 1)

    def forward(self, x):
        outputs = [branch(x).flatten(1) for branch in self.branches.values()]
        combined = torch.cat(outputs, dim=1)
        x = self.fc(combined)
        return self.classifier(x)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def train_model(model, device, train_loader, test_loader, pos_weight, lr=2.5e-5, wd=1e-4, p=5, f=0.5, epochs=25, output_period=1, log_tensorboard=False):
    if log_tensorboard:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        writer = SummaryWriter(f'runs/experiment_{timestamp}')

    if device.type == "cpu":
        print("WARNING: Using CPU as device. This may take a while...")

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device)) # binary cross-entropy
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) # Adam optimizer for Windows
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=p, factor=f) # lowers learning rate by a factor of f every p epochs without significant gain in F1 score
    best_f1 = 0.0
    best_epoch = -1
    best_model_state = None
    print(f"running for {epochs} epochs.")

    train_losses_data, test_losses_data = [], []

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()

            optimizer.zero_grad()
            outputs = model(X_batch).squeeze()
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)

        model.eval()
        test_loss = 0.0
        all_preds, all_targets, all_probs = [], [], []

        with torch.inference_mode():
            for X_batch, y_batch in test_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
                outputs = model(X_batch).squeeze()
                loss = criterion(outputs, y_batch)
                test_loss += loss.item() * X_batch.size(0)

                probs = torch.sigmoid(outputs)
                preds = probs > 0.5

                all_probs.extend(probs.cpu().numpy().flatten())
                all_preds.extend(preds.cpu().numpy().flatten())
                all_targets.extend(y_batch.cpu().numpy().flatten())

        train_loss /= len(train_loader.dataset)
        test_loss /= len(test_loader.dataset)
        train_losses_data.append(train_loss)
        test_losses_data.append(test_loss)


        all_targets_np = np.array(all_targets)
        all_preds_np = np.array(all_preds)

        precision, recall, f1, _ = precision_recall_fscore_support(
            all_targets_np, all_preds_np, average="binary", zero_division=0
        )
        accuracy = accuracy_score(all_targets_np, all_preds_np)

        scheduler.step(f1)

        if log_tensorboard:
            writer.add_scalar('Loss/Train', train_loss, epoch)
            writer.add_scalar('Loss/Test', test_loss, epoch)
            writer.add_scalar('Metrics/Precision', precision, epoch)
            writer.add_scalar('Metrics/Recall', recall, epoch)
            writer.add_scalar('Metrics/F1', f1, epoch)
            writer.add_scalar('Metrics/Accuracy', accuracy, epoch)


        current_lr = optimizer.param_groups[0]["lr"]

        if log_tensorboard:
            writer.add_scalar('Learning Rate', current_lr, epoch)
            writer.add_pr_curve('PR_Curve',
                                np.array(all_targets),
                                np.array(all_probs),
                                global_step=epoch)

        if epoch % 5 == 0 and log_tensorboard:
            for name, param in model.named_parameters():
                writer.add_histogram(f'Weights/{name}', param, epoch)
                if param.grad is not None:
                    writer.add_histogram(f'Gradients/{name}', param.grad, epoch)

        if epoch % output_period == 0 or epoch == epochs-1:
            print(f"Epoch {epoch+1:2}/{epochs} -> "
                  f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:2.4f} | "
                  f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f} | "
                  f"Accuracy: {accuracy:.3f} ---> Learning rate: \x1b[31m{current_lr}\x1b[0m")

        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch+1
            best_model_state = model.state_dict().copy()

    if log_tensorboard:
        writer.close()


    plt.figure(figsize=(10, 6))
    plt.plot(range(1, epochs+1), train_losses_data, 'b-o', linewidth=2, markersize=4, label="Training Loss")
    plt.plot(range(1, epochs+1), test_losses_data, 'r--s', linewidth=2, markersize=4, label="Test Loss")
    plt.title("Training Progress", fontsize=14, fontweight="bold")
    plt.xlabel("Epochs", fontsize=12)
    plt.ylabel("Loss", fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.xticks(range(1, epochs+1, 5))
    plt.annotate(f"Final Train: {train_losses_data[-1]:.4f}",
                xy=(epochs, train_losses_data[-1]),
                xytext=(epochs-0.1*epochs, train_losses_data[-1]),
                arrowprops=dict(arrowstyle="->"))
    plt.annotate(f"Final Test: {test_losses_data[-1]:.4f}",
                xy=(epochs, test_losses_data[-1]),
                xytext=(epochs-0.1*epochs, test_losses_data[-1]*1.1),
                arrowprops=dict(arrowstyle="->"))
    plt.tight_layout()
    plt.show()

    print(f"\nBest model from epoch {best_epoch} with F1: {best_f1:.4f}")
    print(f"precision: {precision:.3f} | Recall: {recall:.3f} | Accuracy: {accuracy:.3f}")

    return best_model_state

In [None]:
import torch
import torch.nn as nn

# class _Branch(nn.Module):
#     def __init__(self, channel_sequence, kernel_sizes, paddings, strides, pool_sizes, pool_strides, dropout_rates):
#         super().__init__()
#         self.conv1 = nn.Conv1d(1, channel_sequence[0], kernel_sizes[0], stride=strides[0], padding=paddings[0])
#         self.conv2 = nn.Conv1d(channel_sequence[0], channel_sequence[1], kernel_sizes[1], stride=strides[1], padding=paddings[1])
#         self.conv3 = nn.Conv1d(channel_sequence[1], channel_sequence[2], kernel_sizes[2], stride=strides[2], padding=paddings[2])

#         self.pool1 = nn.MaxPool1d(pool_sizes[0], stride=pool_strides[0])
#         self.pool2 = nn.MaxPool1d(pool_sizes[1], stride=pool_strides[1])

#         self.dropout1 = nn.Dropout(dropout_rates[0])
#         self.relu = nn.ReLU()

#     def forward(self, x):
#         x = self.relu(self.conv1(x))
#         x = self.pool1(x)
#         x = self.dropout1(x)

#         x = self.relu(self.conv2(x))
#         x = self.pool2(x)

#         x = self.conv3(x)
#         return x

kz1 = 22
kz2 = 400
branch_configs = {
    "left": {
        "channel_sequence": [32, 64, 64],
        "kernel_sizes": [kz1, 8, 8],
        "paddings": [kz1//2 - 1, 3, 3],
        "strides": [6, 1, 1],
        "pool_sizes": [8, 4],
        "pool_strides": [8, 4],
        "dropout_rates": [0.1, 0.0]
    },
    "right": {
        "channel_sequence": [32, 64, 64],
        "kernel_sizes": [kz2, 8, 8],
        "paddings": [kz2//2 - 1, 3, 3],
        "strides": [50, 1, 1],
        "pool_sizes": [4, 2],
        "pool_strides": [4, 2],
        "dropout_rates": [0.1, 0.0]
    }
}

n_samples = 3000
dummy_input = torch.zeros(1, 1, n_samples)

print("Branch Output Dimensions:")
for branch_name, config in branch_configs.items():
    branch = _Branch(**config)

    with torch.inference_mode():
        output = branch(dummy_input)
        features = output.numel() // output.shape[0]

    print(f"{branch_name.upper()} branch:")
    print(f"\tOutput shape: {output.shape}")
    print(f"\tFeatures per sample: {features}\n")

Branch Output Dimensions:
LEFT branch:
	Output shape: torch.Size([1, 64, 14])
	Features per sample: 896

RIGHT branch:
	Output shape: torch.Size([1, 64, 6])
	Features per sample: 384

