In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
from torch.utils.data import DataLoader

# Define classes and spurious feature difficulty
classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

root_dir = "./mnist_data/" 

# Initialize the train dataset
trainset = SpuCoMNIST(
    root=root_dir,
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

# Initialize the test dataset
testset = SpuCoMNIST(
    root=root_dir,
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

# Dataloader
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
test_loader = DataLoader(testset, batch_size=64, shuffle=False)

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(3 * 28 * 28, 128)  # For RGB images with size 28x28
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define train function
def train(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for images, labels in loader:  # Expect only 2 values: images and labels
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model with ERM
epochs = 100
for epoch in range(epochs):
    train_loss = train(model, train_loader, criterion, optimizer)
    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {train_loss:.4f}')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|█████████████████████████████| 9912422/9912422 [00:16<00:00, 615852.62it/s]


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 28881/28881 [00:00<00:00, 371597.84it/s]


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 2637392.30it/s]


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 1013595.57it/s]


Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



100%|██████████████████████████████████| 48004/48004 [00:00<00:00, 51750.53it/s]
100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 50032.79it/s]


Epoch [1/100], Loss: 0.0504
Epoch [2/100], Loss: 0.0339
Epoch [3/100], Loss: 0.0302
Epoch [4/100], Loss: 0.0268
Epoch [5/100], Loss: 0.0246
Epoch [6/100], Loss: 0.0227
Epoch [7/100], Loss: 0.0208
Epoch [8/100], Loss: 0.0183
Epoch [9/100], Loss: 0.0180
Epoch [10/100], Loss: 0.0170
Epoch [11/100], Loss: 0.0141
Epoch [12/100], Loss: 0.0139
Epoch [13/100], Loss: 0.0118
Epoch [14/100], Loss: 0.0109
Epoch [15/100], Loss: 0.0110
Epoch [16/100], Loss: 0.0093
Epoch [17/100], Loss: 0.0092
Epoch [18/100], Loss: 0.0086
Epoch [19/100], Loss: 0.0068
Epoch [20/100], Loss: 0.0077
Epoch [21/100], Loss: 0.0084
Epoch [22/100], Loss: 0.0057
Epoch [23/100], Loss: 0.0060
Epoch [24/100], Loss: 0.0058
Epoch [25/100], Loss: 0.0062
Epoch [26/100], Loss: 0.0058
Epoch [27/100], Loss: 0.0049
Epoch [28/100], Loss: 0.0057
Epoch [29/100], Loss: 0.0038
Epoch [30/100], Loss: 0.0034
Epoch [31/100], Loss: 0.0060
Epoch [32/100], Loss: 0.0035
Epoch [33/100], Loss: 0.0026
Epoch [34/100], Loss: 0.0068
Epoch [35/100], Loss: 0

In [7]:
import torch
from sklearn.cluster import KMeans
import numpy as np
np.__config__.show()

# Faulty threadpool fix
from threadpoolctl import threadpool_limits
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"

# Extract embeddings from model
def extract_embeddings(model, loader):
    model.eval()  # Set model to evaluation mode
    embeddings = []
    labels_list = []
    with torch.no_grad():  # Disable gradient computation
        for images, labels in loader:
            # Flatten the images
            x = images.view(images.size(0), -1) 
            embedding = model.fc1(x)  
            embeddings.append(embedding.cpu()) 
            labels_list.append(labels.cpu()) 
    return torch.cat(embeddings), torch.cat(labels_list)


openblas64__info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None), ('BLAS_SYMBOL_SUFFIX', '64_'), ('HAVE_BLAS_ILP64', None)]
    runtime_library_dirs = ['/usr/local/lib']
blas_ilp64_opt_info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None), ('BLAS_SYMBOL_SUFFIX', '64_'), ('HAVE_BLAS_ILP64', None)]
    runtime_library_dirs = ['/usr/local/lib']
openblas64__lapack_info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None), ('BLAS_SYMBOL_SUFFIX', '64_'), ('HAVE_BLAS_ILP64', None), ('HAVE_LAPACKE', None)]
    runtime_library_dirs = ['/usr/local/lib']
lapack_ilp64_opt_info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None

In [8]:
# Extract embeddings from the trained model
train_embeddings, train_labels = extract_embeddings(model, train_loader)

