# Train a model using ERM

In [2]:
#install the SpuCo package
!pip install SpuCo --user



In [1]:
import torch
from spuco.utils import set_seed
set_seed(0)

In [3]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

valset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="val"
)
valset.initialize()

100%|█████████████████████████████████████████████████████████████████████████| 11996/11996 [00:00<00:00, 12166.43it/s]


In [4]:
from spuco.robust_train import ERM
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.99,
    classes=classes,
    split="train"
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

100%|█████████████████████████████████████████████████████████████████████████| 48004/48004 [00:04<00:00, 11403.43it/s]
100%|██████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 9956.71it/s]


In [5]:
from spuco.models import model_factory 

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes)

In [6]:
from torch.optim import SGD

erm = ERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    verbose=True
)

In [7]:
erm.train()

Epoch 0: 100%|█████████████████████████████████████| 751/751 [00:42<00:00, 17.68batch/s, accuracy=100.0%, loss=0.00865]


In [8]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|██                                                 | 1/25 [00:15<06:07, 15.32s/it]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|████                                               | 2/25 [00:32<06:11, 16.14s/it]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  12%|██████                                             | 3/25 [00:52<06:34, 17.94s/it]

Group (0, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  16%|████████▏                                          | 4/25 [01:12<06:35, 18.83s/it]

Group (0, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██████████▏                                        | 5/25 [01:30<06:13, 18.70s/it]

Group (0, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  24%|████████████▏                                      | 6/25 [01:50<05:59, 18.94s/it]

Group (1, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  28%|██████████████▎                                    | 7/25 [02:08<05:38, 18.79s/it]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|████████████████▎                                  | 8/25 [02:25<05:06, 18.04s/it]

Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  36%|██████████████████▎                                | 9/25 [02:43<04:51, 18.22s/it]

Group (1, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████████████████████                              | 10/25 [02:59<04:21, 17.43s/it]

Group (1, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  44%|██████████████████████                            | 11/25 [03:17<04:05, 17.54s/it]

Group (2, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  48%|████████████████████████                          | 12/25 [03:33<03:41, 17.06s/it]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  52%|██████████████████████████                        | 13/25 [03:49<03:22, 16.86s/it]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|████████████████████████████                      | 14/25 [04:05<03:04, 16.73s/it]

Group (2, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  60%|██████████████████████████████                    | 15/25 [04:22<02:45, 16.53s/it]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|████████████████████████████████                  | 16/25 [04:38<02:28, 16.51s/it]

Group (3, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  68%|██████████████████████████████████                | 17/25 [04:55<02:14, 16.78s/it]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  72%|████████████████████████████████████              | 18/25 [05:12<01:56, 16.67s/it]

Group (3, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  76%|██████████████████████████████████████            | 19/25 [05:27<01:38, 16.35s/it]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████████████████████████████████████          | 20/25 [05:44<01:21, 16.37s/it]

Group (3, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  84%|██████████████████████████████████████████        | 21/25 [06:03<01:08, 17.12s/it]

Group (4, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  88%|████████████████████████████████████████████      | 22/25 [06:20<00:51, 17.17s/it]

Group (4, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  92%|██████████████████████████████████████████████    | 23/25 [06:36<00:33, 16.80s/it]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|████████████████████████████████████████████████  | 24/25 [06:52<00:16, 16.59s/it]

Group (4, 3) Accuracy: 1.7676767676767677


Evaluating group-wise accuracy: 100%|██████████████████████████████████████████████████| 25/25 [07:08<00:00, 17.13s/it]

Group (4, 4) Accuracy: 100.0





{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 1.7676767676767677,
 (4, 4): 100.0}

In [9]:
evaluator.worst_group_accuracy

((0, 1), 0.0)

In [10]:
evaluator.average_accuracy

98.99453349426267

In [11]:
evaluator.evaluate_spurious_attribute_prediction()

99.89

# Cluster inputs based on the output they produce for ERM

In [12]:
from spuco.utils import Trainer

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes)
trainer = Trainer(
    trainset=trainset,
    model=model,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    verbose=True
)

trainer.train(1)

Epoch 0: 100%|██████████████████████████████████████| 751/751 [00:44<00:00, 16.81batch/s, accuracy=100.0%, loss=0.0113]


In [13]:
from spuco.group_inference import Cluster, ClusterAlg

logits = trainer.get_trainset_outputs()
cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    cluster_alg=ClusterAlg.KMEANS,
    num_clusters=2,
    verbose=True
)
group_partition = cluster.infer_groups()

Getting Trainset Outputs: 100%|███████████████████████████████████████████████████| 751/751 [00:40<00:00, 18.52batch/s]
Clustering class-wise: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.45it/s]


In [14]:
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

(0, 0) 10061
(0, 1) 72
(1, 0) 9607
(1, 1) 65
(2, 0) 8936
(2, 1) 75
(3, 0) 9703
(3, 1) 44
(4, 0) 9365
(4, 1) 76


In [15]:
from spuco.evaluate import Evaluator 

evaluator = Evaluator(
    testset=trainset,
    group_partition=group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  10%|█████                                              | 1/10 [00:39<05:57, 39.69s/it]

Group (0, 0) Accuracy: 99.70181890468145


Evaluating group-wise accuracy:  20%|██████████▏                                        | 2/10 [01:12<04:43, 35.48s/it]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  30%|███████████████▎                                   | 3/10 [01:52<04:22, 37.49s/it]

Group (1, 0) Accuracy: 99.66690954512335


Evaluating group-wise accuracy:  40%|████████████████████▍                              | 4/10 [02:33<03:53, 38.97s/it]

Group (1, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  50%|█████████████████████████▌                         | 5/10 [03:10<03:11, 38.34s/it]

Group (2, 0) Accuracy: 99.82094897045658


Evaluating group-wise accuracy:  60%|██████████████████████████████▌                    | 6/10 [03:41<02:22, 35.68s/it]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  70%|███████████████████████████████████▋               | 7/10 [04:15<01:45, 35.21s/it]

Group (3, 0) Accuracy: 99.44347109141502


Evaluating group-wise accuracy:  80%|████████████████████████████████████████▊          | 8/10 [04:48<01:09, 34.70s/it]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  90%|█████████████████████████████████████████████▉     | 9/10 [05:25<00:35, 35.32s/it]

Group (4, 0) Accuracy: 99.7971169247197


Evaluating group-wise accuracy: 100%|██████████████████████████████████████████████████| 10/10 [05:59<00:00, 35.99s/it]

Group (4, 1) Accuracy: 0.0





{(0, 0): 99.70181890468145,
 (0, 1): 0.0,
 (1, 0): 99.66690954512335,
 (1, 1): 0.0,
 (2, 0): 99.82094897045658,
 (2, 1): 0.0,
 (3, 0): 99.44347109141502,
 (3, 1): 0.0,
 (4, 0): 99.7971169247197,
 (4, 1): 0.0}

In [16]:
evaluator.worst_group_accuracy

((0, 1), 0.0)

In [17]:
evaluator.average_accuracy

21.037347177320918

In [18]:
evaluator.evaluate_spurious_attribute_prediction()

21.123239730022497

# Retrain using "Group-Balancing" to ensure each you sample batches s.t. each group appears equally

In [19]:
from torch.optim import SGD
from spuco.robust_train import GroupBalanceBatchERM, ClassBalanceBatchERM
from spuco.models import model_factory 

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes)
group_balance_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    verbose=True
)
group_balance_erm.train()

Epoch 0: 100%|█████████████████████████████████████████| 751/751 [00:40<00:00, 18.55batch/s, accuracy=75.0%, loss=1.26]
Epoch 1: 100%|████████████████████████████████████████| 751/751 [00:46<00:00, 16.19batch/s, accuracy=75.0%, loss=0.858]
Epoch 2: 100%|███████████████████████████████████████| 751/751 [00:49<00:00, 15.18batch/s, accuracy=100.0%, loss=0.168]
Epoch 3: 100%|████████████████████████████████████████| 751/751 [00:47<00:00, 15.97batch/s, accuracy=100.0%, loss=0.43]
Epoch 4: 100%|██████████████████████████████████████| 751/751 [00:49<00:00, 15.30batch/s, accuracy=100.0%, loss=0.0261]


In [20]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|██                                                 | 1/25 [00:14<05:45, 14.41s/it]

Group (0, 0) Accuracy: 98.58156028368795


Evaluating group-wise accuracy:   8%|████                                               | 2/25 [00:27<05:18, 13.87s/it]

Group (0, 1) Accuracy: 13.238770685579196


Evaluating group-wise accuracy:  12%|██████                                             | 3/25 [00:43<05:19, 14.53s/it]

Group (0, 2) Accuracy: 93.3806146572104


Evaluating group-wise accuracy:  16%|████████▏                                          | 4/25 [01:00<05:26, 15.53s/it]

Group (0, 3) Accuracy: 85.81560283687944


Evaluating group-wise accuracy:  20%|██████████▏                                        | 5/25 [01:18<05:27, 16.38s/it]

Group (0, 4) Accuracy: 92.90780141843972


Evaluating group-wise accuracy:  24%|████████████▏                                      | 6/25 [01:34<05:11, 16.41s/it]

Group (1, 0) Accuracy: 32.02933985330073


Evaluating group-wise accuracy:  28%|██████████████▎                                    | 7/25 [01:52<05:02, 16.80s/it]

Group (1, 1) Accuracy: 88.75305623471883


Evaluating group-wise accuracy:  32%|████████████████▎                                  | 8/25 [02:09<04:47, 16.93s/it]

Group (1, 2) Accuracy: 68.38235294117646


Evaluating group-wise accuracy:  36%|██████████████████▎                                | 9/25 [02:28<04:41, 17.61s/it]

Group (1, 3) Accuracy: 77.69607843137256


Evaluating group-wise accuracy:  40%|████████████████████                              | 10/25 [02:45<04:18, 17.26s/it]

Group (1, 4) Accuracy: 67.6470588235294


Evaluating group-wise accuracy:  44%|██████████████████████                            | 11/25 [03:02<04:01, 17.22s/it]

Group (2, 0) Accuracy: 58.666666666666664


Evaluating group-wise accuracy:  48%|████████████████████████                          | 12/25 [03:18<03:41, 17.06s/it]

Group (2, 1) Accuracy: 84.8


Evaluating group-wise accuracy:  52%|██████████████████████████                        | 13/25 [03:34<03:20, 16.72s/it]

Group (2, 2) Accuracy: 97.06666666666666


Evaluating group-wise accuracy:  56%|████████████████████████████                      | 14/25 [03:52<03:06, 16.96s/it]

Group (2, 3) Accuracy: 5.066666666666666


Evaluating group-wise accuracy:  60%|██████████████████████████████                    | 15/25 [04:11<02:55, 17.58s/it]

Group (2, 4) Accuracy: 64.97326203208556


Evaluating group-wise accuracy:  64%|████████████████████████████████                  | 16/25 [04:29<02:40, 17.82s/it]

Group (3, 0) Accuracy: 80.90452261306532


Evaluating group-wise accuracy:  68%|██████████████████████████████████                | 17/25 [04:47<02:22, 17.76s/it]

Group (3, 1) Accuracy: 86.64987405541562


Evaluating group-wise accuracy:  72%|████████████████████████████████████              | 18/25 [05:04<02:04, 17.71s/it]

Group (3, 2) Accuracy: 6.801007556675063


Evaluating group-wise accuracy:  76%|██████████████████████████████████████            | 19/25 [05:22<01:46, 17.76s/it]

Group (3, 3) Accuracy: 98.48866498740554


Evaluating group-wise accuracy:  80%|████████████████████████████████████████          | 20/25 [05:42<01:31, 18.25s/it]

Group (3, 4) Accuracy: 10.075566750629722


Evaluating group-wise accuracy:  84%|██████████████████████████████████████████        | 21/25 [06:03<01:16, 19.04s/it]

Group (4, 0) Accuracy: 81.61209068010075


Evaluating group-wise accuracy:  88%|████████████████████████████████████████████      | 22/25 [06:20<00:56, 18.69s/it]

Group (4, 1) Accuracy: 66.24685138539043


Evaluating group-wise accuracy:  92%|██████████████████████████████████████████████    | 23/25 [06:38<00:36, 18.40s/it]

Group (4, 2) Accuracy: 66.75062972292191


Evaluating group-wise accuracy:  96%|████████████████████████████████████████████████  | 24/25 [06:55<00:18, 18.07s/it]

Group (4, 3) Accuracy: 0.7575757575757576


Evaluating group-wise accuracy: 100%|██████████████████████████████████████████████████| 25/25 [07:14<00:00, 17.37s/it]

Group (4, 4) Accuracy: 98.73737373737374





{(0, 0): 98.58156028368795,
 (0, 1): 13.238770685579196,
 (0, 2): 93.3806146572104,
 (0, 3): 85.81560283687944,
 (0, 4): 92.90780141843972,
 (1, 0): 32.02933985330073,
 (1, 1): 88.75305623471883,
 (1, 2): 68.38235294117646,
 (1, 3): 77.69607843137256,
 (1, 4): 67.6470588235294,
 (2, 0): 58.666666666666664,
 (2, 1): 84.8,
 (2, 2): 97.06666666666666,
 (2, 3): 5.066666666666666,
 (2, 4): 64.97326203208556,
 (3, 0): 80.90452261306532,
 (3, 1): 86.64987405541562,
 (3, 2): 6.801007556675063,
 (3, 3): 98.48866498740554,
 (3, 4): 10.075566750629722,
 (4, 0): 81.61209068010075,
 (4, 1): 66.24685138539043,
 (4, 2): 66.75062972292191,
 (4, 3): 0.7575757575757576,
 (4, 4): 98.73737373737374}

In [21]:
evaluator.worst_group_accuracy

((4, 3), 0.7575757575757576)

In [22]:
evaluator.average_accuracy

95.93536387141634

In [23]:
evaluator.evaluate_spurious_attribute_prediction()

39.01