In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms

from suitability.filter.suitability_filter import get_sf_features, SuitabilityFilter

In [None]:
NUM_CLASSES = 10  # MNIST has 10 classes (digits 0-9)
IMG_SIZE = 28     # MNIST image size
BATCH_SIZE = 64   # Adjusted batch size

### DATA PREPARATION

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Mean and std for MNIST
])

mnist_train_full = MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test_full = MNIST(root='./data', train=False, download=True, transform=transform)

# Create subsets for our three data groups for the suitability filter
# classifier_loader_sf: Data to train the prediction correctness classifier
classifier_indices_sf = list(range(0, 1000))
classifier_dataset_sf = Subset(mnist_train_full, classifier_indices_sf)
classifier_loader_sf = DataLoader(classifier_dataset_sf, batch_size=BATCH_SIZE, shuffle=False)

test_indices_sf = list(range(0, 500))
test_dataset_sf = Subset(mnist_test_full, test_indices_sf)
test_loader_sf = DataLoader(test_dataset_sf, batch_size=BATCH_SIZE, shuffle=False)

user_indices_sf = list(range(500, 1000)) # Distinct from test_dataset_sf
user_dataset_sf = Subset(mnist_test_full, user_indices_sf)
user_loader_sf = DataLoader(user_dataset_sf, batch_size=BATCH_SIZE, shuffle=False)

print(f"Classifier SF data: {len(classifier_dataset_sf)} samples from MNIST train")
print(f"Test SF data: {len(test_dataset_sf)} samples from MNIST test")
print(f"User SF data: {len(user_dataset_sf)} samples from MNIST test")


### MODEL DEFINITION

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) # MNIST is 1 channel
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 14x14 -> 7x7
        self.fc = nn.Linear(32 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7) # Flatten
        x = self.fc(x)
        return x

### FEATURE EXTRACTION

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=NUM_CLASSES).to(device)

classifier_feats, classifier_corr = get_sf_features(classifier_loader_sf, model, device)
print(f"Shape of classifier_features: {classifier_feats.shape}")
print(f"SF data correctness: {np.sum(classifier_corr)} correct out of {len(classifier_corr)} (approx accuracy: {np.mean(classifier_corr):.2f})")

test_feats, test_corr = get_sf_features(test_loader_sf, model, device)
print(f"Shape of test_features: {test_feats.shape}")
print(f"Test correctness: {np.sum(test_corr)} correct out of {len(test_corr)} (approx accuracy: {np.mean(test_corr):.2f})")

user_feats, _ = get_sf_features(user_loader_sf, model, device)
print(f"Shape of user_features: {user_feats.shape}")

### SUITABILITY FILTER

sf_filter = SuitabilityFilter(
    test_features=test_feats,
    test_corr=test_corr, # Correctness of primary model on SF's "test" data
    classifier_features=classifier_feats,
    classifier_corr=classifier_corr, # Correctness of primary model on SF's "classifier training" data
    device=device,
    normalize=True
)

sf_filter.train_classifier(classifier_name="logistic_regression", calibrated=True)
test_margin = 0.01 

results = sf_filter.suitability_test(user_features=user_feats, margin=test_margin)
print("\nSuitability Test Results:")
for key, value in results.items():
    if isinstance(value, (int, float)):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

alpha = 0.05

if results['p_value'] < alpha:
    print(f"\nUser data IS considered non-inferior (p={results['p_value']:.4f} < {alpha}). The new data is not significantly worse than the test data by more than the margin.")
else:
    print(f"\nUser data is NOT proven non-inferior (p={results['p_value']:.4f} >= {alpha}). We cannot conclude that the new data is within the non-inferiority margin of the test data.")