# Example of implementing the code from JSE

In this notebook, we briefly show how to implement the code in this repository for our Toy dataset. 

In [None]:

# import the necessary libraries
import numpy as np
import pandas as pd
import torch

# import from the JSE package
from JSE.data import *
from JSE.settings import data_info, optimizer_info
from JSE.models import *
from JSE.training import *

# import in order 
import argparse
import os
import sys

In [None]:
# define several variables
device_type = 'cpu'
device = torch.device(device_type)
dataset = 'Toy'

# get the dataset information
dataset_setting = 'default'
dataset_settings = data_info[dataset][dataset_setting]
optimizer_settings = optimizer_info['All']

# Determine the spurious correlation for the dataset - in this case corresponding to the \rho value in the paper
spurious_ratio = 0.8

# set the random seed, get the dataset, and set the device
seed = 0
set_seed(seed)

# we will demean the data, apply no pca
demean = True
pca = False
k_components = 20 # number of pca components - only relevant if pca = True
d = 20 # number of features

# define settings for the model
solver = 'SGD'
lr = 0.01
weight_decay  = 0.0
early_stopping = True
epochs = 50
per_step = 5 # number of epochs between printing the loss
batch_size = 128


In [None]:
# get the dataset object
data_obj = get_dataset_obj(dataset, dataset_settings, spurious_ratio, data_info, seed, device, use_punctuation_MNLI=True)

# demean, pca
if demean:
    data_obj.demean_X(reset_mean=True, include_test=True)
if pca:
    data_obj.transform_data_to_k_components(k_components, reset_V_k=True, include_test=True)
    V_k_train = data_obj.V_k_train

# get the data
X_train, y_c_train, y_m_train = data_obj.X_train, data_obj.y_c_train, data_obj.y_m_train
X_val, y_c_val, y_m_val = data_obj.X_val, data_obj.y_c_val, data_obj.y_m_val
X_test, y_c_test, y_m_test = data_obj.X_test, data_obj.y_c_test, data_obj.y_m_test


In [None]:
# We first implement simple ERM

# set the loaders of the dataset object
data_obj.reset_X(X_train, X_val, batch_size=batch_size, reset_X_objects=False, include_weights=False, train_weights = None, val_weights = None)

# get the model
ERM_model = return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model')

In [None]:
# Next, we show how to implement JSE

# define the loaders
balanced_training_concept = False
concept_weights_train = None
concept_weights_val = None
concept_first = True # if True, then inner loop is for the main-task, outer loop is for the spurious concept-task
loaders =  data_obj.create_loaders(batch_size=batch_size, workers=0, with_concept=True, include_weights=balanced_training_concept, train_weights=concept_weights_train, val_weights=concept_weights_val,concept_first=concept_first ) 

alpha = 0.05
Delta = 0
eval_balanced = True

# Run the JSE algorithm - returns the spurious concept basis, the main task concept basis, and the dimension of both
V_c, V_m, d_c, d_m = train_JSE(data_obj,
                                                        device=device,
                                                        batch_size=batch_size, 
                                                        solver=solver,
                                                        lr=lr,
                                                        per_step=per_step,
                                                        tol=optimizer_settings['tol'],
                                                        early_stopping=early_stopping,
                                                        patience=optimizer_settings['patience'],
                                                        epochs=epochs, 
                                                        Delta = Delta,
                                                        alpha=alpha,
                                                         null_is_concept = False,
                                                         eval_balanced=eval_balanced, 
                                                         weight_decay=weight_decay,
                                                         include_weights=balanced_training_concept,
                                                         train_weights=concept_weights_train,
                                                         val_weights=concept_weights_val,
                                                         model_base_name='JSE_'+dataset,
                                                         concept_first=concept_first,
                                                         )

In [None]:

# define the loaders
balanced_training_concept = False
concept_weights_train = None
concept_weights_val = None
concept_first = True # if True, then inner loop is for the main-task, outer loop is for the spurious concept-task
loaders =  data_obj.create_loaders(batch_size=batch_size, workers=0, with_concept=True, include_weights=balanced_training_concept, train_weights=concept_weights_train, val_weights=concept_weights_val,concept_first=concept_first ) 
 
rafvogel_with_joint = False 
orthogonality_constraint = False


# Train the model to get V_c
set_seed(seed)
V_c_INLP, d_c_INLP = train_INLP(data_obj,
                                                                device,
                                                                batch_size=batch_size, 
                                                                solver=solver,
                                                                lr=lr,
                                                                weight_decay=weight_decay,
                                                                per_step=per_step,
                                                                tol=optimizer_settings['tol'],
                                                                early_stopping=early_stopping,
                                                                patience=optimizer_settings['patience'],
                                                                epochs=epochs,
                                                                alpha=alpha,
                                                                model_base_name='waterbird_rafvogel',
                                                                bias=True,
                                                                joint_decision_rule=rafvogel_with_joint,
                                                                include_weights=balanced_training_concept,
                                                                train_weights=concept_weights_train,
                                                                val_weights=concept_weights_val,
                                                                orthogonality_constraint=orthogonality_constraint,
                                                                expected_diff=None,
                                                                var_diff = None,
                                                                )

