In [1]:
!git clone https://github.com/intsystems/SToG.git

Cloning into 'SToG'...
remote: Enumerating objects: 307, done.[K
remote: Counting objects: 100% (307/307), done.[K
remote: Compressing objects: 100% (244/244), done.[K
remote: Total 307 (delta 95), reused 203 (delta 40), pack-reused 0 (from 0)[K
Receiving objects: 100% (307/307), 8.33 MiB | 13.08 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [2]:
!cd SToG/src/mylib

In [3]:
from SToG.src.mylib.stochastic_gating_complete import STGLayer, STELayer, GumbelLayer, CorrelatedSTGLayer, L1Layer

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from abc import ABC, abstractmethod
from sklearn.datasets import load_breast_cancer, load_wine, make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegressionCV
import warnings
warnings.filterwarnings('ignore')

class ModelFeatureSelection(nn.Module):
    def __init__(self, model: nn.Sequential, selection_layers):
        super().__init__()

        for el in selection_layers:
            if el[0] < 0 or el[0] > len(list(model)) - 1:
                raise ValueError("Selection layers must be in the range [1, num_operations]")

        layers = list(model)

        for idx, layer in sorted(selection_layers, key=lambda x: x[0], reverse=True):
            layers.insert(idx, layer)

        self.sel_layer_indices = [el[0] + i for i, el in enumerate(selection_layers)]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [38]:
import torch.optim as optim

lambda_reg = 0.0001
device = 'cpu'
criterion = nn.CrossEntropyLoss()
model = ModelFeatureSelection(nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 2)
), [(0, STGLayer(30)), (2, STGLayer(10))])

optimizer = optim.Adam([
    {"params": model.layers[i].parameters(), "lr": 1e-2} if i in model.sel_layer_indices else {"params": model.layers[i].parameters(), "lr": 1e-3, "weight_decay": 1e-4} for i in range(len(list(model.layers)))
])

In [33]:
def train_epoch(X_train, y_train, X_val, y_val):
    """Train for one epoch."""
    model.train()

    optimizer.zero_grad()

    predictions = model(X_train)

    classification_loss = criterion(predictions, y_train)
    regularization_loss = sum(model.layers[i].regularization_loss() for i in model.sel_layer_indices)
    total_loss = classification_loss + lambda_reg * regularization_loss

    total_loss.backward()

    torch.nn.utils.clip_grad_norm_(
        list(model.parameters()),
        max_norm=1.0
    )

    optimizer.step()

    model.eval()

    with torch.no_grad():
        val_predictions = model(X_val)
        val_loss = criterion(val_predictions, y_val)
        val_acc = (val_predictions.argmax(1) == y_val).float().mean().item() * 100
        sel_count = sum(model.layers[i].get_selected_features().sum() for i in model.sel_layer_indices)

    return {
        'train_loss': total_loss.item(),
        'val_loss': val_loss.item(),
        'val_acc': val_acc,
        'sel_count': sel_count,
        'reg_loss': regularization_loss.item()
    }

In [32]:
def fit(X_train, y_train, X_val, y_val, epochs=300,
            patience=50, verbose=False):
        """
        Train the model with early stopping.
        """
        best_val_acc = 0
        wait = 0
        history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'sel_count': [], 'reg_loss': []}
        best_state = {}

        for epoch in range(epochs):
            metrics = train_epoch(X_train, y_train, X_val, y_val)

            for key, value in metrics.items():
                history[key].append(value)

            if metrics['val_acc'] > best_val_acc:
                best_val_acc = metrics['val_acc']
                wait = 0
                best_state = {
                    'model': model.state_dict(),
                    'epoch': epoch,
                    'val_acc': best_val_acc,
                    'sel_count': metrics['sel_count']
                }
            else:
                wait += 1

            if wait >= patience and epoch >= 100:
                if verbose:
                    print(f"Early stopping at epoch {epoch+1}")
                break

            if verbose and (epoch + 1) % 50 == 0:
                print(f"Epoch {epoch+1}: "
                      f"val_acc={metrics['val_acc']:.2f}%, "
                      f"sel={metrics['sel_count']}, "
                      f"Î»={lambda_reg:.4f}")

        if best_state:
            model.load_state_dict(best_state['model'])

        return history

In [13]:
class DatasetLoader:
    """Load and prepare datasets for benchmarking."""

    @staticmethod
    def load_breast_cancer():
        """Load breast cancer dataset."""
        data = load_breast_cancer()
        return {
            'name': 'Breast Cancer',
            'X': data.data,
            'y': data.target,
            'n_important': 10,
            'description': 'Binary classification, 30 features'
        }

    @staticmethod
    def load_wine():
        """Load wine dataset."""
        data = load_wine()
        return {
            'name': 'Wine',
            'X': data.data,
            'y': data.target,
            'n_important': 7,
            'description': '3-class classification, 13 features'
        }

    @staticmethod
    def create_synthetic_high_dim():
        """Create synthetic high-dimensional dataset (MADELON-like)."""
        X, y = make_classification(
            n_samples=600,
            n_features=100,
            n_informative=5,
            n_redundant=10,
            n_repeated=0,
            n_classes=2,
            n_clusters_per_class=2,
            flip_y=0.03,
            class_sep=1.0,
            random_state=42
        )
        return {
            'name': 'Synthetic-HighDim',
            'X': X,
            'y': y,
            'n_important': 5,
            'description': 'Binary classification, 100 features, 5 informative'
        }

    @staticmethod
    def create_synthetic_correlated():
        """Create synthetic dataset with correlated features."""
        np.random.seed(42)
        n_samples = 500
        n_informative = 5
        n_total = 50

        X_inform = np.random.randn(n_samples, n_informative)

        X_redundant = []
        for i in range(n_informative):
            for _ in range(2):
                noise = np.random.randn(n_samples) * 0.1
                X_redundant.append(X_inform[:, i] + noise)
        X_redundant = np.column_stack(X_redundant)

        n_noise = n_total - n_informative - X_redundant.shape[1]
        X_noise = np.random.randn(n_samples, n_noise)

        X = np.column_stack([X_inform, X_redundant, X_noise])

        y = (X_inform[:, 0] + X_inform[:, 1] * X_inform[:, 2] > 0).astype(int)

        return {
            'name': 'Synthetic-Correlated',
            'X': X,
            'y': y,
            'n_important': n_informative,
            'description': f'Binary classification, {n_total} features, {n_informative} informative with correlated copies'
        }

In [36]:
dataset = DatasetLoader().load_breast_cancer()
random_state = 42

X, y = dataset['X'], dataset['y']
n_features = X.shape[1]
n_classes = len(np.unique(y))

X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=random_state, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.2, random_state=random_state, stratify=y_temp
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

X_train_t = torch.FloatTensor(X_train).to(device)
y_train_t = torch.LongTensor(y_train).to(device)
X_val_t = torch.FloatTensor(X_val).to(device)
y_val_t = torch.LongTensor(y_val).to(device)
X_test_t = torch.FloatTensor(X_test).to(device)
y_test_t = torch.LongTensor(y_test).to(device)

In [39]:
fit(X_train_t, y_train_t, X_val_t, y_val_t)

{'train_loss': [0.9246473908424377,
  0.9194371104240417,
  0.9222840070724487,
  0.92277592420578,
  0.920691192150116,
  0.9165613651275635,
  0.915504515171051,
  0.92184978723526,
  0.9157575964927673,
  0.9109681844711304,
  0.9113531708717346,
  0.9122119545936584,
  0.9148720502853394,
  0.9090137481689453,
  0.9081457257270813,
  0.9051733016967773,
  0.9075988531112671,
  0.9025763869285583,
  0.9009677171707153,
  0.9006478190422058,
  0.9036339521408081,
  0.9004407525062561,
  0.8988052010536194,
  0.901374876499176,
  0.9007357358932495,
  0.8943776488304138,
  0.8972218036651611,
  0.8935827016830444,
  0.892792284488678,
  0.8854230046272278,
  0.8862276077270508,
  0.8820794820785522,
  0.8795429468154907,
  0.8821713924407959,
  0.8828898668289185,
  0.8850051760673523,
  0.8800486922264099,
  0.8800867199897766,
  0.8775691390037537,
  0.8737404346466064,
  0.8741771578788757,
  0.8746575713157654,
  0.8701006770133972,
  0.8713622689247131,
  0.8634679317474365,
  0.