In [1]:
%pip install spuco

Collecting spuco
  Downloading spuco-1.0.3-py3-none-any.whl (101 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m92.2/101.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.0/101.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting wilds>=2.0.0 (from spuco)
  Downloading wilds-2.0.0-py3-none-any.whl (126 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.2/126.2 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting ogb>=1.2.6 (from wilds>=2.0.0->spuco)
  Downloading ogb-1.3.6-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting outdated>=0.2.0 (from wilds>=2.0.0->spuco)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collect

In [2]:
!pip install torch
import torch

device = torch.device("cuda")



In [3]:
from spuco.utils import set_seed

set_seed(0)

In [4]:
from spuco.robust_train import ERM
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995, # define how many datas in a group will have spurious feature
    classes=classes,
    split="train",
    label_noise=0.001
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 51899768.99it/s]


Extracting /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 65161750.31it/s]


Extracting /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 35295965.07it/s]


Extracting /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17429578.01it/s]


Extracting /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw



100%|██████████| 48004/48004 [00:12<00:00, 3870.16it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5340.10it/s]


In [5]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]

valset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="val"
)
valset.initialize()


100%|██████████| 11996/11996 [00:03<00:00, 3309.74it/s]


## Group-Balancing

In [21]:
# Use package spuco
from spuco.group_inference import JTTInference
from spuco.utils import Trainer
from spuco.robust_train import UpSampleERM, CustomSampleERM
import random
import numpy as np
from spuco.utils.random_seed import seed_randomness
from spuco.evaluate import Evaluator
from torch.optim import SGD
from spuco.models import model_factory

# Use JTT Inference after each epoch, and upsample it to 3 times of the max group

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)

print("Train before JTT Inference")
pre_infer_trainer = Trainer(
    trainset=trainset,
    model=model,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
pre_infer_trainer.train(1)

# Do JTT Inference
predictions = torch.argmax(pre_infer_trainer.get_trainset_outputs(), dim=-1).detach().cpu().tolist()
jtt_partition = JTTInference(
    predictions=predictions,
    class_labels=trainset.labels
).infer_groups()

# print("jtt_partition:", jtt_partition)
# Train with dataset after JTT Inference
val_evaluator = Evaluator(
    testset=valset,
    group_partition=valset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)

upsample_amount = 3

len_max_group = max([len(jtt_partition[key]) for key in jtt_partition.keys()])
print("len_max_group=", len_max_group)
print(jtt_partition.keys())
up_indices = []
for key in jtt_partition.keys():
    if key == (0, 1):
      up_indices.extend(jtt_partition[key]*int(len_max_group*upsample_amount/len(jtt_partition[key])))
    elif key == (0, 0):
      up_indices.extend(jtt_partition[key])
    else:
      print('len(jtt_partition[key]) <= 0! key=', key, ", len(jtt_partition[key])=", len(jtt_partition[key]))

print("\n Train after JTT Inference(1)")
post_infer_trainer = CustomSampleERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    val_evaluator=val_evaluator,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    indices=up_indices,
    device=device,
    verbose=True
)
post_infer_trainer.train()

len_max_group = max([len(jtt_partition[key]) for key in jtt_partition.keys()])
print(jtt_partition.keys())
up_indices = []
for key in jtt_partition.keys():
    if 0 < len(jtt_partition[key]) < len_max_group:
      up_indices.extend(jtt_partition[key]*int(len_max_group*upsample_amount/len(jtt_partition[key])))
    elif len(jtt_partition[key]) == len_max_group:
      up_indices.extend(jtt_partition[key])
    else:
      print('len(jtt_partition[key]) <= 0! key=', key, ", len(jtt_partition[key])=", len(jtt_partition[key]))

print("\n Train after JTT Inference(2)")
post_infer_trainer = CustomSampleERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    val_evaluator=val_evaluator,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    indices=up_indices,
    device=device,
    verbose=True
)
post_infer_trainer.train()

