In [1]:
import warnings

warnings.simplefilter("ignore", category=UserWarning)
import logging

logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
logging.getLogger("torchvision.dataset").setLevel(logging.ERROR)
import numpy as np
import pandas as pd
import time
from functools import partial
import multiprocessing

n_cpu = multiprocessing.cpu_count()
import torch
import torch.nn as nn
from torchvision.datasets import DatasetFolder
from torch.utils.data import DataLoader, WeightedRandomSampler
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.utils.class_weight import compute_class_weight
from numpy import dtype
import torch.nn.functional as F
from torchmetrics.classification import BinaryAUROC

### data location ###
data_dir = "/Users/jrudoler/data/small_scalp_features/"
log_dir = "/Users/jrudoler/Library/CloudStorage/Box-Box/JR_CML/pytorch_logs/"

# Precondition Features

In [3]:
class LitPrecondition(pl.LightningModule):
    def __init__(self, input_dim, output_dim, learning_rate, weight_decay, batch_size):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim
        self.condition = nn.Sequential(
            nn.Conv1d(
                in_channels=input_dim,
                out_channels=2 * input_dim,
                kernel_size=2,
                padding=1,
                groups=input_dim,
            ),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=4),
            nn.Conv1d(
                in_channels=2 * input_dim,
                out_channels=2 * input_dim,
                kernel_size=4,
                padding=1,
                groups=2 * input_dim,
            ),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4),
            nn.Conv1d(
                in_channels=2 * input_dim,
                out_channels=input_dim,
                kernel_size=4,
                padding=1,
                groups=input_dim,
            ),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=4),
            nn.Flatten(),
        )
        self.logistic = nn.Sequential(nn.Linear(output_dim, 1, bias=True), nn.Sigmoid())
        self.save_hyperparameters()

    def forward(self, x):
        x_cond = self.condition(x)
        probs = self.logistic(x_cond)
        return probs

    def training_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log(
            "Loss/train",
            loss,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )
        auroc = BinaryAUROC()
        train_auc = auroc(y_hat, y)
        self.log("AUC/train", train_auc, on_epoch=True, on_step=False, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log("Loss/test", loss, on_epoch=True, on_step=False, logger=True)
        auroc = BinaryAUROC()
        test_auc = auroc(y_hat, y)
        self.log(
            "AUC/test",
            test_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )

    def test_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        # self.log("Loss/test", loss, on_epoch=True, on_step=False, logger=True)
        auroc = BinaryAUROC()
        test_auc = auroc(y_hat, y)
        self.log(
            "AUC/test",
            test_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )

    def configure_optimizers(self) -> Any:
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams["learning_rate"],
            weight_decay=self.hparams["weight_decay"],
        )
        lr_scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer,
                threshold=1e-4,
                threshold_mode="rel",
                patience=10,
                verbose=False,
            ),
            # The unit of the scheduler's step size, 'epoch' or 'step'.
            "interval": "epoch",
            # How many epochs/steps should pass between calls to
            # `scheduler.step()`. 1 corresponds to updating the learning
            # rate after every epoch/step.
            "frequency": 1,
            # Metric to to monitor for schedulers like `ReduceLROnPlateau`
            "monitor": "Loss/train",
            # If set to `True`, will enforce that the value specified 'monitor'
            # is available when the scheduler is updated
            "strict": True,
            # If using the `LearningRateMonitor` callback to monitor the
            # learning rate progress, this keyword can be used to specify
            # a custom logged name
            "name": "learning_rate",
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

