# Example of implementing the code from JSE

In this notebook, we briefly show how to implement the code in this repository for our Toy dataset. 

In [13]:

# import the necessary libraries
import numpy as np
import pandas as pd
import torch

# import from the JSE package
from JSE.data import *
from JSE.settings import data_info, optimizer_info
from JSE.models import *
from JSE.training import *
from JSE.helpers import *

# import in order 
import argparse
import os
import sys

In [14]:
# define several variables
device_type = 'cpu'
device = torch.device(device_type)
dataset = 'Toy'

# get the dataset information
dataset_setting = 'default'
dataset_settings = data_info[dataset][dataset_setting]
optimizer_settings = optimizer_info['All']

# Determine the spurious correlation for the dataset - in this case corresponding to the \rho value in the paper
spurious_ratio = 0.8

# set the random seed, get the dataset, and set the device
seed = 1
set_seed(seed)

# we will demean the data, apply no pca
demean = True
pca = False
k_components = None # number of pca components - only relevant if pca = True
d = 20 # number of features

# define settings for the model
solver = 'SGD'
lr = 0.01
weight_decay  = 0.0
early_stopping = True
epochs = 50
per_step = 5 # number of epochs between printing the loss
batch_size = 128


In [15]:
# get the dataset object
data_obj = get_dataset_obj(dataset, dataset_settings, spurious_ratio, data_info, seed, device, use_punctuation_MNLI=True)

# demean, pca
if demean:
    data_obj.demean_X(reset_mean=True, include_test=True)
if pca:
    data_obj.transform_data_to_k_components(k_components, reset_V_k=True, include_test=True)
    V_k_train = data_obj.V_k_train

# get the data
X_train, y_c_train, y_m_train = data_obj.X_train, data_obj.y_c_train, data_obj.y_m_train
X_val, y_c_val, y_m_val = data_obj.X_val, data_obj.y_c_val, data_obj.y_m_val
X_test, y_c_test, y_m_test = data_obj.X_test, data_obj.y_c_test, data_obj.y_m_test


