# KHSIC approach for disentangling content and style

## Data Prep

In [1]:
import time

In [2]:
%%time

alpha = 1.0
alpha_sk = 0.9 # for creating skewed data used to learn R
eta = 0.95
batch_size = 128


import numpy as np
from sklearn.linear_model import LogisticRegression
import torch
from numpy import load
import sys, json
from itertools import product
from sklearn import preprocessing


# Function for binarizing labels
def binarize(y):    
    y = np.copy(y) > 5
    return y.astype(int)

# Function for creating spurious correlations on 
def create_spurious_corr(z, z_t, y_og, spu_corr= 0.1, binarize_label=True):
    y_bin = binarize(y_og)
    mod_labels = np.logical_xor(y_bin, np.random.binomial(1, spu_corr, size=len(y_bin)))
    
    modified_images = z_t[mod_labels]
    unmodified_images = z[~mod_labels]
    all_z = np.concatenate((modified_images, unmodified_images), axis=0)
    
    all_img_labels = None
    
    if binarize_label:
        modified_imgs_labels = y_bin[mod_labels]
        unmodified_imgs_labels = y_bin[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)
    else:
        modified_imgs_labels = y_og[mod_labels]
        unmodified_imgs_labels = y_og[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)    
        
    return all_z, all_img_labels, mod_labels.astype(int)


# Load saved image features

z_train_og = load('./data/Z_train_og_mnist_mlp.npy')
z_train_t = load('./data/Z_train_green_mnist_mlp.npy')

z_test_og = load('./data/Z_test_og_mnist_mlp.npy')
z_test_t = load('./data/Z_test_green_mnist_mlp.npy')

y_train_og_ = load('./data/train_labels_mnist.npy')

y_test_og_ = load('./data/test_labels_mnist.npy')


# Create spurious correlations on train and test sets

z_train_sk, train_labels_sk, t_labels_sk = create_spurious_corr(z_train_og, z_train_t, y_train_og_, 
                                         spu_corr= alpha_sk, binarize_label=True)

z_train, train_labels, _ = create_spurious_corr(z_train_og, z_train_t, y_train_og_, 
                                         spu_corr= alpha, binarize_label=True)

z_test_indist, indist_test_labels, _ = create_spurious_corr(z_test_og, z_test_t, y_test_og_, 
                                                         spu_corr= alpha, binarize_label=True)

z_test_ood, ood_test_labels, _ = create_spurious_corr(z_test_og, z_test_t, y_test_og_, 
                                                         spu_corr= 1-alpha, binarize_label=True)

# binarize train and test labels
y_train_og = binarize(y_train_og_)
y_test_og = binarize(y_test_og_)

# concatenate original and colored features
z_train_og_t = np.concatenate((z_train_og, z_train_t), axis=0)
t_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_t))), axis=None) 
z_test_og_t = np.concatenate((z_test_og, z_test_t), axis=0)
t_test_labels = np.concatenate((np.zeros(len(z_test_og)), np.ones(len(z_test_t))), axis=None) 


# # concatenate features with sytle labels..style labels are in column 0
# t_labels_z_train_og_t = np.concatenate((t_train_labels.reshape(-1,1), z_train_og_t), axis=1)

# # shuffle data in t_labels_z_train_og_t
# np.random.shuffle(t_labels_z_train_og_t)

# shuffled_train_og_t = t_labels_z_train_og_t[:,1:]
# shuffled_t_train_labels = t_labels_z_train_og_t[:,:1]


CPU times: user 6.34 s, sys: 8.89 s, total: 15.2 s
Wall time: 2.46 s


## MNIST data class distribustions per domain

In [3]:
import pandas as pd

# class distribution in original and colored images - number of original and colored images is the same
img_labels = np.concatenate((y_train_og, y_train_og), axis=None).reshape(-1,1)
style_and_img_labels_z_train_og_t_df = pd.DataFrame(np.concatenate((t_train_labels.reshape(-1,1),img_labels, 
                                                                 z_train_og_t), axis=1))

class_distribution_per_domain = style_and_img_labels_z_train_og_t_df.groupby([0,1]).count().iloc[:,0:1]

print("When number of original and colored images is the same")
display(class_distribution_per_domain)


# class distribution in original and colored images - class distribution is skewed
style_and_img_labels_z_train_sk_df = pd.DataFrame(np.concatenate((t_labels_sk.reshape(-1,1),
                                                                  train_labels_sk.reshape(-1,1),z_train_sk), axis=1))