Train before JTT Inference


Epoch 0: 100%|██████████| 751/751 [00:07<00:00, 103.32batch/s, accuracy=100.0%, loss=0.00644]
Getting Trainset Outputs: 100%|██████████| 751/751 [00:02<00:00, 304.79batch/s]


len_max_group= 47713
dict_keys([(0, 0), (0, 1)])

 Train after JTT Inference(1)


Epoch 0: 100%|██████████| 2979/2979 [00:30<00:00, 96.85batch/s, accuracy=100.0%, loss=4.23e-5] 
Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:06,  3.84it/s]

Group (0, 0) Accuracy: 99.80276134122288


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:06,  3.76it/s]

Group (0, 1) Accuracy: 90.92702169625247


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:00<00:05,  3.81it/s]

Group (0, 2) Accuracy: 76.6798418972332


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:05,  3.85it/s]

Group (0, 3) Accuracy: 81.81818181818181


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:05,  3.84it/s]

Group (0, 4) Accuracy: 81.42292490118577


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:01<00:04,  3.83it/s]

Group (1, 0) Accuracy: 60.53719008264463


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:01<00:04,  3.82it/s]

Group (1, 1) Accuracy: 98.34710743801652


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:04,  3.81it/s]

Group (1, 2) Accuracy: 58.799171842650104


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:02<00:04,  3.83it/s]

Group (1, 3) Accuracy: 72.67080745341615


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:02<00:03,  3.81it/s]

Group (1, 4) Accuracy: 51.13871635610766


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:02<00:03,  3.82it/s]

Group (2, 0) Accuracy: 64.07982261640798


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.69it/s]

Group (2, 1) Accuracy: 37.472283813747225


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:03<00:03,  3.22it/s]

Group (2, 2) Accuracy: 98.66666666666667


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:03<00:03,  2.92it/s]

Group (2, 3) Accuracy: 54.22222222222222


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:03,  2.78it/s]

Group (2, 4) Accuracy: 13.555555555555555


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:03,  2.63it/s]

Group (3, 0) Accuracy: 50.0


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:05<00:03,  2.53it/s]

Group (3, 1) Accuracy: 75.35934291581108


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:05<00:02,  2.41it/s]

Group (3, 2) Accuracy: 81.10882956878851


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:06<00:02,  2.33it/s]

Group (3, 3) Accuracy: 99.38398357289527


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:06<00:02,  2.37it/s]

Group (3, 4) Accuracy: 65.50308008213553


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:06<00:01,  2.42it/s]

Group (4, 0) Accuracy: 66.73728813559322


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:07<00:01,  2.65it/s]

Group (4, 1) Accuracy: 49.152542372881356


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:07<00:00,  2.88it/s]

Group (4, 2) Accuracy: 49.78813559322034


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:07<00:00,  3.06it/s]

Group (4, 3) Accuracy: 76.90677966101696


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  3.08it/s]


Group (4, 4) Accuracy: 100.0
Epoch 0: Val Worst-Group Accuracy: 13.555555555555555
Best Val Worst-Group Accuracy: 13.555555555555555
dict_keys([(0, 0), (0, 1)])

 Train after JTT Inference(2)


Epoch 0: 100%|██████████| 2979/2979 [00:28<00:00, 105.00batch/s, accuracy=100.0%, loss=5.36e-7]
Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:10,  2.28it/s]

Group (0, 0) Accuracy: 99.40828402366864


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:09,  2.30it/s]

Group (0, 1) Accuracy: 86.58777120315582


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:09,  2.31it/s]

Group (0, 2) Accuracy: 70.94861660079052


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:08,  2.39it/s]

Group (0, 3) Accuracy: 74.90118577075098


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:08,  2.48it/s]

Group (0, 4) Accuracy: 74.90118577075098


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:06,  2.79it/s]

Group (1, 0) Accuracy: 73.96694214876032


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:05,  3.09it/s]