v_m/ v_c tensor([0., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]) tensor([3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])
train size:  1600
val size:  400
Sigma:  tensor([[1.0000, 0.8000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.8000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
    

In [16]:
# We first implement simple ERM

# set the loaders of the dataset object
data_obj.reset_X(X_train, X_val, batch_size=batch_size, reset_X_objects=False, include_weights=False, train_weights = None, val_weights = None)

# get the model
ERM_model = return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model')

Start training
Phase : train, Epoch [1/50], Step [5/13], Loss: 0.6749, accuracy: 0.8094
Phase : train, Epoch [1/50], Step [10/13], Loss: 0.6241, accuracy: 0.8180
Improving: current loss 0.5880314707756042 < inf, loss last epoch 
save best param
Phase : train, Epoch [2/50], Step [5/13], Loss: 0.5591, accuracy: 0.7922
Phase : train, Epoch [2/50], Step [10/13], Loss: 0.5048, accuracy: 0.8070
Improving: current loss 0.47776883840560913 < 0.5880314707756042, loss last epoch 
save best param
Phase : train, Epoch [3/50], Step [5/13], Loss: 0.4547, accuracy: 0.8250
Phase : train, Epoch [3/50], Step [10/13], Loss: 0.4706, accuracy: 0.8195
Improving: current loss 0.4265497028827667 < 0.47776883840560913, loss last epoch 
save best param
Phase : train, Epoch [4/50], Step [5/13], Loss: 0.3620, accuracy: 0.8250
Phase : train, Epoch [4/50], Step [10/13], Loss: 0.4506, accuracy: 0.8258
Improving: current loss 0.3518127202987671 < 0.4265497028827667, loss last epoch 
save best param
Phase : train, Epo

In [17]:
# Next, we show how to implement JSE

# define the loaders
balanced_training_concept = False
concept_weights_train = None
concept_weights_val = None
concept_first = True # if True, then inner loop is for the main-task, outer loop is for the spurious concept-task
loaders =  data_obj.create_loaders(batch_size=batch_size, workers=0, with_concept=True, include_weights=balanced_training_concept, train_weights=concept_weights_train, val_weights=concept_weights_val,concept_first=concept_first ) 

alpha = 0.05
Delta = 0
eval_balanced = True

# Run the JSE algorithm - returns the spurious concept basis, the main task concept basis, and the dimension of both
V_c, V_m, d_c, d_m = train_JSE(data_obj,
                                                        device=device,
                                                        batch_size=batch_size, 
                                                        solver=solver,
                                                        lr=lr,
                                                        per_step=per_step,
                                                        tol=optimizer_settings['tol'],
                                                        early_stopping=early_stopping,
                                                        patience=optimizer_settings['patience'],
                                                        epochs=epochs, 
                                                        Delta = Delta,
                                                        alpha=alpha,
                                                         null_is_concept = False,
                                                         eval_balanced=eval_balanced, 
                                                         weight_decay=weight_decay,
                                                         include_weights=balanced_training_concept,
                                                         train_weights=concept_weights_train,
                                                         val_weights=concept_weights_val,
                                                         model_base_name='JSE_'+dataset,
                                                         concept_first=concept_first,
                                                         )

Phase : train, Epoch [1/50], Step [5/13], Loss: 1.3690
accuracy concept predictor 1 : 0.7063
accuracy main predictor 1 : 0.7734
Phase : train, Epoch [1/50], Step [10/13], Loss: 1.3074
accuracy concept predictor 1 : 0.6930
accuracy main predictor 1 : 0.7914
Improving: current loss 1.279515266418457 < inf, loss last epoch 
save best param
Phase : train, Epoch [2/50], Step [5/13], Loss: 1.2044
accuracy concept predictor 1 : 0.7344
accuracy main predictor 1 : 0.8156
Phase : train, Epoch [2/50], Step [10/13], Loss: 1.1326
accuracy concept predictor 1 : 0.7445
accuracy main predictor 1 : 0.8086
Improving: current loss 1.1406502723693848 < 1.279515266418457, loss last epoch 
save best param
Phase : train, Epoch [3/50], Step [5/13], Loss: 1.1110
accuracy concept predictor 1 : 0.7703
accuracy main predictor 1 : 0.7937
Phase : train, Epoch [3/50], Step [10/13], Loss: 0.9752
accuracy concept predictor 1 : 0.7734
accuracy main predictor 1 : 0.7984
Improving: current loss 0.9877809286117554 < 1.140

In [18]:

# define the loaders
balanced_training_concept = False
concept_weights_train = None
concept_weights_val = None
concept_first = True # if True, then inner loop is for the main-task, outer loop is for the spurious concept-task
loaders =  data_obj.create_loaders(batch_size=batch_size, workers=0, with_concept=True, include_weights=balanced_training_concept, train_weights=concept_weights_train, val_weights=concept_weights_val,concept_first=concept_first ) 
 
rafvogel_with_joint = False 
orthogonality_constraint = False


# Train the model to get V_c
set_seed(seed)
V_c_INLP, d_c_INLP = train_INLP(data_obj,
                                                                device,
                                                                batch_size=batch_size, 
                                                                solver=solver,
                                                                lr=lr,
                                                                weight_decay=weight_decay,
                                                                per_step=per_step,
                                                                tol=optimizer_settings['tol'],
                                                                early_stopping=early_stopping,
                                                                patience=optimizer_settings['patience'],
                                                                epochs=epochs,
                                                                alpha=alpha,
                                                                model_base_name='waterbird_rafvogel',
                                                                bias=True,
                                                                joint_decision_rule=rafvogel_with_joint,
                                                                include_weights=balanced_training_concept,
                                                                train_weights=concept_weights_train,
                                                                val_weights=concept_weights_val,
                                                                orthogonality_constraint=orthogonality_constraint,
                                                                expected_diff=None,
                                                                var_diff = None,
                                                                )

False
Start training
Phase : train, Epoch [1/50], Step [5/13], Loss: 0.6839, accuracy: 0.7672
Phase : train, Epoch [1/50], Step [10/13], Loss: 0.6608, accuracy: 0.7477
Improving: current loss 0.6531729102134705 < inf, loss last epoch 
save best param
Phase : train, Epoch [2/50], Step [5/13], Loss: 0.6213, accuracy: 0.7328
Phase : train, Epoch [2/50], Step [10/13], Loss: 0.5904, accuracy: 0.7477
Improving: current loss 0.5691113471984863 < 0.6531729102134705, loss last epoch 
save best param
Phase : train, Epoch [3/50], Step [5/13], Loss: 0.5617, accuracy: 0.7766
Phase : train, Epoch [3/50], Step [10/13], Loss: 0.5731, accuracy: 0.7477
Improving: current loss 0.5361239910125732 < 0.5691113471984863, loss last epoch 
save best param
Phase : train, Epoch [4/50], Step [5/13], Loss: 0.5440, accuracy: 0.7500
Phase : train, Epoch [4/50], Step [10/13], Loss: 0.5222, accuracy: 0.7563
Improving: current loss 0.5085861086845398 < 0.5361239910125732, loss last epoch 
save best param
Phase : train,

In [19]:
# define the orthogonal projection matrix
P_c_orth = torch.eye(d) - create_P(V_c) 

# reset the data
X_train_transformed = torch.matmul(X_train, P_c_orth)
X_val_transformed = torch.matmul(X_val, P_c_orth)
X_test_transformed = torch.matmul(X_test, P_c_orth)

# set the loaders of the dataset object
balanced_training_main = False
main_weights_train = None
main_weights_val = None
data_obj.reset_X(X_train_transformed, X_val_transformed, batch_size=batch_size, reset_X_objects=True, include_weights=balanced_training_main, train_weights = main_weights_train, val_weights = main_weights_val, only_main=True)

# Train the model on the transformed embeddings
set_seed(seed)
ERM_after_JSE = return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model',
                                              save_best_model=True
                                              )

Start training
Phase : train, Epoch [1/50], Step [5/13], Loss: 0.6815, accuracy: 0.8203
Phase : train, Epoch [1/50], Step [10/13], Loss: 0.6478, accuracy: 0.8203
Improving: current loss 0.6304269433021545 < inf, loss last epoch 
save best param
Phase : train, Epoch [2/50], Step [5/13], Loss: 0.5849, accuracy: 0.8266
Phase : train, Epoch [2/50], Step [10/13], Loss: 0.5523, accuracy: 0.8297
Improving: current loss 0.5368441939353943 < 0.6304269433021545, loss last epoch 
save best param
Phase : train, Epoch [3/50], Step [5/13], Loss: 0.5343, accuracy: 0.8344
Phase : train, Epoch [3/50], Step [10/13], Loss: 0.5010, accuracy: 0.8297
Improving: current loss 0.5305685997009277 < 0.5368441939353943, loss last epoch 
save best param
Phase : train, Epoch [4/50], Step [5/13], Loss: 0.4392, accuracy: 0.8250
Phase : train, Epoch [4/50], Step [10/13], Loss: 0.4439, accuracy: 0.8352
Improving: current loss 0.506938099861145 < 0.5305685997009277, loss last epoch 
save best param
Phase : train, Epoch 

In [20]:
# define the orthogonal projection matrix for INLP
P_c_orth_INLP = torch.eye(d) - create_P(V_c_INLP) 

# reset the data
X_train_transformed_INLP = torch.matmul(X_train, P_c_orth_INLP)
X_val_transformed_INLP = torch.matmul(X_val, P_c_orth_INLP)
X_test_transformed_INLP = torch.matmul(X_test, P_c_orth_INLP)

# set the loaders of the dataset object
balanced_training_main = False
main_weights_train = None
main_weights_val = None
data_obj.reset_X(X_train_transformed_INLP, X_val_transformed_INLP, batch_size=batch_size, reset_X_objects=True, include_weights=balanced_training_main, train_weights = main_weights_train, val_weights = main_weights_val, only_main=True)

# Train the model on the transformed embeddings
set_seed(seed)
ERM_after_INLP= return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model',
                                              save_best_model=True
                                              )

Start training
Phase : train, Epoch [1/50], Step [5/13], Loss: 0.6848, accuracy: 0.7422
Phase : train, Epoch [1/50], Step [10/13], Loss: 0.6601, accuracy: 0.7312
Improving: current loss 0.6517851948738098 < inf, loss last epoch 
save best param
Phase : train, Epoch [2/50], Step [5/13], Loss: 0.6218, accuracy: 0.7656
Phase : train, Epoch [2/50], Step [10/13], Loss: 0.5925, accuracy: 0.7555
Improving: current loss 0.5815144181251526 < 0.6517851948738098, loss last epoch 
save best param
Phase : train, Epoch [3/50], Step [5/13], Loss: 0.5997, accuracy: 0.7641
Phase : train, Epoch [3/50], Step [10/13], Loss: 0.5736, accuracy: 0.7570
Improving: current loss 0.5787176489830017 < 0.5815144181251526, loss last epoch 
save best param
Phase : train, Epoch [4/50], Step [5/13], Loss: 0.5240, accuracy: 0.7578
Phase : train, Epoch [4/50], Step [10/13], Loss: 0.5260, accuracy: 0.7609
Improving: current loss 0.547160804271698 < 0.5787176489830017, loss last epoch 
save best param
Phase : train, Epoch 

In [21]:

# Train the model on the transformed embeddings
set_seed(seed)
ERM_after_INLP= return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model',
                                              save_best_model=True
                                              )