print("When class distribution is skewed")
class_distribution_per_domain = style_and_img_labels_z_train_sk_df.groupby([0,1]).count().iloc[:,0:1]
display(class_distribution_per_domain)


# shuffle data in style_and_img_labels_z_train_sk_df
style_and_img_labels_z_train_sk = style_and_img_labels_z_train_sk_df.to_numpy()
np.random.shuffle(style_and_img_labels_z_train_sk)

shuffled_train_og_t = style_and_img_labels_z_train_sk[:,2:]
shuffled_t_train_labels = style_and_img_labels_z_train_sk[:,:1]




When number of original and colored images is the same


Unnamed: 0_level_0,Unnamed: 1_level_0,2
0,1,Unnamed: 2_level_1
0.0,0.0,36017
0.0,1.0,23983
1.0,0.0,36017
1.0,1.0,23983


When class distribution is skewed


Unnamed: 0_level_0,Unnamed: 1_level_0,2
0,1,Unnamed: 2_level_1
0.0,0.0,15187
0.0,1.0,10056
1.0,0.0,20830
1.0,1.0,13927


# Find rotation matrix R by optimization----using KHSIC loss

In [4]:
%%time

import torch
import mctorch.nn as mnn
import mctorch.optim as moptim
from hsic_calculator import HSIC


# # Reduce the samples size
# shuffled_train_og_t = shuffled_train_og_t[:10000]
# shuffled_t_train_labels = shuffled_t_train_labels[:10000]

dtype = torch.FloatTensor
n = shuffled_train_og_t.shape[0]
d = shuffled_train_og_t.shape[1]
k = int(shuffled_train_og_t.shape[1]*eta) # % of original number of features
ns = 1 #specify number of style features

# Initialize R
R = mnn.Parameter(manifold=mnn.Stiefel(d,k)).float()

# print("Initial R")
# display(R)

# Define Objective function 
def obj(z, e, W, n_s=1):
    z = torch.from_numpy(z).float()
    e = torch.from_numpy(e).float()
    MI_content_style = HSIC(torch.matmul(z, W[:,:n_s]), torch.matmul(z, W[:,n_s:]))
    MI_conten_env = HSIC(torch.matmul(z,W[:,n_s:]), e)
    MI_style_env = HSIC(torch.matmul(z,W[:,:n_s]), e)
    loss = (MI_content_style + MI_conten_env) - MI_style_env
    return loss

# Optimize - passing data in mini-batches
optimizer = moptim.rAdagrad(params = [R], lr=1e-2)

best_loss = 1e5
checkpoint = {}
for epoch in range(200):
    for index in range(0, len(shuffled_train_og_t), batch_size):
        train_data_subset = shuffled_train_og_t[index:index+batch_size]
        style_labels_subset = shuffled_t_train_labels[index:index+batch_size]
        loss = obj(train_data_subset, style_labels_subset, R, ns)        
        # saving R with the smallest loss value so far
        if loss < best_loss:
            best_loss = loss
            print("Saving R, at epoch ", epoch)
            checkpoint = {'epoch': epoch, 'loss': loss, 'R': R}
            torch.save(checkpoint, 'checkpoint') 
            print("loss: ", loss)            
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    
# print("R after optimization")
# display(R)
# (R.T)@R


checkpoint

Saving R, at epoch  0
loss:  tensor(0.2303, grad_fn=<SubBackward0>)


The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1980.)
  q, r = torch.qr(X)


Saving R, at epoch  0
loss:  tensor(0.2286, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.2202, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1802, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1720, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1703, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1536, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1498, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1245, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1175, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1127, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.1115, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.0906, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.0731, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.0681, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.0543, grad

{'epoch': 1,
 'loss': tensor(-0.0962, grad_fn=<SubBackward0>),
 'R': Parameter containing:
 tensor([[-0.1453,  0.0065,  0.1592,  ..., -0.0611,  0.0403, -0.1702],
         [ 0.2068, -0.0730,  0.0924,  ...,  0.4574,  0.0249,  0.0738],
         [ 0.0011,  0.2070, -0.0631,  ..., -0.0331,  0.1184,  0.0229],
         ...,
         [-0.2019,  0.0185, -0.0606,  ...,  0.0669, -0.0993,  0.2858],
         [ 0.0921, -0.0596,  0.0828,  ...,  0.1072, -0.2401, -0.1000],
         [-0.0202, -0.0783, -0.0553,  ..., -0.1076, -0.1725,  0.1708]],
        requires_grad=True)}

## Use the obtained rotation matrix R to disentangle content and style for OOD generalization

In [5]:
%%time

# Load saved R
R_mat = torch.load('checkpoint')['R']

# Obtain post-processed features
f_train_og = z_train_og @ R_mat.detach().numpy()  
f_train = z_train @ R_mat.detach().numpy()
f_test_indist = z_test_indist @ R_mat.detach().numpy()
f_test_ood = z_test_ood @ R_mat.detach().numpy()
f_test_og = z_test_og @ R_mat.detach().numpy()
f_test_t = z_test_t @ R_mat.detach().numpy()
f_test_og_t = z_test_og_t @ R_mat.detach().numpy()



########### Baseline results ###############
# Prediction Accuracies on image features extracted using a baseline model (mlp)
logistic_regression_on_baseline = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                              random_state=0).fit(z_train,train_labels)                                                                                     
baseline_accuracy0 = logistic_regression_on_baseline.score(z_train, train_labels)
baseline_accuracy1 = logistic_regression_on_baseline.score(z_test_indist, indist_test_labels)
baseline_accuracy2 = logistic_regression_on_baseline.score(z_test_ood, ood_test_labels)   