# Convert embeddings to NumPy format for KMeans clustering
train_embeddings_np = train_embeddings.numpy()  # Convert to NumPy array
print(train_embeddings_np.shape)


(48004, 128)


In [9]:
# Check the number of samples in the dataset
num_samples_in_dataset = len(train_loader.dataset)
print(f"Number of samples in the dataset: {num_samples_in_dataset}")


Number of samples in the dataset: 48004


In [10]:
num_samples = train_embeddings.shape[0]
print("Number of samples in train_embeddings:", num_samples)


Number of samples in train_embeddings: 48004


In [11]:
# K-Means clustering with threadpool control
n_clusters = 3  # 
kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=0).fit(train_embeddings_np)

# Assign each sample to a cluster
train_clusters = kmeans.labels_

# Print the cluster assignments
print("Cluster assignments:", train_clusters)

Cluster assignments: [1 2 0 ... 2 1 0]


In [12]:
from torch.utils.data import Subset

# Create group-wise datasets
# Use the cluster assignments from K-Means
train_groups = kmeans.labels_

# Create group-wise datasets
group_indices = [[] for _ in range(n_clusters)]
for idx, group in enumerate(train_groups):
    group_indices[group].append(idx)

group_dataloaders = [DataLoader(Subset(trainset, indices), batch_size=64, shuffle=True) for indices in group_indices]

# Balanced group training loop
import itertools

def balanced_train(model, group_dataloaders, criterion, optimizer, max_steps):
    model.train()
    group_iters = [iter(loader) for loader in group_dataloaders]
    group_cycle = itertools.cycle(enumerate(group_iters))
    
    steps = 0
    while steps < max_steps:
        try:
            i, group_iter = next(group_cycle)
            images, labels = next(group_iter)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            steps += 1
        except StopIteration:
            # Reset the exhausted iterator
            group_iters[i] = iter(group_dataloaders[i])
            continue

# Retrain using group-balanced batches
max_steps_per_epoch = len(trainset) // 64  # Adjust based on batch size
for epoch in range(epochs):
    balanced_train(model, group_dataloaders, criterion, optimizer, max_steps=max_steps_per_epoch)


In [13]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return total_loss / len(loader), accuracy
    
for epoch in range(epochs):
    balanced_train(model, group_dataloaders, criterion, optimizer, max_steps=max_steps_per_epoch)
    test_loss, test_accuracy = evaluate(model, test_loader, criterion)
    print(f'Epoch [{epoch + 1}/{epochs}], Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')


Epoch [1/100], Test Loss: 10.0623, Test Accuracy: 37.80%
Epoch [2/100], Test Loss: 10.1812, Test Accuracy: 38.31%
Epoch [3/100], Test Loss: 10.3076, Test Accuracy: 37.30%
Epoch [4/100], Test Loss: 10.2932, Test Accuracy: 38.56%
Epoch [5/100], Test Loss: 10.3994, Test Accuracy: 38.30%
Epoch [6/100], Test Loss: 10.3980, Test Accuracy: 38.86%
Epoch [7/100], Test Loss: 10.8612, Test Accuracy: 38.11%
Epoch [8/100], Test Loss: 10.6602, Test Accuracy: 39.14%
Epoch [9/100], Test Loss: 10.4847, Test Accuracy: 39.04%
Epoch [10/100], Test Loss: 10.5360, Test Accuracy: 39.29%
Epoch [11/100], Test Loss: 10.8167, Test Accuracy: 38.73%
Epoch [12/100], Test Loss: 10.7121, Test Accuracy: 39.30%
Epoch [13/100], Test Loss: 10.9572, Test Accuracy: 38.56%
Epoch [14/100], Test Loss: 10.9274, Test Accuracy: 39.08%
Epoch [15/100], Test Loss: 11.0578, Test Accuracy: 38.60%
Epoch [16/100], Test Loss: 11.1447, Test Accuracy: 38.88%
Epoch [17/100], Test Loss: 11.3127, Test Accuracy: 38.92%
Epoch [18/100], Test Lo

In [14]:
import os
notebook_path = os.path.abspath("")
print("Notebook path:", notebook_path)

Notebook path: /Users/andrewsuh/Desktop/Models