Start training
Phase : train, Epoch [1/50], Step [5/13], Loss: 0.6848, accuracy: 0.7422
Phase : train, Epoch [1/50], Step [10/13], Loss: 0.6601, accuracy: 0.7312
Improving: current loss 0.6517851948738098 < inf, loss last epoch 
save best param
Phase : train, Epoch [2/50], Step [5/13], Loss: 0.6218, accuracy: 0.7656
Phase : train, Epoch [2/50], Step [10/13], Loss: 0.5925, accuracy: 0.7555
Improving: current loss 0.5815144181251526 < 0.6517851948738098, loss last epoch 
save best param
Phase : train, Epoch [3/50], Step [5/13], Loss: 0.5997, accuracy: 0.7641
Phase : train, Epoch [3/50], Step [10/13], Loss: 0.5736, accuracy: 0.7570
Improving: current loss 0.5787176489830017 < 0.5815144181251526, loss last epoch 
save best param
Phase : train, Epoch [4/50], Step [5/13], Loss: 0.5240, accuracy: 0.7578
Phase : train, Epoch [4/50], Step [10/13], Loss: 0.5260, accuracy: 0.7609
Improving: current loss 0.547160804271698 < 0.5787176489830017, loss last epoch 
save best param
Phase : train, Epoch 