In [5]:
### HYPERPARAMETERS ####
learning_rate = 1e-2
weight_decay = 0.5
batch_size = 256
########################
timestr = time.strftime("%Y%m%d-%H%M%S")
_ = pl.seed_everything(56)
subject = "LTP093"
test_result = []
for sess in range(24):
    try:
        test_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and s.count(f"sess_{sess}")
        )
        test_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=test_file_crit,
        )
        train_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and not s.count(f"sess_{sess}")
        )
        train_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=train_file_crit,
        )
    except FileNotFoundError:
        print(f"no session {sess}")
        test_result += [{"subject": subject, "session": sess}]
        continue
    ## class balancing ##
    cls_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(train_dataset.targets),
        y=train_dataset.targets,
    )
    weights = cls_weights[train_dataset.targets]
    sampler = WeightedRandomSampler(weights, len(train_dataset), replacement=True)  # type: ignore
    ## data loaders ##
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        num_workers=n_cpu,
        prefetch_factor=10,
        persistent_workers=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=len(test_dataset),
        shuffle=False,
        pin_memory=True,
    )
    ## create model ##
    n_features = train_dataset[0][0].shape[0]
    model = LitPrecondition(
        n_features, n_features, learning_rate, weight_decay, batch_size
    )
    es = EarlyStopping("Loss/train", min_delta=1e-3, patience=25, mode="min")
    lr_mtr = LearningRateMonitor("epoch")
    check = ModelCheckpoint(monitor="AUC/train", mode="max")
    run_dir = f"run_{subject}_{sess}_{timestr}"
    logger = TensorBoardLogger(
        save_dir=log_dir,
        name="precondition",
        version=run_dir,
        default_hp_metric=False,
    )
    logger.log_hyperparams(
        {
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "batch_size": batch_size,
        },
        {"AUC/train": 0, "Loss/train": 0, "AUC/test": 0, "Loss/test": 0},
    )
    trainer = Trainer(
        min_epochs=50,
        max_epochs=500,
        accelerator="mps",
        devices=1,
        callbacks=[lr_mtr, es, check],
        logger=logger,
        log_every_n_steps=5,
    )
    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=test_dataloader,
    )
    model = LitPrecondition.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path  # type: ignore
    )  # Load best checkpoint after training
    test_result += trainer.test(model, dataloaders=test_dataloader, verbose=False)
    test_result[-1].update({"subject": subject, "session": sess})
    torch.mps.empty_cache()
    result_df = pd.DataFrame(test_result)
    result_df.to_csv(log_dir + f"precond_results_LTP093_{timestr}.csv")
    result_df.to_csv(f"precond_results_LTP093_{timestr}.csv")

Global seed set to 56


NameError: name 'LitPrecondition' is not defined

# Logistic Regression

In [2]:
class LitLogisticRegression(pl.LightningModule):
    def __init__(self, input_dim, learning_rate, weight_decay, batch_size):
        super().__init__()
        self.logistic = nn.Sequential(nn.Linear(input_dim, 1, bias=True), nn.Sigmoid())
        self.save_hyperparameters()

    def forward(self, X):
        probs = self.logistic(X)
        return probs

    def training_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log(
            "Loss/train",
            loss,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )
        auroc = BinaryAUROC()
        train_auc = auroc(y_hat, y)
        self.log(
            "AUC/train",
            train_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log(
            "Loss/test",
            loss,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )
        auroc = BinaryAUROC()
        test_auc = auroc(y_hat, y)
        self.log(
            "AUC/test",
            test_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )

    def test_step(self, batch, batch_idx):
        X, y = batch
        X, y = X.float(), y.float()
        y_hat = torch.squeeze(self.forward(X))
        loss = F.binary_cross_entropy(y_hat, y)
        self.log("Loss/test", loss, on_epoch=True, on_step=False, logger=True)
        auroc = BinaryAUROC()
        test_auc = auroc(y_hat, y)
        self.log(
            "AUC/test",
            test_auc,
            on_epoch=True,
            on_step=False,
            prog_bar=False,
            logger=True,
        )

    def configure_optimizers(self) -> Any:
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams["learning_rate"],
            weight_decay=self.hparams["weight_decay"],
        )
        lr_scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer,
                threshold=1e-4,
                patience=10,
                verbose=True,
            ),
            # The unit of the scheduler's step size, 'epoch' or 'step'.
            "interval": "epoch",
            # How many epochs/steps should pass between calls to
            # `scheduler.step()`. 1 corresponds to updating the learning
            # rate after every epoch/step.
            "frequency": 1,
            # Metric to to monitor for schedulers like `ReduceLROnPlateau`
            "monitor": "Loss/train",
            # If set to `True`, will enforce that the value specified 'monitor'
            # is available when the scheduler is updated
            "strict": True,
            # If using the `LearningRateMonitor` callback to monitor the
            # learning rate progress, this keyword can be used to specify
            # a custom logged name
            "name": "learning_rate",
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