Group (1, 1) Accuracy: 98.96694214876032


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:05,  3.29it/s]

Group (1, 2) Accuracy: 59.006211180124225


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:03<00:04,  3.44it/s]

Group (1, 3) Accuracy: 71.42857142857143


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:03<00:04,  3.53it/s]

Group (1, 4) Accuracy: 47.61904761904762


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:03,  3.62it/s]

Group (2, 0) Accuracy: 62.971175166297115


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.65it/s]

Group (2, 1) Accuracy: 43.90243902439025


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:04<00:03,  3.72it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:04<00:02,  3.72it/s]

Group (2, 3) Accuracy: 60.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:02,  3.79it/s]

Group (2, 4) Accuracy: 13.333333333333334


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:02,  3.82it/s]

Group (3, 0) Accuracy: 45.69672131147541


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:05<00:02,  3.79it/s]

Group (3, 1) Accuracy: 59.95893223819302


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:05<00:01,  3.75it/s]

Group (3, 2) Accuracy: 69.19917864476386


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:05<00:01,  3.79it/s]

Group (3, 3) Accuracy: 99.79466119096509


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:06<00:01,  3.85it/s]

Group (3, 4) Accuracy: 52.772073921971256


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:06<00:01,  3.87it/s]

Group (4, 0) Accuracy: 72.88135593220339


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:06<00:00,  3.80it/s]

Group (4, 1) Accuracy: 54.8728813559322


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:06<00:00,  3.85it/s]

Group (4, 2) Accuracy: 33.47457627118644


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:07<00:00,  3.86it/s]

Group (4, 3) Accuracy: 59.53389830508475


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.42it/s]

Group (4, 4) Accuracy: 100.0
Epoch 0: Val Worst-Group Accuracy: 13.333333333333334
Best Val Worst-Group Accuracy: 13.333333333333334





In [22]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

evaluator.worst_group_accuracy

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:06,  3.63it/s]

Group (0, 0) Accuracy: 99.76359338061465


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:06,  3.75it/s]

Group (0, 1) Accuracy: 86.99763593380615


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:00<00:05,  3.81it/s]

Group (0, 2) Accuracy: 64.77541371158392


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:05,  3.77it/s]

Group (0, 3) Accuracy: 73.99527186761229


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:05,  3.71it/s]

Group (0, 4) Accuracy: 72.10401891252955


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:01<00:05,  3.78it/s]

Group (1, 0) Accuracy: 78.239608801956


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:01<00:04,  3.75it/s]

Group (1, 1) Accuracy: 99.75550122249389


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:04,  3.70it/s]

Group (1, 2) Accuracy: 63.23529411764706


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:02<00:04,  3.75it/s]

Group (1, 3) Accuracy: 72.05882352941177


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:02<00:03,  3.78it/s]

Group (1, 4) Accuracy: 41.911764705882355


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:02<00:03,  3.78it/s]

Group (2, 0) Accuracy: 63.2


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.63it/s]

Group (2, 1) Accuracy: 46.93333333333333


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:03<00:03,  3.65it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:03<00:03,  3.62it/s]

Group (2, 3) Accuracy: 67.2


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:02,  3.60it/s]

Group (2, 4) Accuracy: 12.834224598930481


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:02,  3.68it/s]

Group (3, 0) Accuracy: 49.74874371859296


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:04<00:02,  3.75it/s]

Group (3, 1) Accuracy: 61.9647355163728


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:04<00:02,  3.40it/s]

Group (3, 2) Accuracy: 69.26952141057934


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:05<00:02,  2.95it/s]

Group (3, 3) Accuracy: 99.49622166246851


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:05<00:01,  2.78it/s]

Group (3, 4) Accuracy: 58.4382871536524


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:06<00:01,  2.67it/s]

Group (4, 0) Accuracy: 75.56675062972292


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:06<00:01,  2.54it/s]