In [22]:
# get the accuracy of the main model after JSE
y_m_pred_test = ERM_after_JSE(X_test_transformed)

# get the accuracy of the main model overall after JSE
main_acc_after = get_acc_pytorch_model(y_m_test, y_m_pred_test)

# get the accuracy of the main model per group
result_per_group, _ = get_acc_per_group(y_m_pred_test, y_m_test, y_c_test)

print("Overall Accuracy of JSE (test): ", main_acc_after)

print("Accuracy per group of JSE (test): ", result_per_group)

   


Overall Accuracy of JSE (test):  tensor(0.8400)
Accuracy per group of JSE (test):                count      mean
main concept                 
0.0  0.0        508  0.840551
     1.0        478  0.859833
1.0  0.0        477  0.836478
     1.0        537  0.824953


In [23]:
# get the accuracy of the main model after JSE
y_m_pred_test_INLP = ERM_after_INLP(X_test_transformed_INLP)

# get the accuracy of the main model overall after JSE
main_acc_after_INLP = get_acc_pytorch_model(y_m_test, y_m_pred_test_INLP)

# get the accuracy of the main model per group
result_per_group_INLP, _ = get_acc_per_group(y_m_pred_test_INLP, y_m_test, y_c_test)

print("Overall Accuracy of INLP (test): ", main_acc_after_INLP)

print("Accuracy per group of INLP (test): ", result_per_group_INLP)


Overall Accuracy of INLP (test):  tensor(0.5100)
Accuracy per group of INLP (test):                count      mean
main concept                 
0.0  0.0        508  0.816929
     1.0        478  0.144351
1.0  0.0        477  0.171908
     1.0        537  0.845438


In [24]:
# get the accuracy of the main model after JSE
y_m_pred_test_ERM = ERM_model(X_test)

# get the accuracy of the main model overall after JSE
main_acc_ERM = get_acc_pytorch_model(y_m_test, y_m_pred_test_ERM)

# get the accuracy of the main model per group
result_per_group_ERM, _ = get_acc_per_group(y_m_pred_test_ERM, y_m_test, y_c_test)

print("Overall Accuracy of ERM (test): ", main_acc_ERM)

print("Accuracy per group of ERM (test): ", result_per_group_ERM)

   


Overall Accuracy of ERM (test):  tensor(0.7445)
Accuracy per group of ERM (test):                count      mean
main concept                 
0.0  0.0        508  0.901575
     1.0        478  0.550209
1.0  0.0        477  0.576520
     1.0        537  0.918063