In [None]:
# define the orthogonal projection matrix
P_c_orth = torch.eye(d) - create_P(V_c) 

# reset the data
X_train_transformed = torch.matmul(X_train, P_c_orth)
X_val_transformed = torch.matmul(X_val, P_c_orth)
X_test_transformed = torch.matmul(X_test, P_c_orth)

# set the loaders of the dataset object
balanced_training_main = False
main_weights_train = None
main_weights_val = None
data_obj.reset_X(X_train_transformed, X_val_transformed, batch_size=batch_size, reset_X_objects=True, include_weights=balanced_training_main, train_weights = main_weights_train, val_weights = main_weights_val, only_main=True)

# Train the model on the transformed embeddings
set_seed(seed)
ERM_after_JSE = return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model',
                                              save_best_model=True
                                              )

In [None]:
# define the orthogonal projection matrix for INLP
P_c_orth_INLP = torch.eye(d) - create_P(V_c_INLP) 

# reset the data
X_train_transformed_INLP = torch.matmul(X_train, P_c_orth_INLP)
X_val_transformed_INLP = torch.matmul(X_val, P_c_orth_INLP)
X_test_transformed_INLP = torch.matmul(X_test, P_c_orth_INLP)

# set the loaders of the dataset object
balanced_training_main = False
main_weights_train = None
main_weights_val = None
data_obj.reset_X(X_train_transformed_INLP, X_val_transformed_INLP, batch_size=batch_size, reset_X_objects=True, include_weights=balanced_training_main, train_weights = main_weights_train, val_weights = main_weights_val, only_main=True)

# Train the model on the transformed embeddings
set_seed(seed)
ERM_after_INLP= return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model',
                                              save_best_model=True
                                              )

In [None]:

# Train the model on the transformed embeddings
set_seed(seed)
ERM_after_INLP= return_linear_model(d, 
                                               data_obj.main_loader,
                                              device,
                                              solver = solver,
                                              lr=lr,
                                              per_step=per_step,
                                              tol = optimizer_settings['tol'],
                                              early_stopping = early_stopping,
                                              patience = optimizer_settings['patience'],
                                              epochs = epochs,
                                              bias=True,
                                              weight_decay=weight_decay, 
                                              model_name=dataset+'_main_model',
                                              save_best_model=True
                                              )

In [31]:
# get the accuracy of the main model after JSE
y_m_pred_test = ERM_after_JSE(X_test_transformed)

# get the accuracy of the main model overall after JSE
main_acc_after = get_acc_pytorch_model(y_m_test, y_m_pred_test)

# get the accuracy of the main model per group
result_per_group, _ = get_acc_per_group(y_m_pred_test, y_m_test, y_c_test)

print("Overall Accuracy of JSE (test): ", main_acc_after)

print("Accuracy per group of JSE (test): ", result_per_group)

   


Overall Accuracy of JSE (test):  tensor(0.8325)
Accuracy per group of JSE (test):                count      mean
main concept                 
0.0  0.0        509  0.842829
     1.0        467  0.839400
1.0  0.0        522  0.827586
     1.0        502  0.820717


In [32]:
# get the accuracy of the main model after JSE
y_m_pred_test_INLP = ERM_after_INLP(X_test_transformed_INLP)

# get the accuracy of the main model overall after JSE
main_acc_after_INLP = get_acc_pytorch_model(y_m_test, y_m_pred_test_INLP)

# get the accuracy of the main model per group
result_per_group_INLP, _ = get_acc_per_group(y_m_pred_test_INLP, y_m_test, y_c_test)

print("Overall Accuracy of JSE (test): ", main_acc_after_INLP)

print("Accuracy per group of JSE (test): ", result_per_group_INLP)

   


Overall Accuracy of JSE (test):  tensor(0.5125)
Accuracy per group of JSE (test):                count      mean
main concept                 
0.0  0.0        509  0.842829
     1.0        467  0.203426
1.0  0.0        522  0.187739
     1.0        502  0.802789


In [33]:
# get the accuracy of the main model after JSE
y_m_pred_test_ERM = ERM_model(X_test)

# get the accuracy of the main model overall after JSE
main_acc_ERM = get_acc_pytorch_model(y_m_test, y_m_pred_test_ERM)

# get the accuracy of the main model per group
result_per_group_ERM, _ = get_acc_per_group(y_m_pred_test_ERM, y_m_test, y_c_test)

print("Overall Accuracy of ERM (test): ", main_acc_ERM)

print("Accuracy per group of ERM (test): ", result_per_group_ERM)

   


Overall Accuracy of ERM (test):  tensor(0.8170)
Accuracy per group of ERM (test):                count      mean
main concept                 
0.0  0.0        509  0.899804
     1.0        467  0.747323
1.0  0.0        522  0.739464
     1.0        502  0.878486