Group (4, 1) Accuracy: 57.9345088161209


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:07<00:00,  2.49it/s]

Group (4, 2) Accuracy: 29.722921914357684


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:07<00:00,  2.50it/s]

Group (4, 3) Accuracy: 56.06060606060606


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.18it/s]

Group (4, 4) Accuracy: 99.74747474747475





((2, 4), 12.834224598930481)

#### Backup Code: Self-implimented JTT

In [17]:
# Self-implimented JTT
from spuco.group_inference import JTTInference
from spuco.utils import Trainer
from spuco.robust_train import UpSampleERM, CustomSampleERM
import random
import numpy as np
from spuco.utils.random_seed import seed_randomness
from spuco.evaluate import Evaluator
from torch.optim import SGD
from spuco.models import model_factory

# Use JTT Inference after each epoch, and upsample it to 3 times of the max group

model2 = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)

print("Train before JTT Inference")
pre_infer_trainer = Trainer(
    trainset=trainset,
    model=model2,
    batch_size=64,
    optimizer=SGD(model2.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
pre_infer_trainer.train(1)

# Do JTT Inference after first epoch
predictions = torch.argmax(pre_infer_trainer.get_trainset_outputs(), dim=-1).detach().cpu().tolist()
print("len(predictions) epoch 0:", len(predictions))

def myJTTUpsample(predictions, labels, upsample_amount: int=3):
  res = []
  mislabel = []
  match_key_len = 0
  unmatch_key_len = 0
  for idx, pred in enumerate(predictions):
    if pred == labels[idx]:
      res.append(idx)
      match_key_len += 1
    else:
      mislabel.append(idx)
      unmatch_key_len += 1

  res.extend(mislabel * int(upsample_amount*match_key_len/unmatch_key_len))
  random.shuffle(res)
  return res


up_indices = myJTTUpsample(predictions, trainset.labels, 3)
print(len(up_indices), "up_indices:", up_indices[:50])

val_evaluator = Evaluator(
    testset=valset,
    group_partition=valset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model2,
    device=device,
    verbose=True
)

print("\n Train after JTT Inference(1)")
post_infer_trainer = CustomSampleERM(
    model=model2,
    num_epochs=2,
    trainset=trainset,
    val_evaluator=val_evaluator,
    batch_size=64,
    optimizer=SGD(model2.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    indices=up_indices,
    device=device,
    verbose=True
)
post_infer_trainer.train()

# # Do JTT Inference after second epoch
# predictions = torch.argmax(pre_infer_trainer.get_trainset_outputs(), dim=-1).detach().cpu().tolist()
# up_indices = myJTTUpsample(predictions, trainset.labels, 10)

# print("\n Train after JTT Inference(2)")
# post_infer_trainer = CustomSampleERM(
#     model=model2,
#     num_epochs=1,
#     trainset=trainset,
#     val_evaluator=val_evaluator,
#     batch_size=64,
#     optimizer=SGD(model2.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
#     indices=up_indices,
#     device=device,
#     verbose=True
# )
# post_infer_trainer.train()



Train before JTT Inference


Epoch 0: 100%|██████████| 751/751 [00:07<00:00, 102.68batch/s, accuracy=100.0%, loss=0.00644]
Getting Trainset Outputs: 100%|██████████| 751/751 [00:02<00:00, 286.55batch/s]


len(predictions) epoch 0: 48004
190594 up_indices: [17460, 22433, 4901, 6524, 160, 29032, 33459, 33761, 36418, 7370, 22439, 13236, 36975, 47080, 11000, 747, 29009, 44156, 6002, 46902, 22120, 3079, 33840, 18188, 45195, 6045, 3934, 19429, 22221, 46902, 33189, 19132, 44780, 38116, 37675, 19118, 24102, 40997, 5986, 5986, 19645, 2978, 31116, 21186, 31982, 43999, 45297, 21025, 6991, 2042]

 Train after JTT Inference(1)


Epoch 0: 100%|██████████| 2979/2979 [00:26<00:00, 113.99batch/s, accuracy=100.0%, loss=0.000126]
Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:09,  2.41it/s]

Group (0, 0) Accuracy: 99.0138067061144


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:09,  2.40it/s]

Group (0, 1) Accuracy: 82.05128205128206


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:09,  2.43it/s]

