<a href="https://colab.research.google.com/github/emmatliu/spuco-mnist/blob/main/spuco_mnist_ml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install spuco

Collecting spuco
  Downloading spuco-1.0.3-py3-none-any.whl (101 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m92.2/101.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.0/101.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting wilds>=2.0.0 (from spuco)
  Downloading wilds-2.0.0-py3-none-any.whl (126 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.2/126.2 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers>=3.5.0 (from spuco)
  Downloading transformers-4.34.1-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m101.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers>=3.5.0->spuco)
  Downloading huggingface_hub-0.18.0-py

In [None]:
import spuco
import torch

from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
from spuco.robust_train import ERM, GroupBalanceBatchERM
from spuco.evaluate import Evaluator
from spuco.models import model_factory
from spuco.group_inference import Cluster, ClusterAlg
from spuco.utils import Trainer

In [None]:
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE
classes = [[1,7],[2,5],[3,8],[4,9],[0,6]] # Visual similarity?

train = SpuCoMNIST(
    "/mnist/",
    difficulty,
    classes,
    0.99,
    split="train"
)

test = SpuCoMNIST(
    "/mnist/",
    difficulty,
    classes,
    split="test"
)

train.initialize()
test.initialize()

100%|██████████| 48004/48004 [00:08<00:00, 5901.41it/s]
100%|██████████| 10000/10000 [00:02<00:00, 4444.95it/s]


In [None]:
train_shape = train[0][0].shape

# Step 1: Train a model using ERM

In [None]:
model = model_factory('lenet',train_shape,len(classes))
optimizer = torch.optim.Adam(model.parameters())

In [None]:
erm = ERM(
    model=model,
    trainset=train,
    batch_size=64,
    optimizer=optimizer,
    num_epochs=5,
    verbose=True
)
erm.train()

Epoch 0: 100%|██████████| 751/751 [00:16<00:00, 45.64batch/s, accuracy=100.0%, loss=0.00755]
Epoch 1: 100%|██████████| 751/751 [00:16<00:00, 45.29batch/s, accuracy=100.0%, loss=0.00587]
Epoch 2: 100%|██████████| 751/751 [00:15<00:00, 46.94batch/s, accuracy=100.0%, loss=0.0016]
Epoch 3: 100%|██████████| 751/751 [00:15<00:00, 47.47batch/s, accuracy=100.0%, loss=0.016]
Epoch 4: 100%|██████████| 751/751 [00:26<00:00, 28.09batch/s, accuracy=100.0%, loss=0.0162]


# Step 2: Cluster inputs using ERM output

In [None]:
trainer = Trainer(
    trainset=train,
    model=model,
    batch_size=64,
    optimizer=optimizer,
    verbose=True
)
trainer.train(0)

outs = trainer.get_trainset_outputs()

cluster = Cluster(
    Z=outs,
    class_labels=train.labels,
    num_clusters=3,
    verbose=True
)
groups = cluster.infer_groups()

Getting Trainset Outputs: 100%|██████████| 751/751 [00:09<00:00, 79.57batch/s] 
Clustering class-wise: 100%|██████████| 5/5 [00:00<00:00,  9.31it/s]


# Step 3: Retrain using group-balancing

In [None]:
gb_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=train,
    group_partition=groups,
    batch_size=64,
    optimizer=optimizer,
    verbose=True
)
gb_erm.train()

Epoch 0: 100%|██████████| 751/751 [00:16<00:00, 45.50batch/s, accuracy=100.0%, loss=0.00055]
Epoch 1: 100%|██████████| 751/751 [00:16<00:00, 46.71batch/s, accuracy=100.0%, loss=0.0686]
Epoch 2: 100%|██████████| 751/751 [00:16<00:00, 46.82batch/s, accuracy=100.0%, loss=2.17e-5]
Epoch 3: 100%|██████████| 751/751 [00:25<00:00, 29.84batch/s, accuracy=100.0%, loss=0.000432]
Epoch 4: 100%|██████████| 751/751 [00:16<00:00, 46.43batch/s, accuracy=100.0%, loss=0.000121]


# Some evaluation

In [None]:
evaluator = Evaluator(
    testset=test,
    group_partition=test.group_partition,
    group_weights=test.group_weights,
    batch_size=64,
    model=model,
    verbose=True
)

evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:13,  1.79it/s]

Group (0, 0) Accuracy: 99.7690531177829


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:01<00:11,  1.93it/s]

Group (0, 1) Accuracy: 52.19399538106236


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:09,  2.31it/s]

Group (0, 2) Accuracy: 42.263279445727484


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:08,  2.55it/s]

Group (0, 3) Accuracy: 75.23148148148148


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:07,  2.74it/s]

Group (0, 4) Accuracy: 93.98148148148148


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:06,  2.89it/s]

Group (1, 0) Accuracy: 32.72727272727273


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:06,  2.94it/s]

Group (1, 1) Accuracy: 99.48051948051948


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:05,  3.02it/s]

Group (1, 2) Accuracy: 14.025974025974026


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:03<00:05,  3.05it/s]

Group (1, 3) Accuracy: 36.62337662337662


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:03<00:04,  3.07it/s]

Group (1, 4) Accuracy: 45.052083333333336


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:04,  3.11it/s]

Group (2, 0) Accuracy: 71.53652392947103


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:04<00:04,  3.19it/s]

Group (2, 1) Accuracy: 75.56675062972292


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:04<00:03,  3.08it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:04<00:03,  3.08it/s]

Group (2, 3) Accuracy: 66.49874055415617


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:05<00:03,  2.70it/s]

Group (2, 4) Accuracy: 62.121212121212125


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:05<00:03,  2.46it/s]

Group (3, 0) Accuracy: 74.93734335839599


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:06<00:03,  2.34it/s]

Group (3, 1) Accuracy: 85.92964824120602


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:06<00:03,  2.21it/s]

Group (3, 2) Accuracy: 74.37185929648241


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:07<00:02,  2.21it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:07<00:02,  2.20it/s]

Group (3, 4) Accuracy: 85.17587939698493


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:08<00:01,  2.18it/s]

Group (4, 0) Accuracy: 96.64948453608247


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:08<00:01,  2.18it/s]

Group (4, 1) Accuracy: 85.82474226804123


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:09<00:00,  2.15it/s]

Group (4, 2) Accuracy: 87.37113402061856


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:09<00:00,  2.15it/s]

Group (4, 3) Accuracy: 81.65374677002583


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.49it/s]

Group (4, 4) Accuracy: 100.0





{(0, 0): 99.7690531177829,
 (0, 1): 52.19399538106236,
 (0, 2): 42.263279445727484,
 (0, 3): 75.23148148148148,
 (0, 4): 93.98148148148148,
 (1, 0): 32.72727272727273,
 (1, 1): 99.48051948051948,
 (1, 2): 14.025974025974026,
 (1, 3): 36.62337662337662,
 (1, 4): 45.052083333333336,
 (2, 0): 71.53652392947103,
 (2, 1): 75.56675062972292,
 (2, 2): 100.0,
 (2, 3): 66.49874055415617,
 (2, 4): 62.121212121212125,
 (3, 0): 74.93734335839599,
 (3, 1): 85.92964824120602,
 (3, 2): 74.37185929648241,
 (3, 3): 100.0,
 (3, 4): 85.17587939698493,
 (4, 0): 96.64948453608247,
 (4, 1): 85.82474226804123,
 (4, 2): 87.37113402061856,
 (4, 3): 81.65374677002583,
 (4, 4): 100.0}

In [None]:
evaluator.evaluate_spurious_attribute_prediction()



41.22