In [1]:
import warnings
from tqdm import TqdmWarning

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=TqdmWarning)

In [2]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty

# Defaults from the notebook examples
classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train",
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test",
)
testset.initialize()

## Step 1:
Train a model using ERM.

In [3]:
from spuco.robust_train import ERM
from torch.optim import SGD
from spuco.models import model_factory
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = model_factory(
    arch="lenet", input_shape=trainset[0][0].shape, num_classes=trainset.num_classes
).to(device)

erm = ERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    batch_size=64,
    # For simplicity, same optimizer as in sample notebooks
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=False
)
erm.train()

##### Evaluation before balancing sub-classes

In [4]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=False
)
evaluator.evaluate()

Evaluating group-wise accuracy: 100%|██████████| 25/25 [07:05<00:00, 17.00s/it]


{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 100.0}

In [5]:
print(evaluator.worst_group_accuracy)
print(evaluator.average_accuracy)
print(evaluator.evaluate_spurious_attribute_prediction())

((0, 1), 0.0)
99.49379218398467
20.27


## Step 2:
Cluster inputs based on the output they produce for ERM.

In [6]:
# Gets logits for our training set across possible superclasses
logits = erm.trainer.get_trainset_outputs()

In [7]:
from spuco.group_inference import Cluster

cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    # k in [2, 10] pg. 23
    max_clusters=10,
    device=device,
    verbose=False
)

# Uses silhouette scores per superclass to get k values, pg. 6, pg. 23
group_partition = cluster.infer_groups()

In [8]:
# Distribution of subclasses
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

(0, 0) 10081
(0, 1) 39
(0, 2) 13
(1, 0) 9623
(1, 1) 49
(2, 0) 8965
(2, 1) 15
(2, 2) 10
(2, 3) 13
(2, 4) 8
(3, 0) 9710
(3, 1) 13
(3, 2) 24
(4, 0) 48
(4, 1) 9393


## Step 3:
Retrain using "Group-Balancing" to ensure in each batch each group appears equally.

In [9]:
from spuco.robust_train import GroupBalanceBatchERM

# Rebalance each batch using group partition
gpb_erm = GroupBalanceBatchERM(
    model=model,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    num_epochs=1,
    device=device,
    verbose=False
)

gpb_erm.train()

##### Evaluation after balancing sub-classes

In [10]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=False
)
evaluator.evaluate()

Evaluating group-wise accuracy: 100%|██████████| 25/25 [07:00<00:00, 16.83s/it]


{(0, 0): 97.16312056737588,
 (0, 1): 71.63120567375887,
 (0, 2): 84.16075650118204,
 (0, 3): 74.46808510638297,
 (0, 4): 94.56264775413712,
 (1, 0): 90.2200488997555,
 (1, 1): 99.26650366748166,
 (1, 2): 53.431372549019606,
 (1, 3): 75.24509803921569,
 (1, 4): 58.8235294117647,
 (2, 0): 54.93333333333333,
 (2, 1): 46.666666666666664,
 (2, 2): 95.2,
 (2, 3): 80.8,
 (2, 4): 60.160427807486634,
 (3, 0): 52.26130653266332,
 (3, 1): 49.87405541561713,
 (3, 2): 90.42821158690177,
 (3, 3): 95.46599496221663,
 (3, 4): 7.8085642317380355,
 (4, 0): 71.28463476070529,
 (4, 1): 42.821158690176325,
 (4, 2): 45.08816120906801,
 (4, 3): 74.4949494949495,
 (4, 4): 94.94949494949495}

In [11]:
print(evaluator.worst_group_accuracy)
print(evaluator.average_accuracy)
print(evaluator.evaluate_spurious_attribute_prediction())

((3, 4), 7.8085642317380355)
96.28908895570329
20.36