Group (0, 2) Accuracy: 84.38735177865613


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:08,  2.49it/s]

Group (0, 3) Accuracy: 92.4901185770751


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:07,  2.53it/s]

Group (0, 4) Accuracy: 83.39920948616601


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:06,  2.83it/s]

Group (1, 0) Accuracy: 68.38842975206612


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:05,  3.12it/s]

Group (1, 1) Accuracy: 99.17355371900827


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:05,  3.30it/s]

Group (1, 2) Accuracy: 73.2919254658385


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:03<00:04,  3.49it/s]

Group (1, 3) Accuracy: 71.84265010351967


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:03<00:04,  3.48it/s]

Group (1, 4) Accuracy: 68.32298136645963


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:03,  3.61it/s]

Group (2, 0) Accuracy: 63.85809312638581


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.66it/s]

Group (2, 1) Accuracy: 28.824833702882483


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:04<00:03,  3.74it/s]

Group (2, 2) Accuracy: 99.33333333333333


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:04<00:02,  3.78it/s]

Group (2, 3) Accuracy: 62.888888888888886


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:02,  3.77it/s]

Group (2, 4) Accuracy: 44.888888888888886


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:02,  3.73it/s]

Group (3, 0) Accuracy: 50.40983606557377


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:05<00:02,  3.80it/s]

Group (3, 1) Accuracy: 74.94866529774127


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:05<00:01,  3.83it/s]

Group (3, 2) Accuracy: 79.87679671457906


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:05<00:01,  3.86it/s]

Group (3, 3) Accuracy: 97.3305954825462


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:05<00:01,  3.84it/s]

Group (3, 4) Accuracy: 62.42299794661191


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:06<00:01,  3.87it/s]

Group (4, 0) Accuracy: 50.63559322033898


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:06<00:00,  3.88it/s]

Group (4, 1) Accuracy: 22.66949152542373


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:06<00:00,  3.90it/s]

Group (4, 2) Accuracy: 19.915254237288135


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:06<00:00,  3.91it/s]

Group (4, 3) Accuracy: 63.983050847457626


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.47it/s]


Group (4, 4) Accuracy: 96.39065817409767
Epoch 0: Val Worst-Group Accuracy: 19.915254237288135
Best Val Worst-Group Accuracy: 19.915254237288135


Epoch 1: 100%|██████████| 2979/2979 [00:26<00:00, 112.34batch/s, accuracy=100.0%, loss=0.00146]
Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:09,  2.54it/s]

Group (0, 0) Accuracy: 99.40828402366864


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:09,  2.53it/s]

Group (0, 1) Accuracy: 71.20315581854044


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:07,  2.90it/s]

Group (0, 2) Accuracy: 72.72727272727273


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:06,  3.21it/s]

Group (0, 3) Accuracy: 79.24901185770752


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:05,  3.44it/s]

Group (0, 4) Accuracy: 59.88142292490119


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:01<00:05,  3.53it/s]

Group (1, 0) Accuracy: 69.21487603305785


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:04,  3.66it/s]

Group (1, 1) Accuracy: 99.79338842975207


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:04,  3.72it/s]

Group (1, 2) Accuracy: 72.46376811594203


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:02<00:04,  3.67it/s]

Group (1, 3) Accuracy: 65.42443064182194


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:02<00:04,  3.73it/s]

Group (1, 4) Accuracy: 57.7639751552795


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:03,  3.80it/s]

Group (2, 0) Accuracy: 74.94456762749445


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.84it/s]

Group (2, 1) Accuracy: 38.35920177383592


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:03<00:03,  3.81it/s]