# Trained on original baseline features, tested on colored features - no spurious correlations here
logistic_regression_on_baseline_og = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                                 random_state=0).fit(z_train_og,y_train_og)                                                                                     
baseline_og_accuracy0 = logistic_regression_on_baseline_og.score(z_train_og, y_train_og)
baseline_og_accuracy1 = logistic_regression_on_baseline_og.score(z_test_og, y_test_og)
baseline_transf_accuracy2 = logistic_regression_on_baseline_og.score(z_test_t, y_test_og)
####################################

# Classification task using all post-processed features except style features   
lr_model_new_HSIC_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                    random_state=0).fit(f_train[:,1:],train_labels)
new_HSIC_sp_accuracy0 = lr_model_new_HSIC_sp.score(f_train[:,1:], train_labels)
new_HSIC_sp_accuracy1 = lr_model_new_HSIC_sp.score(f_test_indist[:,1:], indist_test_labels)
new_HSIC_sp_accuracy2 = lr_model_new_HSIC_sp.score(f_test_ood[:,1:], ood_test_labels)

# trained on original post-processed features, tested on transformed post-processed 
# features without style features   
lr_model_new_HSIC_no_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                    random_state=0).fit(f_train_og[:,1:],y_train_og)
new_HSIC_no_sp_accuracy0 = lr_model_new_HSIC_no_sp.score(f_train_og[:,1:], y_train_og)
new_HSIC_no_sp_accuracy1 = lr_model_new_HSIC_no_sp.score(f_test_og[:,1:], y_test_og)
new_HSIC_no_sp_accuracy2 = lr_model_new_HSIC_no_sp.score(f_test_t[:,1:], y_test_og)

# put all the results in a dictionary
results_log = {}

results_log['Baseline indist accuracy - spurious corr: '] = baseline_accuracy1
results_log['HSIC Approach indist accuracy - spurious corr: '] = new_HSIC_sp_accuracy1

results_log['Baseline ood accuracy - spurious corr: '] = baseline_accuracy2 
results_log['HSIC Approach ood accuracy- spurious corr: '] = new_HSIC_sp_accuracy2    

results_log['Baseline indist accuracy - no spurious corr: '] = baseline_og_accuracy1
results_log['HSIC Approach indist accuracy - no spurious corr: '] = new_HSIC_no_sp_accuracy1

results_log['Baseline ood accuracy - no spurious corr: '] = baseline_transf_accuracy2            
results_log['HSIC Approach ood accuracy - no spurious corr: '] = new_HSIC_no_sp_accuracy2   


results_log
 


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


CPU times: user 2min 54s, sys: 4min 35s, total: 7min 30s
Wall time: 7.05 s


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


{'Baseline indist accuracy - spurious corr: ': 1.0,
 'HSIC Approach indist accuracy - spurious corr: ': 0.8585,
 'Baseline ood accuracy - spurious corr: ': 0.0,
 'HSIC Approach ood accuracy- spurious corr: ': 0.3368,
 'Baseline indist accuracy - no spurious corr: ': 0.9104,
 'HSIC Approach indist accuracy - no spurious corr: ': 0.9081,
 'Baseline ood accuracy - no spurious corr: ': 0.7876,
 'HSIC Approach ood accuracy - no spurious corr: ': 0.5839}

<br>

<br>

# Compare with PISCO Results on MNIST 

In [6]:
%%time

