Vision Transformer with Spectrogram and Shapley Value Analysis

A modular implementation for training ViT models on spectrogram data with Shapley value interpretation.


In [1]:
import torch
import torch.nn as nn
import torchaudio
from vit_pytorch import SimpleViT
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.metrics import accuracy_score, roc_auc_score
import numpy as np
import json
from dataclasses import dataclass
from pathlib import Path
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

In [2]:
@dataclass
class ModelConfig:
    """Configuration for the ViT model and training parameters."""

    image_size: tuple[int, int]
    patch_size: int
    num_classes: int
    dim: int
    depth: int
    heads: int
    mlp_dim: int
    channels: int
    dim_head: int
    learning_rate: float = 1e-4
    batch_size: int = 32
    max_epochs: int = 100
    train_split: float = 0.8


@dataclass
class SpectrogramConfig:
    """Configuration for spectrogram transformation."""

    n_fft: int = 4
    win_length: int | None = None
    hop_length: int | None = None

In [3]:
class OutputFunction(nn.Module):
    """Output function with sigmoid activation and rounding."""

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.round(torch.sigmoid(x))


class SpectrogramTransform(nn.Module):
    """Spectrogram transformation module."""

    def __init__(self, config: SpectrogramConfig):
        super().__init__()
        self.transform = torchaudio.transforms.Spectrogram(
            n_fft=config.n_fft, win_length=config.win_length, hop_length=config.hop_length
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.transform(x)


class VisionTransformerModel(nn.Module):
    """Combined ViT model with spectrogram transformation."""

    def __init__(self, model_config: ModelConfig, spec_config: SpectrogramConfig):
        super().__init__()
        self.spectrogram = SpectrogramTransform(spec_config)
        self.vit = SimpleViT(
            image_size=model_config.image_size,
            patch_size=model_config.patch_size,
            num_classes=model_config.num_classes,
            dim=model_config.dim,
            depth=model_config.depth,
            heads=model_config.heads,
            mlp_dim=model_config.mlp_dim,
            channels=model_config.channels,
            dim_head=model_config.dim_head,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.spectrogram(x)
        x = self.vit(x)
        return torch.squeeze(x, dim=1)

In [4]:
class ShapleyAnalyzer:
    """Handles Shapley value computation for feature importance analysis."""

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device

    def compute_shapley_values(self, inputs: torch.Tensor, num_samples: int = 100) -> np.ndarray:
        """Compute Shapley values for each feature using Monte Carlo sampling."""
        n_features = inputs.shape[1]
        shapley_values = np.zeros(n_features)

        def model_prediction(subset_indices: list[int]) -> float:
            subset_data = inputs.clone()
            mask = torch.ones(n_features, dtype=bool)
            mask[subset_indices] = False
            subset_data[:, mask] = 0

            with torch.no_grad():
                output = self.model(subset_data)
                return torch.sigmoid(output).mean().item()

        for i in range(n_features):
            contributions = []
            for _ in range(num_samples):
                subset_size = np.random.randint(0, n_features)
                subset = np.random.choice(
                    [j for j in range(n_features) if j != i], size=subset_size, replace=False
                )

                with_i = model_prediction(list(subset) + [i])
                without_i = model_prediction(list(subset))

                weight = (
                    np.math.factorial(subset_size)
                    * np.math.factorial(n_features - subset_size - 1)
                    / np.math.factorial(n_features)
                )

                contributions.append((with_i - without_i) * weight)

            shapley_values[i] = np.mean(contributions)

        return shapley_values

In [5]:
class ExperimentManager:
    """Manages the training process and experiment tracking."""

    def __init__(
        self,
        model_config: ModelConfig,
        spec_config: SpectrogramConfig,
        experiment_dir: Path | None = None,
    ):
        self.model_config = model_config
        self.spec_config = spec_config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.experiment_dir = experiment_dir or Path(f"data/{datetime.now():%Y%m%d_%H%M%S}")
        self.experiment_dir.mkdir(parents=True, exist_ok=True)

        self.model = VisionTransformerModel(model_config, spec_config).to(self.device)
        self.output_fn = OutputFunction().to(self.device)
        self.criterion = nn.BCEWithLogitsLoss().to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=model_config.learning_rate)

    def load_data(self, data_path: Path) -> tuple[DataLoader, DataLoader]:
        """Load and prepare data for training."""
        with open(data_path) as f:
            data = json.load(f)

        inputs = torch.tensor(data["respMatrix"]).to(self.device)
        targets = torch.tensor(data["Y"]).to(self.device)

        # Reshape inputs if necessary
        if len(inputs.shape) == 4:  # [n_samples, dim0, dim1, frames]
            n_samples, dim0, dim1, frames = inputs.shape
            inputs = inputs.view(n_samples, dim0 * dim1, frames)

        dataset = TensorDataset(inputs, targets)
        train_size = int(self.model_config.train_split * len(dataset))
        val_size = len(dataset) - train_size

        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

        return (
            DataLoader(train_dataset, batch_size=self.model_config.batch_size, shuffle=True),
            DataLoader(val_dataset, batch_size=self.model_config.batch_size),
        )

    def train_epoch(self, train_loader: DataLoader) -> dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        total_loss = 0
        outputs, targets = [], []

        for data, target in train_loader:
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target.float())

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            outputs.extend(self.output_fn(output).cpu().numpy())
            targets.extend(target.cpu().numpy())

        metrics = {
            "loss": total_loss / len(train_loader),
            "accuracy": accuracy_score(targets, outputs),
            "auroc": roc_auc_score(targets, outputs),
        }

        return metrics

    def validate(self, val_loader: DataLoader) -> dict[str, float]:
        """Perform validation."""
        self.model.eval()
        total_loss = 0
        outputs, targets = [], []

        with torch.no_grad():
            for data, target in val_loader:
                output = self.model(data)
                loss = self.criterion(output, target.float())

                total_loss += loss.item()
                outputs.extend(self.output_fn(output).cpu().numpy())
                targets.extend(target.cpu().numpy())

        metrics = {
            "loss": total_loss / len(val_loader),
            "accuracy": accuracy_score(targets, outputs),
            "auroc": roc_auc_score(targets, outputs),
        }

        return metrics

    def run_experiment(self, data_path: Path) -> None:
        """Run the complete experiment."""
        logger.info(f"Starting experiment in {self.experiment_dir}")
        train_loader, val_loader = self.load_data(data_path)

        metrics_path = self.experiment_dir / "metrics.csv"
        with open(metrics_path, "w") as f:
            f.write("epoch,train_loss,train_acc,train_auroc,val_loss,val_acc,val_auroc\n")

        best_val_accuracy = 0
        for epoch in range(self.model_config.max_epochs):
            train_metrics = self.train_epoch(train_loader)
            val_metrics = self.validate(val_loader)

            # Log metrics
            with open(metrics_path, "a") as f:
                f.write(
                    f"{epoch},{train_metrics['loss']:.4f},{train_metrics['accuracy']:.4f},"
                    f"{train_metrics['auroc']:.4f},{val_metrics['loss']:.4f},"
                    f"{val_metrics['accuracy']:.4f},{val_metrics['auroc']:.4f}\n"
                )

            # Save best model
            if val_metrics["accuracy"] > best_val_accuracy:
                best_val_accuracy = val_metrics["accuracy"]
                torch.save(self.model.state_dict(), self.experiment_dir / "best_model.pt")

            logger.info(
                f"Epoch {epoch}: Train Acc={train_metrics['accuracy']:.4f}, "
                f"Val Acc={val_metrics['accuracy']:.4f}"
            )

            # Early stopping check
            if val_metrics["accuracy"] == 1.0 or train_metrics["accuracy"] == 1.0:
                logger.info("Early stopping triggered")
                break

        # Compute Shapley values
        logger.info("Computing Shapley values...")
        analyzer = ShapleyAnalyzer(self.model, self.device)
        inputs = train_loader.dataset.tensors[0]
        shapley_values = analyzer.compute_shapley_values(inputs)
        np.save(self.experiment_dir / "shapley_values.npy", shapley_values)

In [6]:
"""Example usage of the experiment framework."""
model_config = ModelConfig(
    image_size=(3, 16),  # to adjust based on spectrogram output
    patch_size=1,
    num_classes=1,
    dim=1024,
    depth=2,
    heads=16,
    mlp_dim=2048,
    channels=1720,  # to adjust based on input size
    dim_head=64,
)

spec_config = SpectrogramConfig(n_fft=4, win_length=None, hop_length=None)

data_path = Path("data/prelick_data_no_zeros1.csv")
experiment = ExperimentManager(model_config, spec_config)
# experiment.run_experiment(data_path)