In [55]:
# import warnings
# warnings.simplefilter("ignore", category=UserWarning)
import logging

logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
logging.getLogger("torchvision.dataset").setLevel(logging.ERROR)
import numpy as np
import pandas as pd
import time
from functools import partial
import multiprocessing

n_cpu = multiprocessing.cpu_count()
import torch
import torch.nn as nn
from torchvision.datasets import DatasetFolder
from torch.utils.data import DataLoader, WeightedRandomSampler
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.utils.class_weight import compute_class_weight
from numpy import dtype
import torch.nn.functional as F
from torchmetrics.classification import BinaryAUROC

### data location ###
data_dir = "/Users/jrudoler/data/small_scalp_features/"
log_dir = "/Users/jrudoler/Library/CloudStorage/Box-Box/JR_CML/pytorch_logs/"

# Precondition Features

In [56]:
class LitPrecondition(pl.LightningModule):
    def __init__(self, input_dim, output_dim, learning_rate, weight_decay, batch_size):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim
        self.condition = nn.Sequential(
            nn.Conv1d(
                in_channels=input_dim,
                out_channels=2 * input_dim,
                kernel_size=2,
                padding=1,
                groups=input_dim,
            ),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=4),
            nn.Conv1d(
                in_channels=2 * input_dim,
                out_channels=2 * input_dim,
                kernel_size=4,
                padding=1,
                groups=2 * input_dim,
            ),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4),
            nn.Conv1d(
                in_channels=2 * input_dim,
                out_channels=input_dim,
                kernel_size=4,
                padding=1,
                groups=input_dim,
            ),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=4),
            nn.Flatten(),
        )
        self.logistic = nn.Sequential(nn.Linear(output_dim, 1, bias=True), nn.Sigmoid())
        self.save_hyperparameters()

    def forward(self, x):
        x_cond = self.condition(x)
        probs = self.logistic(x_cond)
        return probs

    def training_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log(
            "Loss/train", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True
        )
        auroc = BinaryAUROC()
        train_auc = auroc(y_hat, y)
        self.log("AUC/train", train_auc, on_epoch=True, on_step=False, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log("Loss/test", loss, on_epoch=True, on_step=False, logger=True)
        auroc = BinaryAUROC()
        test_auc = auroc(y_hat, y)
        self.log(
            "AUC/test",
            test_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )

    def configure_optimizers(self) -> Any:
        self.logger.log_hyperparams(
            {
                "learning_rate": self.hparams["learning_rate"],
                "weight_decay": self.hparams["weight_decay"],
                "batch_size": self.hparams["batch_size"],
            }
        )
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams["learning_rate"],
            weight_decay=self.hparams["weight_decay"],
        )
        lr_scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer, threshold=1e-4, verbose=True
            ),
            # The unit of the scheduler's step size, 'epoch' or 'step'.
            "interval": "epoch",
            # How many epochs/steps should pass between calls to
            # `scheduler.step()`. 1 corresponds to updating the learning
            # rate after every epoch/step.
            "frequency": 1,
            # Metric to to monitor for schedulers like `ReduceLROnPlateau`
            "monitor": "Loss/train",
            # If set to `True`, will enforce that the value specified 'monitor'
            # is available when the scheduler is updated
            "strict": True,
            # If using the `LearningRateMonitor` callback to monitor the
            # learning rate progress, this keyword can be used to specify
            # a custom logged name
            "name": "learning_rate",
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

In [58]:
### HYPERPARAMETERS ####
learning_rate = 1e-2
weight_decay = 1e-4  # 1e-4
batch_size = 512
########################
timestr = time.strftime("%Y%m%d-%H%M%S")
_ = pl.seed_everything(56)
subject = "LTP093"
test_result = []
for sess in range(24):
    try:
        test_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and s.count(f"sess_{sess}")
        )
        test_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=test_file_crit,
        )
        train_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and not s.count(f"sess_{sess}")
        )
        train_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=train_file_crit,
        )
    except FileNotFoundError:
        print(f"no session {sess}")
        continue
    ## class balancing ##
    cls_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(train_dataset.targets),
        y=train_dataset.targets,
    )
    weights = cls_weights[train_dataset.targets]
    sampler = WeightedRandomSampler(weights, len(train_dataset), replacement=True)  # type: ignore
    ## data loaders ##
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        num_workers=n_cpu,
        prefetch_factor=10,
        persistent_workers=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=len(test_dataset),
        shuffle=False,
        pin_memory=True,
        num_workers=n_cpu,
    )
    ## create model ##
    n_features = train_dataset[0][0].shape[0]
    model = LitPrecondition(
        n_features, n_features, learning_rate, weight_decay, batch_size
    )
    es = EarlyStopping("Loss/train", min_delta=0.0001, patience=25, mode="min")
    lr_mtr = LearningRateMonitor("epoch")
    check = ModelCheckpoint(monitor="AUC/train", mode="max")
    run_dir = f"run_{subject}_{sess}_{timestr}"
    logger = TensorBoardLogger(
        save_dir=log_dir, name="precondition", version=run_dir, default_hp_metric=False
    )
    trainer = Trainer(
        min_epochs=75,
        max_epochs=200,
        accelerator="mps",
        devices=1,
        callbacks=[lr_mtr, es, check],
        logger=logger,
        log_every_n_steps=10,
    )
    # trainer.logger._default_hp_metric = None
    trainer.fit(model, train_dataloaders=train_dataloader)
    model = LitPrecondition.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path  # type: ignore
    )  # Load best checkpoint after training
    test_result += trainer.test(model, dataloaders=test_dataloader, verbose=False)
    test_result[-1].update({"subject": subject, "session": sess})

Global seed set to 56
  rank_zero_warn(


Epoch 64: 100%|██████████| 7/7 [00:01<00:00,  5.82it/s, v_num=3738, Loss/train=0.692]Epoch 00065: reducing learning rate of group 0 to 1.0000e-03.
Epoch 75: 100%|██████████| 7/7 [00:01<00:00,  5.89it/s, v_num=3738, Loss/train=0.692]Epoch 00076: reducing learning rate of group 0 to 1.0000e-04.
Epoch 78: 100%|██████████| 7/7 [00:01<00:00,  5.87it/s, v_num=3738, Loss/train=0.692]


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 16.65it/s]


  rank_zero_warn(


Epoch 1:  25%|██▌       | 1/4 [00:00<00:00,  3.00it/s, v_num=3738, Loss/train=0.694]

In [17]:
result_df = pd.DataFrame(test_result)

In [18]:
result_df

Unnamed: 0,Loss/test,AUC/test,subject,session
0,0.691342,0.463612,LTP093,0.0
1,0.693647,0.492343,,
2,0.69162,0.536043,,
3,0.692682,0.498934,,
4,0.690038,0.594484,,
5,0.69346,0.510938,,
6,0.693099,0.471304,,
7,0.695018,0.452026,,
8,0.694144,0.504188,,
9,0.688533,0.574561,,


In [9]:
result_df.to_csv("precond_results.csv")