import numpy as np
from sklearn.linear_model import LogisticRegression
from numpy import linalg as LA
import torch
from numpy import load
import sys, json
from itertools import product
from sklearn import preprocessing


# Function for binarizing labels
def binarize(y):    
    y = np.copy(y) > 5
    return y.astype(int)

# Function for creating spurious correlations
def create_spurious_corr(z, z_t, y_og, spu_corr= 0.1, binarize_label=True):
    y_bin = binarize(y_og)
    mod_labels = np.logical_xor(y_bin, np.random.binomial(1, spu_corr, size=len(y_bin)))
    
    modified_images = z_t[mod_labels]
    unmodified_images = z[~mod_labels]
    all_z = np.concatenate((modified_images, unmodified_images), axis=0)
    
    all_img_labels = None
    
    if binarize_label:
        modified_imgs_labels = y_bin[mod_labels]
        unmodified_imgs_labels = y_bin[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)
    else:
        modified_imgs_labels = y_og[mod_labels]
        unmodified_imgs_labels = y_og[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)    
        
    return all_z, all_img_labels 
    


# call this function to get experiments results for different parameters
def get_exp_results(alpha = 1.0, seed=0, lamda=1, extractor='mlp', transf_type='colored', 
                    dataset='mnist', eta=0.95):
    
    np.random.seed(seed)
    
    # Load saved image features
    z_train_og = load('./data/Z_train_og_mnist_mlp.npy')
    z_train_t = load('./data/Z_train_green_mnist_mlp.npy')

    z_test_og = load('./data/Z_test_og_mnist_mlp.npy')
    z_test_t = load('./data/Z_test_green_mnist_mlp.npy')

    y_train_og_ = load('./data/train_labels_mnist.npy')

    y_test_og_ = load('./data/test_labels_mnist.npy')
    
    
    # Create spurious correlations on train and test sets
    z_train, train_labels = create_spurious_corr(z_train_og, z_train_t, y_train_og_, 
                                             spu_corr= alpha, binarize_label=True)

    z_test_indist, indist_test_labels = create_spurious_corr(z_test_og, z_test_t, y_test_og_, 
                                                             spu_corr= alpha, binarize_label=True)

    z_test_ood, ood_test_labels = create_spurious_corr(z_test_og, z_test_t, y_test_og_, 
                                                             spu_corr= 1-alpha, binarize_label=True)
    
    # binarize train and test labels
    y_train_og = binarize(y_train_og_)
    y_test_og = binarize(y_test_og_)

    # concatenate original and colored features
    z_train_og_t = np.concatenate((z_train_og, z_train_t), axis=0)
    t_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_t))), axis=None) 
    z_test_og_t = np.concatenate((z_test_og, z_test_t), axis=0)
    t_test_labels = np.concatenate((np.zeros(len(z_test_og)), np.ones(len(z_test_t))), axis=None) 
    

    # Obtain prediction coefficients of color
    lr_model_t = LogisticRegression(random_state=0).fit(z_train_og_t, t_train_labels)
    t_coefficients = lr_model_t.coef_.reshape(-1,1)
    theta_1 = t_coefficients / np.linalg.norm(t_coefficients)

    # Prediction Accuracies on image features extracted using a baseline model (mlp)
    logistic_regression_on_baseline = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                                  random_state=0).fit(z_train,train_labels)                                                                                     
    baseline_accuracy0 = logistic_regression_on_baseline.score(z_train, train_labels)
    baseline_accuracy1 = logistic_regression_on_baseline.score(z_test_indist, indist_test_labels)
    baseline_accuracy2 = logistic_regression_on_baseline.score(z_test_ood, ood_test_labels)   
    
    # Trained on original baseline features, tested on colored features - no spurious correlations here
    logistic_regression_on_baseline_og = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                                     random_state=0).fit(z_train_og,y_train_og)                                                                                     
    baseline_og_accuracy0 = logistic_regression_on_baseline_og.score(z_train_og, y_train_og)
    baseline_og_accuracy1 = logistic_regression_on_baseline_og.score(z_test_og, y_test_og)
    baseline_transf_accuracy2 = logistic_regression_on_baseline_og.score(z_test_t, y_test_og)

    # Find P, get post-processed features, and perform predictions
    k = int(z_train_og_t.shape[1]*eta) # % of original number of features
    n = z_train_og_t.shape[0]

    delta_z_matrix = z_train_og - z_train_t 

    M = - z_train_og_t.T @ z_train_og_t/n + lamda * delta_z_matrix.T @ delta_z_matrix / (n // 2 ) 

    # Performing SVD to get eigenvectors and eigenvalues
    eigenvalues, eigenvectors = LA.eigh(M)

    # Forming P from eigenvectors and coeficients of color
    P_1 = theta_1

    P_2 = eigenvectors[:,:(k-1)]

    P = np.concatenate((P_1,P_2), axis=1)
    

    # Obtain post-processed features
    f_train_og = z_train_og @ P  
    f_train = z_train @ P
    f_test_indist = z_test_indist @ P
    f_test_ood = z_test_ood @ P
    f_test_og = z_test_og @ P
    f_test_t = z_test_t @ P
    f_test_og_t = z_test_og_t @ P
    
    

    # Correlation Matrix Analysis
    # concatenate transformation labels with f_test_og_t
    t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
    t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
    corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
    corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
    corr_special = np.abs(corr_matrix[0,1])
    corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
    z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
     
     
    

    # Classification task using all post-processed features except style features   
    lr_model_pisco_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                        random_state=0).fit(f_train[:,1:],train_labels)
    pisco_sp_accuracy0 = lr_model_pisco_sp.score(f_train[:,1:], train_labels)
    pisco_sp_accuracy1 = lr_model_pisco_sp.score(f_test_indist[:,1:], indist_test_labels)
    pisco_sp_accuracy2 = lr_model_pisco_sp.score(f_test_ood[:,1:], ood_test_labels)
    
    # trained on original post-processed features, tested on transformed post-processed 
    # features without style features   
    lr_model_pisco_no_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                        random_state=0).fit(f_train_og[:,1:],y_train_og)
    pisco_no_sp_accuracy0 = lr_model_pisco_no_sp.score(f_train_og[:,1:], y_train_og)
    pisco_no_sp_accuracy1 = lr_model_pisco_no_sp.score(f_test_og[:,1:], y_test_og)
    pisco_no_sp_accuracy2 = lr_model_pisco_no_sp.score(f_test_t[:,1:], y_test_og)
    
    # put all the results in a dictionary
    results_log = {}
    results_log['Baseline indist accuracy - spurious corr: '] = baseline_accuracy1
    results_log['PISCO indist accuracy - spurious corr: '] = pisco_sp_accuracy1

    results_log['Baseline ood accuracy - spurious corr: '] = baseline_accuracy2 
    results_log['PISCO ood accuracy- spurious corr: '] = pisco_sp_accuracy2    

    results_log['Baseline indist accuracy - no spurious corr: '] = baseline_og_accuracy1
    results_log['PISCO indist accuracy - no spurious corr: '] = pisco_no_sp_accuracy1

    results_log['Baseline ood accuracy - no spurious corr: '] = baseline_transf_accuracy2            
    results_log['PISCO ood accuracy - no spurious corr: '] = pisco_no_sp_accuracy2   

    
    return results_log



