In [2]:
# import warnings
# warnings.simplefilter("ignore", category=UserWarning)
import logging

logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
logging.getLogger("torchvision.dataset").setLevel(logging.ERROR)
import numpy as np
import pandas as pd
from functools import partial
import multiprocessing

n_cpu = multiprocessing.cpu_count()
import torch
import torch.nn as nn
from torchvision.datasets import DatasetFolder
from torch.utils.data import DataLoader, WeightedRandomSampler
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from sklearn.utils.class_weight import compute_class_weight
from numpy import dtype
import torch.nn.functional as F
from torchmetrics.classification import BinaryAUROC

_ = pl.seed_everything(56)
### data location ###
data_dir = "/Users/jrudoler/data/small_scalp_features/"

Global seed set to 56


In [3]:
class SmallerPrecondFeatLogisticRegressionTorch(torch.nn.Module):
    def __init__(self, input_dim, output_dim=None):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim
        self.condition = nn.Sequential(
            nn.Conv1d(
                in_channels=input_dim,
                out_channels=2 * input_dim,
                kernel_size=2,
                padding=1,
                groups=input_dim,
            ),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=4),
            nn.Conv1d(
                in_channels=2 * input_dim,
                out_channels=2 * input_dim,
                kernel_size=4,
                padding=1,
                groups=2 * input_dim,
            ),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4),
            nn.Conv1d(
                in_channels=2 * input_dim,
                out_channels=input_dim,
                kernel_size=4,
                padding=1,
                groups=input_dim,
            ),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=4),
            nn.Flatten(),
        )
        self.logistic = nn.Sequential(nn.Linear(output_dim, 1, bias=True), nn.Sigmoid())

    def forward(self, x):
        x_cond = self.condition(x)
        probs = self.logistic(x_cond)
        return probs

In [14]:
### HYPERPARAMETERS ####
lr = 1e-2
weight_decay = 1e-4  # 1e-4
batch_size = 256
########################


class LitPrecondition(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args: Any, **kwargs: Any):
        return self.model.forward(*args, **kwargs)

    def training_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.model(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log(
            "Loss/train", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True
        )
        auroc = BinaryAUROC()
        train_auc = auroc(y_hat, y)
        self.log("AUC/train", train_auc, on_epoch=True, on_step=False, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.model(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log("Loss/test", loss, on_epoch=True, on_step=False, logger=True)
        auroc = BinaryAUROC()
        test_auc = auroc(y_hat, y)
        self.log(
            "AUC/test",
            test_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )

    def configure_optimizers(self) -> Any:
        self.logger.log_hyperparams(
            {
                "learning_rate": lr,
                "weight_decay": weight_decay,
                "batch_size": batch_size,
            }
        )
        optimizer = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=weight_decay)
        lr_scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer, threshold=1e-3
            ),
            # The unit of the scheduler's step size, 'epoch' or 'step'.
            "interval": "epoch",
            # How many epochs/steps should pass between calls to
            # `scheduler.step()`. 1 corresponds to updating the learning
            # rate after every epoch/step.
            "frequency": 1,
            # Metric to to monitor for schedulers like `ReduceLROnPlateau`
            "monitor": "Loss/train",
            # If set to `True`, will enforce that the value specified 'monitor'
            # is available when the scheduler is updated
            "strict": True,
            # If using the `LearningRateMonitor` callback to monitor the
            # learning rate progress, this keyword can be used to specify
            # a custom logged name
            "name": None,
        }
        return

In [15]:
subject = "LTP093"
sess = 5
for sess in [5]:
    try:
        test_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and s.count(f"sess_{sess}")
        )
        test_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=test_file_crit,
        )
        train_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and not s.count(f"sess_{sess}")
        )
        train_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=train_file_crit,
        )
    except FileNotFoundError:
        print(f"no session {sess}")
        continue
    ## class balancing ##
    cls_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(train_dataset.targets),
        y=train_dataset.targets,
    )
    weights = cls_weights[train_dataset.targets]
    sampler = WeightedRandomSampler(weights, len(train_dataset), replacement=True)
    ## data loaders ##
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        num_workers=n_cpu,
        prefetch_factor=10,
        persistent_workers=True,
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=len(test_dataset), shuffle=False, pin_memory=True
    )
    ## create model ##
    n_features = train_dataset[0][0].shape[0]
    model = LitPrecondition(SmallerPrecondFeatLogisticRegressionTorch(n_features))
    es = EarlyStopping("AUC/train", min_delta=0.01, patience=15)
    trainer = Trainer(min_epochs=10, max_epochs=150, accelerator="mps", devices=1)
    trainer.logger._default_hp_metric = None
    trainer.fit(model, train_dataloaders=train_dataloader)
    test_result = trainer.test(model, dataloaders=test_dataloader, verbose=False)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                                      | Params
--------------------------------------------------------------------
0 | model | SmallerPrecondFeatLogisticRegressionTorch | 25.8 K
--------------------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=150` reached.


Testing: 0it [00:00, ?it/s]

In [7]:
pd.DataFrame(test_result)

Unnamed: 0,Loss/test,AUC/test
0,0.695903,0.431641