Group (2, 2) Accuracy: 99.77777777777777


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:03<00:02,  3.82it/s]

Group (2, 3) Accuracy: 59.333333333333336


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:02,  3.87it/s]

Group (2, 4) Accuracy: 37.55555555555556


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:02,  3.85it/s]

Group (3, 0) Accuracy: 59.01639344262295


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:04<00:02,  3.83it/s]

Group (3, 1) Accuracy: 59.75359342915811


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:04<00:01,  3.84it/s]

Group (3, 2) Accuracy: 64.06570841889118


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:05<00:01,  3.83it/s]

Group (3, 3) Accuracy: 99.58932238193019


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:05<00:01,  3.82it/s]

Group (3, 4) Accuracy: 58.31622176591376


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:05<00:01,  3.83it/s]

Group (4, 0) Accuracy: 68.64406779661017


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:06<00:00,  3.84it/s]

Group (4, 1) Accuracy: 11.864406779661017


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:06<00:00,  3.88it/s]

Group (4, 2) Accuracy: 32.41525423728814


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:06<00:00,  3.89it/s]

Group (4, 3) Accuracy: 56.567796610169495


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:06<00:00,  3.68it/s]

Group (4, 4) Accuracy: 98.93842887473461
Epoch 1: Val Worst-Group Accuracy: 11.864406779661017
Best Val Worst-Group Accuracy: 19.915254237288135





In [20]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model2,
    device=device,
    verbose=True
)
evaluator.evaluate()

evaluator.worst_group_accuracy

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:06,  3.94it/s]

Group (0, 0) Accuracy: 99.29078014184397


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:06,  3.81it/s]

Group (0, 1) Accuracy: 73.75886524822695


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:00<00:05,  3.78it/s]

Group (0, 2) Accuracy: 69.73995271867612


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:05,  3.65it/s]

Group (0, 3) Accuracy: 80.61465721040189


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:05,  3.75it/s]

Group (0, 4) Accuracy: 63.12056737588652


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:01<00:05,  3.80it/s]

Group (1, 0) Accuracy: 72.12713936430318


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:01<00:04,  3.84it/s]

Group (1, 1) Accuracy: 99.75550122249389


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:04,  3.76it/s]

Group (1, 2) Accuracy: 73.03921568627452


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:02<00:04,  3.77it/s]

Group (1, 3) Accuracy: 71.56862745098039


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:02<00:03,  3.79it/s]

Group (1, 4) Accuracy: 55.14705882352941


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:02<00:03,  3.84it/s]

Group (2, 0) Accuracy: 73.06666666666666


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.90it/s]

Group (2, 1) Accuracy: 43.46666666666667


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:03<00:03,  3.89it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:03<00:02,  3.83it/s]

Group (2, 3) Accuracy: 62.666666666666664


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:03<00:02,  3.86it/s]

Group (2, 4) Accuracy: 37.4331550802139


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:02,  3.87it/s]

Group (3, 0) Accuracy: 66.33165829145729


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:04<00:02,  3.91it/s]

Group (3, 1) Accuracy: 58.94206549118388


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:04<00:01,  3.90it/s]

Group (3, 2) Accuracy: 66.24685138539043


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:04<00:01,  3.90it/s]

Group (3, 3) Accuracy: 99.74811083123426


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:05<00:01,  3.92it/s]

Group (3, 4) Accuracy: 54.659949622166245


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:05<00:01,  3.40it/s]

Group (4, 0) Accuracy: 74.55919395465995


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:05<00:00,  3.07it/s]

Group (4, 1) Accuracy: 11.083123425692696


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:06<00:00,  2.80it/s]

Group (4, 2) Accuracy: 34.76070528967254


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:06<00:00,  2.62it/s]

Group (4, 3) Accuracy: 57.323232323232325


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.44it/s]

Group (4, 4) Accuracy: 99.4949494949495





((4, 1), 11.083123425692696)