# if __name__ == "__main__":
#     ITERS = range(10)
#     datasets = ['mnist']
#     extractors= ['mlp']
#     transf_types = ['colored']
#     alphas = [0.5, 0.75, 0.90, 0.95, 0.99,1.0]
#     lamdas= [0,1,10,50] 
#     etas = [0.90]

#     grid = list(product(datasets, extractors, transf_types, alphas, lamdas,etas,ITERS))
    
#     i = int(float(sys.argv[1]))
#     dataset, extractor, transf_type, alpha, lamda, eta, ITER = grid[i]
    

#     results_log = get_exp_results(alpha = alpha, seed=ITER, lamda=lamda, extractor=extractor, 
#                                   transf_type=transf_type, dataset=dataset, eta=eta)

#     with open(f'summary_mnist/summary_{i}.json', 'w') as fp:
#         json.dump(results_log, fp)



get_exp_results(alpha = 1.0, seed=0, lamda=1, extractor='mlp', transf_type='colored', dataset='mnist', eta=0.95) 
        

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


CPU times: user 3min, sys: 4min 33s, total: 7min 34s
Wall time: 7.15 s
Compiler : 140 ms
Parser   : 228 ms


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


{'Baseline indist accuracy - spurious corr: ': 1.0,
 'PISCO indist accuracy - spurious corr: ': 0.9357,
 'Baseline ood accuracy - spurious corr: ': 0.0,
 'PISCO ood accuracy- spurious corr: ': 0.8022,
 'Baseline indist accuracy - no spurious corr: ': 0.9104,
 'PISCO indist accuracy - no spurious corr: ': 0.9099,
 'Baseline ood accuracy - no spurious corr: ': 0.7876,
 'PISCO ood accuracy - no spurious corr: ': 0.8389}