Coding Assignment: Implementing a Solution for the Spurious Correlations Problem

[Self-assessment for those interested in Empirical Work]. If this is too difficult for you, please consider taking more ML courses first.

We have a sample task just to help us understand if you're familiar with Pytorch etc. and how comfortable you are with machine learning in general.

Deep neural networks often exploit non-predictive features that are spuriously correlated with class labels, leading to poor performance on groups of examples without such features. Using the SpuCo Package (SpuCo Documentation), we'd like you to implement a simple method to remedy spurious correlations in SpuCoMNIST (use default parameters to initialize the dataset). https://spuco.readthedocs.io/en/latest/

The method (George) we'd like you to implement has a 3 step pipeline: https://arxiv.org/abs/2011.12945
1. Train a model using ERM
2. Cluster inputs based on the output they produce for ERM
3. Retrain using "Group-Balancing" to ensure in each batch each group appears equally.

We'd like you to send us a notebook with your code and outputs (similar to the SpuCo Quickstart Notebooks).

Upload your notebook (with both code and outputs) to a public GitHub repository and share the link here. 

In [9]:
import warnings
from tqdm import TqdmWarning

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=TqdmWarning)

In [None]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty

# Defaults from the notebook examples
classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train",
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test",
)
testset.initialize()

## Step 1:
Train a model using ERM.

In [None]:
from spuco.robust_train import ERM
from torch.optim import SGD
from spuco.models import model_factory
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Using LeNet for images
model = model_factory(
    arch="lenet", input_shape=trainset[0][0].shape, num_classes=trainset.num_classes
).to(device)

erm = ERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    batch_size=64,
    # For simplicity, same optimizer as in samples
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=False
)
erm.train()

##### Evaluation before balancing sub-classes

In [4]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=False
)
evaluator.evaluate()

Evaluating group-wise accuracy:   0%|          | 0/25 [00:00<?, ?it/s]

Evaluating group-wise accuracy: 100%|██████████| 25/25 [07:29<00:00, 17.96s/it]


{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 100.0}

## Step 2:
Cluster inputs based on the output they produce for ERM.

In [None]:
# Gets logits for our training set across possible superclasses
logits = erm.trainer.get_trainset_outputs()

In [None]:
from spuco.group_inference import Cluster

cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    # k in [2, 10] pg. 23
    max_clusters=10,
    device=device,
    verbose=True
)

# Uses silhouette scores per superclass to get k values, pg. 6, pg. 23
group_partition = cluster.infer_groups()

Clustering class-wise:   0%|          | 0/5 [00:00<?, ?it/s]

For n_clusters = 2 The average silhouette_score is : 0.9629660248756409
For n_clusters = 3 The average silhouette_score is : 0.962044358253479
For n_clusters = 4 The average silhouette_score is : 0.9633027911186218
For n_clusters = 5 The average silhouette_score is : 0.5213839411735535
For n_clusters = 6 The average silhouette_score is : 0.5249813199043274
For n_clusters = 7 The average silhouette_score is : 0.4771381914615631
For n_clusters = 8 The average silhouette_score is : 0.41998812556266785
For n_clusters = 9 The average silhouette_score is : 0.39356228709220886


Clustering class-wise:  20%|██        | 1/5 [00:08<00:35,  8.89s/it]

For n_clusters = 10 The average silhouette_score is : 0.35772427916526794
For n_clusters = 2 The average silhouette_score is : 0.9450042843818665
For n_clusters = 3 The average silhouette_score is : 0.6658995747566223
For n_clusters = 4 The average silhouette_score is : 0.6672133207321167
For n_clusters = 5 The average silhouette_score is : 0.6685448288917542
For n_clusters = 6 The average silhouette_score is : 0.6695531606674194
For n_clusters = 7 The average silhouette_score is : 0.6078675985336304
For n_clusters = 8 The average silhouette_score is : 0.5511408448219299
For n_clusters = 9 The average silhouette_score is : 0.476537823677063


Clustering class-wise:  40%|████      | 2/5 [00:19<00:30, 10.01s/it]

For n_clusters = 10 The average silhouette_score is : 0.4567190110683441
For n_clusters = 2 The average silhouette_score is : 0.9456753730773926
For n_clusters = 3 The average silhouette_score is : 0.5008295774459839
For n_clusters = 4 The average silhouette_score is : 0.5015779733657837
For n_clusters = 5 The average silhouette_score is : 0.5037319660186768
For n_clusters = 6 The average silhouette_score is : 0.504131555557251
For n_clusters = 7 The average silhouette_score is : 0.438045471906662
For n_clusters = 8 The average silhouette_score is : 0.3961904048919678
For n_clusters = 9 The average silhouette_score is : 0.3542941212654114


Clustering class-wise:  60%|██████    | 3/5 [00:29<00:19,  9.82s/it]

For n_clusters = 10 The average silhouette_score is : 0.33044400811195374
For n_clusters = 2 The average silhouette_score is : 0.9811540246009827
For n_clusters = 3 The average silhouette_score is : 0.9810868501663208
For n_clusters = 4 The average silhouette_score is : 0.9833795428276062
For n_clusters = 5 The average silhouette_score is : 0.9829742908477783
For n_clusters = 6 The average silhouette_score is : 0.4533403217792511
For n_clusters = 7 The average silhouette_score is : 0.38894420862197876
For n_clusters = 8 The average silhouette_score is : 0.3568340539932251
For n_clusters = 9 The average silhouette_score is : 0.3178868591785431


Clustering class-wise:  80%|████████  | 4/5 [00:39<00:10, 10.10s/it]

