# Suitability Filter Example

In [19]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms

from filter.suitability_filter import get_sf_features, SuitabilityFilter

## Prepare and train MNIST data & model

### Data preparation

In [20]:
NUM_CLASSES = 10  # MNIST has 10 classes (digits 0-9)
IMG_SIZE = 28     # MNIST image size
BATCH_SIZE = 64   # Adjusted batch size

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Mean and std for MNIST
])

mnist_train_full = MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test_full = MNIST(root='./data', train=False, download=True, transform=transform)

train_indices = list(range(1000, len(mnist_train_full)))
train_data = Subset(mnist_train_full, train_indices)
cnn_train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

test_indices = list(range(5000, len(mnist_test_full)))
test_data = Subset(mnist_test_full, test_indices)
cnn_test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

### Model definition

In [21]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) # MNIST is 1 channel
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 14x14 -> 7x7
        self.fc = nn.Linear(32 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7) # Flatten
        x = self.fc(x)
        return x

### Model training

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=NUM_CLASSES).to(device)

# Hyperparameters for CNN training
LEARNING_RATE = 0.0001
EPOCHS = 1  

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"\nStarting SimpleCNN training on {device}...")
for epoch in range(EPOCHS):
    model.train()  
    running_loss = 0.0
    for i, (images, labels) in enumerate(cnn_train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 200 == 0: # Print progress every 200 batches
            print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(cnn_train_loader)}], Loss: {running_loss/200:.4f}")
            running_loss = 0.0
print("SimpleCNN training finished.")

# Evaluate the trained CNN on the full test set
model.eval() # Set model to evaluation mode
correct_cnn = 0
total_cnn = 0
with torch.no_grad():
    for images, labels in cnn_test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_cnn += labels.size(0)
        correct_cnn += (predicted == labels).sum().item()

cnn_accuracy = 100 * correct_cnn / total_cnn
print(f"\nAccuracy of the trained SimpleCNN on the {total_cnn} test images: {cnn_accuracy:.2f}%")


Starting SimpleCNN training on cuda...


Epoch [1/1], Step [200/922], Loss: 1.4744
Epoch [1/1], Step [400/922], Loss: 0.5030
Epoch [1/1], Step [600/922], Loss: 0.3252
Epoch [1/1], Step [800/922], Loss: 0.2753
SimpleCNN training finished.

Accuracy of the trained SimpleCNN on the 5000 test images: 95.86%


## Suitability Filter

### Define Datasets: suitability filter data (to train prediction correctness classifier) and user data

In [35]:
# classifier_loader_sf: Data to train the prediction correctness classifier
classifier_indices_sf = list(range(0, 1000)) # Distinct from train_dataset
classifier_dataset_sf = Subset(mnist_train_full, classifier_indices_sf)
classifier_loader_sf = DataLoader(classifier_dataset_sf, batch_size=BATCH_SIZE, shuffle=False)

# user_loader_sf: User data to be tested
user_indices_sf = list(range(0, 5000)) # Distinct from test_dataset
user_dataset_sf = Subset(mnist_test_full, user_indices_sf)
user_loader_sf = DataLoader(user_dataset_sf, batch_size=BATCH_SIZE, shuffle=False)

print(f"Classifier SF data: {len(classifier_dataset_sf)} samples from MNIST train (not used for model training)")
print(f"User SF data: {len(user_dataset_sf)} samples from MNIST test")

print(f"Test data: {len(test_data)} (different) samples from MNIST test")

Classifier SF data: 1000 samples from MNIST train (not used for model training)
User SF data: 5000 samples from MNIST test
Test data: 5000 (different) samples from MNIST test


### Feature extraction

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

classifier_feats, classifier_corr = get_sf_features(classifier_loader_sf, model, device)
print(f"Shape of classifier_features: {classifier_feats.shape}")
print(f"SF data correctness: {np.sum(classifier_corr)} correct out of {len(classifier_corr)} (approx accuracy: {np.mean(classifier_corr):.2f})")

user_feats, user_corr = get_sf_features(user_loader_sf, model, device)
print(f"Shape of user_features: {user_feats.shape}")

test_feats, test_corr = get_sf_features(cnn_test_loader, model, device)
print(f"Shape of test_features: {test_feats.shape}")
print(f"Test correctness: {np.sum(test_corr)} correct out of {len(test_corr)} (approx accuracy: {np.mean(test_corr):.2f})")

print("--> THIS PART WOULD NOT BE KNOWN IN PRACTICE (NO ACCESS TO USER LABELS):")
print(f"    User correctness: {np.sum(user_corr)} correct out of {len(user_corr)} (approx accuracy: {np.mean(user_corr):.2f})")

Shape of classifier_features: (1000, 12)
SF data correctness: 931 correct out of 1000 (approx accuracy: 0.93)
Shape of user_features: (5000, 12)
Shape of test_features: (5000, 12)
Test correctness: 4793 correct out of 5000 (approx accuracy: 0.96)
--> THIS PART WOULD NOT BE KNOWN IN PRACTICE (NO ACCESS TO USER LABELS):
    User correctness: 4601 correct out of 5000 (approx accuracy: 0.92)


### Suitability filter test

In [41]:
suitability_filter = SuitabilityFilter(
    test_features=test_feats,
    test_corr=test_corr, # Correctness of primary model on SF's "test" data
    classifier_features=classifier_feats,
    classifier_corr=classifier_corr, # Correctness of primary model on SF's "classifier training" data
    device=device,
    normalize=True
)

suitability_filter.train_classifier(classifier="logistic_regression", calibrated=True)
test_margin = 0

results = suitability_filter.suitability_test(user_features=user_feats, margin=test_margin)
print("\nSuitability Test Results:")
for key, value in results.items():
    if isinstance(value, (int, float)):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

alpha = 0.05

if results['p_value'] < alpha:
    print(f"\nUser data IS considered non-inferior (p={results['p_value']:.4f} < {alpha}). The new data is not significantly worse than the test data by more than the margin.")
else:
    print(f"\nUser data is NOT proven non-inferior (p={results['p_value']:.4f} >= {alpha}). We cannot conclude that the new data is within the non-inferiority margin of the test data.")


Suitability Test Results:
  t_statistic: 11.7601
  p_value: 1.0000
  reject_null: False

User data is NOT proven non-inferior (p=1.0000 >= 0.05). We cannot conclude that the new data is within the non-inferiority margin of the test data.


Note that the above is expected since the accuracy on user data (92% in this example) is lower than on test data (96%) and hence the performance on user data is in fact not non-inferior (i.e., it is inferior) to the performance on test data with a margin of 0.