In [None]:
### HYPERPARAMETERS ####
learning_rate = 1e-2
weight_decay = 1
batch_size = 256
########################
timestr = time.strftime("%Y%m%d-%H%M%S")
_ = pl.seed_everything(56)
subject = "LTP093"
test_result = []
for sess in range(24):
    try:
        test_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and s.count(f"sess_{sess}")
        )
        test_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=test_file_crit,
            transform=partial(torch.mean, dim=-1),
        )
        train_file_crit = (
            lambda s: s.endswith(".pt")
            and s.count(f"sub_{subject}")
            and not s.count(f"sess_{sess}")
        )
        train_dataset = DatasetFolder(
            data_dir,
            loader=partial(torch.load),
            is_valid_file=train_file_crit,
            transform=partial(torch.mean, dim=-1),
        )
    except FileNotFoundError:
        print(f"no session {sess}")
        test_result += [{"subject": subject, "session": sess}]
        continue
    ## class balancing ##
    cls_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(train_dataset.targets),
        y=train_dataset.targets,
    )
    weights = cls_weights[train_dataset.targets]
    sampler = WeightedRandomSampler(weights, len(train_dataset), replacement=True)  # type: ignore
    ## data loaders ##
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        num_workers=n_cpu,
        prefetch_factor=10,
        persistent_workers=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=len(test_dataset),
        shuffle=False,
        pin_memory=True,
    )
    ## create model ##
    n_features = train_dataset[0][0].shape[0]
    model = LitLogisticRegression(n_features, learning_rate, weight_decay, batch_size)
    es = EarlyStopping("Loss/train", min_delta=0.0001, patience=25, mode="min")
    lr_mtr = LearningRateMonitor("epoch")
    check = ModelCheckpoint(monitor="AUC/train", mode="max")
    run_dir = f"run_{subject}_{sess}_{timestr}"
    logger = TensorBoardLogger(
        save_dir=log_dir,
        name="logreg",
        version=run_dir,  # , default_hp_metric=False
    )
    logger.log_hyperparams(
        {
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "batch_size": batch_size,
        },
        {"AUC/train": 0, "Loss/train": 0, "AUC/test": 0, "Loss/test": 0},
    )
    trainer = Trainer(
        min_epochs=50,
        max_epochs=500,
        accelerator="mps",
        devices=1,
        callbacks=[lr_mtr, es, check],
        logger=logger,
        log_every_n_steps=10,
        # fast_dev_run=True
    )
    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=test_dataloader,
    )
    model = LitLogisticRegression.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path  # type: ignore
    )  # Load best checkpoint after training
    test_result += trainer.test(model, dataloaders=test_dataloader, verbose=False)
    test_result[-1].update({"subject": subject, "session": sess})
    torch.mps.empty_cache()
    result_df = pd.DataFrame(test_result)
    result_df.to_csv(log_dir + f"logreg_results_LTP093_{timestr}.csv")
    result_df.to_csv(f"logreg_results_LTP093_{timestr}.csv")

### Saving and loading models

In [35]:
model = torch.load(
    "/Users/jrudoler/Library/CloudStorage/Box-Box/JR_CML/pytorch_logs/models/foundation_model_20230427-152237.pt"
).cpu()

In [31]:
model.hparams_initial

"batch_size":    256
"input_dim":     992
"learning_rate": 0.01
"output_dim":    992
"weight_decay":  0.5

In [32]:
model.hparams["batch_size"] = 512

In [36]:
model.hparams

"batch_size":    256
"input_dim":     992
"learning_rate": 0.01
"output_dim":    992
"weight_decay":  0.5