For n_clusters = 10 The average silhouette_score is : 0.2881195545196533
For n_clusters = 2 The average silhouette_score is : 0.9447611570358276
For n_clusters = 3 The average silhouette_score is : 0.5244637727737427
For n_clusters = 4 The average silhouette_score is : 0.5257400274276733
For n_clusters = 5 The average silhouette_score is : 0.5266082286834717
For n_clusters = 6 The average silhouette_score is : 0.5279449224472046
For n_clusters = 7 The average silhouette_score is : 0.4734761714935303
For n_clusters = 8 The average silhouette_score is : 0.43740060925483704
For n_clusters = 9 The average silhouette_score is : 0.40999120473861694


Clustering class-wise: 100%|██████████| 5/5 [00:49<00:00,  9.94s/it]

For n_clusters = 10 The average silhouette_score is : 0.37849894165992737





In [41]:
# Distribution of subclasses
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

(0, 0) 51
(0, 1) 10082
(1, 0) 9623
(1, 1) 11
(1, 2) 17
(1, 3) 21
(2, 0) 8965
(2, 1) 14
(2, 2) 11
(2, 3) 21
(3, 0) 9698
(3, 1) 49
(4, 0) 9393
(4, 1) 48


## Step 3:
Retrain using "Group-Balancing" to ensure in each batch each group appears equally.

In [None]:
from spuco.robust_train import GroupBalanceBatchERM

# Rebalance each batch using group partition
gpb_erm = GroupBalanceBatchERM(
    model=model,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    num_epochs=1,
    device=device,
    verbose=False
)

gpb_erm.train()

##### Evaluation after balancing sub-classes

In [None]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=False
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:36<14:30, 36.28s/it]

Group (0, 0) Accuracy: 99.29078014184397


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:55<10:01, 26.15s/it]

Group (0, 1) Accuracy: 86.05200945626477


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [01:14<08:26, 23.02s/it]

Group (0, 2) Accuracy: 89.12529550827423


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [01:40<08:23, 23.96s/it]

Group (0, 3) Accuracy: 80.61465721040189


Evaluating group-wise accuracy:  20%|██        | 5/25 [02:02<07:45, 23.25s/it]

Group (0, 4) Accuracy: 77.77777777777777


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [02:20<06:50, 21.60s/it]

Group (1, 0) Accuracy: 74.81662591687042


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [02:36<05:58, 19.91s/it]

Group (1, 1) Accuracy: 98.0440097799511


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [02:55<05:30, 19.42s/it]

Group (1, 2) Accuracy: 76.7156862745098


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [03:13<05:03, 18.99s/it]

Group (1, 3) Accuracy: 67.6470588235294


Evaluating group-wise accuracy:  40%|████      | 10/25 [03:30<04:34, 18.29s/it]

Group (1, 4) Accuracy: 82.84313725490196


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [03:47<04:14, 18.16s/it]

Group (2, 0) Accuracy: 79.2


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [04:05<03:55, 18.13s/it]

Group (2, 1) Accuracy: 17.6


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [04:23<03:34, 17.85s/it]

Group (2, 2) Accuracy: 94.4


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [04:40<03:15, 17.82s/it]

Group (2, 3) Accuracy: 52.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [04:59<02:59, 17.94s/it]

Group (2, 4) Accuracy: 71.65775401069519


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [05:17<02:42, 18.10s/it]

Group (3, 0) Accuracy: 76.38190954773869


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [05:35<02:25, 18.15s/it]

Group (3, 1) Accuracy: 75.31486146095718


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [05:54<02:08, 18.36s/it]

Group (3, 2) Accuracy: 83.62720403022671


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [06:12<01:48, 18.06s/it]

Group (3, 3) Accuracy: 98.99244332493703


Evaluating group-wise accuracy:  80%|████████  | 20/25 [06:28<01:28, 17.65s/it]

Group (3, 4) Accuracy: 58.94206549118388


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [06:46<01:10, 17.61s/it]

Group (4, 0) Accuracy: 80.10075566750629


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [07:03<00:52, 17.52s/it]

Group (4, 1) Accuracy: 69.0176322418136


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [07:21<00:35, 17.58s/it]

Group (4, 2) Accuracy: 75.31486146095718


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [07:38<00:17, 17.39s/it]

Group (4, 3) Accuracy: 63.38383838383838


Evaluating group-wise accuracy: 100%|██████████| 25/25 [07:54<00:00, 18.99s/it]

Group (4, 4) Accuracy: 94.6969696969697





{(0, 0): 99.29078014184397,
 (0, 1): 86.05200945626477,
 (0, 2): 89.12529550827423,
 (0, 3): 80.61465721040189,
 (0, 4): 77.77777777777777,
 (1, 0): 74.81662591687042,
 (1, 1): 98.0440097799511,
 (1, 2): 76.7156862745098,
 (1, 3): 67.6470588235294,
 (1, 4): 82.84313725490196,
 (2, 0): 79.2,
 (2, 1): 17.6,
 (2, 2): 94.4,
 (2, 3): 52.0,
 (2, 4): 71.65775401069519,
 (3, 0): 76.38190954773869,
 (3, 1): 75.31486146095718,
 (3, 2): 83.62720403022671,
 (3, 3): 98.99244332493703,
 (3, 4): 58.94206549118388,
 (4, 0): 80.10075566750629,
 (4, 1): 69.0176322418136,
 (4, 2): 75.31486146095718,
 (4, 3): 63.38383838383838,
 (4, 4): 94.6969696969697}