# KHSIC approach for disentangling content and style

## Data Prep

In [1]:
import time

In [2]:
%%time

alpha = 1.0
alpha_sk = 0.5 # for creating skewed data used to learn R
eta = 0.95
batch_size = 128
num_epochs = 200
ns = 20 #specify number of style features


import numpy as np
from sklearn.linear_model import LogisticRegression
from numpy import linalg as LA
import torch
from numpy import load
import sys, json
from itertools import product
from sklearn import preprocessing

import pandas as pd
import mctorch.nn as mnn
import mctorch.optim as moptim
from hsic_calculator import HSIC, normalized_HSIC




# Function for binarizing labels
def binarize(y):    
    y = np.copy(y) > 5
    return y.astype(int)

# Function for creating spurious correlations  
def create_spurious_corr(z, z_t, y_og, spu_corr= 0.1, binarize_label=True):
    y_bin = binarize(y_og)
    mod_labels = np.logical_xor(y_bin, np.random.binomial(1, spu_corr, size=len(y_bin)))
    
    modified_images = z_t[mod_labels]
    unmodified_images = z[~mod_labels]
    all_z = np.concatenate((modified_images, unmodified_images), axis=0)
    style_labels = np.concatenate((np.zeros(len(modified_images)), np.ones(len(unmodified_images))), axis=None)
    
    all_img_labels = None
    
    if binarize_label:
        modified_imgs_labels = y_bin[mod_labels]
        unmodified_imgs_labels = y_bin[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)
    else:
        modified_imgs_labels = y_og[mod_labels]
        unmodified_imgs_labels = y_og[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)    
        
    return all_z, all_img_labels, style_labels.astype(int)
    

# call this function to get experiments results for different parameters    
def get_exp_results(alpha = 1.0, seed=0, lamda=1, extractor='simclr', transf_type='contrasted', 
                    dataset='cifar10', eta=0.95):
    
    np.random.seed(seed)
    
    # Load saved image features
    z_train_og = load('./data/Z_train_og_'+dataset+'_'+extractor+'.npy')
    z_train_t = load('./data/Z_train_'+transf_type+'_'+dataset+'_'+extractor+'.npy')

    z_test_og = load('./data/Z_test_og_'+dataset+'_'+extractor+'.npy')
    z_test_t = load('./data/Z_test_'+transf_type+'_'+dataset+'_'+extractor+'.npy')

    y_train_og = load('./data/train_labels_'+dataset+'.npy')

    y_test_og = load('./data/test_labels_'+dataset+'.npy')
    
    # Create spurious correlations on train and test sets
    z_train, train_labels, _ = create_spurious_corr(z_train_og, z_train_t, y_train_og, 
                                             spu_corr= alpha, binarize_label=False)

    z_test_indist, indist_test_labels, _ = create_spurious_corr(z_test_og, z_test_t, y_test_og, 
                                                             spu_corr= alpha, binarize_label=False)

    z_test_ood, ood_test_labels, _ = create_spurious_corr(z_test_og, z_test_t, y_test_og, 
                                                             spu_corr= 1-alpha, binarize_label=False)
   
    # concatenate original and transformed features
    z_train_og_t = np.concatenate((z_train_og, z_train_t), axis=0)
    t_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_t))), axis=None) 
    z_test_og_t = np.concatenate((z_test_og, z_test_t), axis=0)
    t_test_labels = np.concatenate((np.zeros(len(z_test_og)), np.ones(len(z_test_t))), axis=None) 
   
    # Prediction Accuracies on image features extracted using a baseline model
    logistic_regression_on_baseline = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                                  random_state=0).fit(z_train,train_labels)                                                                                     
    baseline_accuracy0 = logistic_regression_on_baseline.score(z_train, train_labels)
    baseline_accuracy1 = logistic_regression_on_baseline.score(z_test_indist, indist_test_labels)
    baseline_accuracy2 = logistic_regression_on_baseline.score(z_test_ood, ood_test_labels)
    
    # Trained on original baseline features, tested on transformed features - no spurious correlations here
    logistic_regression_on_baseline_og = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                                     random_state=0).fit(z_train_og,y_train_og)                                                                                     
    baseline_og_accuracy0 = logistic_regression_on_baseline_og.score(z_train_og, y_train_og)
    baseline_og_accuracy1 = logistic_regression_on_baseline_og.score(z_test_og, y_test_og)
    baseline_transf_accuracy2 = logistic_regression_on_baseline_og.score(z_test_t, y_test_og)
          
    # Obtain prediction coefficients of transformations done on images
    z_train_rotated = load('./data/Z_train_rotated_cifar10_'+extractor+'.npy')
    z_train_contrasted = load('./data/Z_train_contrasted_cifar10_'+extractor+'.npy')
    z_train_blurred = load('./data/Z_train_blurred_cifar10_'+extractor+'.npy')
    z_train_saturated = load('./data/Z_train_saturated_cifar10_'+extractor+'.npy')
       

    # Find R, get post-processed features, and perform predictions
    
    z_train_og_4_ts = np.concatenate((z_train_og, z_train_rotated,z_train_contrasted, 
                                      z_train_blurred,z_train_saturated), axis=0)

    og_4_ts_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_rotated)), 
                                     np.array([2]*len(z_train_contrasted)), np.array([3]*len(z_train_blurred)), 
                                     np.array([4]*len(z_train_saturated))), axis=None)
    
    
    image_labels = np.concatenate((y_train_og, y_train_og, y_train_og, y_train_og, y_train_og), axis=None)
    
    # concatenate features with sytle labels..style labels are in column 0
    og_4_ts_labels_z_train_og_4_ts = np.concatenate((og_4_ts_labels.reshape(-1,1), image_labels.reshape(-1,1), z_train_og_4_ts), axis=1)

    # shuffle data in t_labels_z_train_og_t
    np.random.shuffle(og_4_ts_labels_z_train_og_4_ts)

    shuffled_train_og_t = og_4_ts_labels_z_train_og_4_ts[:,2:]
    shuffled_t_train_labels = og_4_ts_labels_z_train_og_4_ts[:,:1]
    
    # class distribution 
    labels_and_z_train_df = pd.DataFrame(og_4_ts_labels_z_train_og_4_ts)

    print("class distribution - column 1 is class labels - column 0 is domain/environment labels")
    class_distribution_per_domain = labels_and_z_train_df.groupby([1,0]).count().iloc[:,0:1]
    display(class_distribution_per_domain)
    
    
    dtype = torch.FloatTensor
    n = shuffled_train_og_t.shape[0]
    d = shuffled_train_og_t.shape[1]
    k = int(shuffled_train_og_t.shape[1]*eta) # % of original number of features

    # Initialize R
    R = mnn.Parameter(manifold=mnn.Stiefel(d,k)).float()

    # print("Initial R")
    # display(R)

    # Define Objective function 
    def obj(z, e, W, n_s=1):
        z = torch.from_numpy(z).float()
        e = torch.from_numpy(e).float()
        MI_content_style = normalized_HSIC(torch.matmul(z, W[:,:n_s]), torch.matmul(z, W[:,n_s:]))
        MI_conten_env = normalized_HSIC(torch.matmul(z,W[:,n_s:]), e)
        MI_style_env = normalized_HSIC(torch.matmul(z,W[:,:n_s]), e)
        loss = (MI_content_style + MI_conten_env) - MI_style_env
        return loss

    # Optimize - passing data in mini-batches
    optimizer = moptim.rAdagrad(params = [R], lr=1e-2)

    best_loss = 1e5
    checkpoint = {}
    for epoch in range(num_epochs):
        for index in range(0, len(shuffled_train_og_t), batch_size):
            train_data_subset = shuffled_train_og_t[index:index+batch_size]
            style_labels_subset = shuffled_t_train_labels[index:index+batch_size]
            loss = obj(train_data_subset, style_labels_subset, R, ns)        
            # saving R with the smallest loss value so far
            if loss < best_loss:
                best_loss = loss
                print("Saving R, at epoch ", epoch)
                checkpoint = {'epoch': epoch, 'loss': loss, 'R': R}
                torch.save(checkpoint, 'checkpoint') 
                print("loss: ", loss)            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
    # Load saved R
    R_mat = torch.load('checkpoint')['R']

    # Obtain post-processed features
    f_train_og = z_train_og @ R_mat.detach().numpy()  
    f_train = z_train @ R_mat.detach().numpy()
    f_test_indist = z_test_indist @ R_mat.detach().numpy()
    f_test_ood = z_test_ood @ R_mat.detach().numpy()
    f_test_og = z_test_og @ R_mat.detach().numpy()
    f_test_t = z_test_t @ R_mat.detach().numpy()
    f_test_og_t = z_test_og_t @ R_mat.detach().numpy()

    
    # Correlation Matrix Analysis
    if transf_type=='rotated':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,1])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
            
        
    elif transf_type=='contrasted':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,2])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
    
        
    elif transf_type=='blurred':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,3])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
        
        
    elif transf_type=='saturated':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,4])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
        

    # Classification task using all post-processed features except style features    
    lr_model_hsic_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                        random_state=0).fit(f_train[:,20:],train_labels)
    hsic_sp_accuracy0 = lr_model_hsic_sp.score(f_train[:,20:], train_labels)
    hsic_sp_accuracy1 = lr_model_hsic_sp.score(f_test_indist[:,20:], indist_test_labels)
    hsic_sp_accuracy2 = lr_model_hsic_sp.score(f_test_ood[:,20:], ood_test_labels)
    
    # trained on original post-processed features, tested on transformed post-processed features 
    # without features without style features  
    lr_model_hsic_no_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                        random_state=0).fit(f_train_og[:,20:],y_train_og)
    hsic_no_sp_accuracy0 = lr_model_hsic_no_sp.score(f_train_og[:,20:], y_train_og)
    hsic_no_sp_accuracy1 = lr_model_hsic_no_sp.score(f_test_og[:,20:], y_test_og)
    hsic_no_sp_accuracy2 = lr_model_hsic_no_sp.score(f_test_t[:,20:], y_test_og)
    
    # put all the results in a dictionary
    results_log = {}
    results_log['Baseline indist accuracy - spurious corr: '] = baseline_accuracy1
    results_log['HSIC indist accuracy - spurious corr: '] = hsic_sp_accuracy1

    results_log['Baseline ood accuracy - spurious corr: '] = baseline_accuracy2 
    results_log['HSIC ood accuracy- spurious corr: '] = hsic_sp_accuracy2    

    results_log['Baseline indist accuracy - no spurious corr: '] = baseline_og_accuracy1
    results_log['HSIC indist accuracy - no spurious corr: '] = hsic_no_sp_accuracy1

    results_log['Baseline ood accuracy - no spurious corr: '] = baseline_transf_accuracy2            
    results_log['HSIC ood accuracy - no spurious corr: '] = hsic_no_sp_accuracy2 
    
    return results_log



# if __name__ == "__main__":
#     ITERS = range(10)
#     datasets = ['cifar10'] 
#     extractors= ['resnet', 'simclr']  
#     transf_types = ['contrasted', 'rotated', 'blurred', 'saturated']  
#     alphas = [0.5,0.75,0.90,0.95,0.99,1.0] 
#     lamdas= [0,1,10,50]
#     etas = [0.90,0.93,0.95,0.98,1.0]

#     grid = list(product(datasets, extractors, transf_types, alphas, lamdas,etas,ITERS))
    
#     i = int(float(sys.argv[1]))
#     dataset, extractor, transf_type, alpha, lamda, eta, ITER = grid[i]    

#     results_log = get_exp_results(alpha = alpha, seed=int(ITER), lamda=lamda, extractor=extractor, 
#                                   transf_type=transf_type, dataset=dataset, eta=eta)
    
#     with open(f'summary_cifar10/summary_{i}.json', 'w') as fp:
#         json.dump(results_log, fp)


get_exp_results(alpha = 1.0, seed=0, lamda=10, extractor='resnet', transf_type='rotated', dataset='cifar10', eta=0.95)




STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


class distribution - column 1 is class labels - column 0 is domain/environment labels


Unnamed: 0_level_0,Unnamed: 1_level_0,2
1,0,Unnamed: 2_level_1
0.0,0.0,5000
0.0,1.0,5000
0.0,2.0,5000
0.0,3.0,5000
0.0,4.0,5000
1.0,0.0,5000
1.0,1.0,5000
1.0,2.0,5000
1.0,3.0,5000
1.0,4.0,5000


The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1980.)
  q, r = torch.qr(X)


Saving R, at epoch  0
loss:  tensor(0.6990, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(0.4178, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.0847, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.0952, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.2505, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.3239, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.3639, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.3757, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.3883, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.3917, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.4039, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.4575, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.5081, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(-0.5719, grad_fn=<SubBackward0>)
Saving R, at epoch  0
loss:  tensor(

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


CPU times: user 7d 23h 43min 19s, sys: 8h 7min 8s, total: 8d 7h 50min 28s
Wall time: 5h 32min 14s


{'Baseline indist accuracy - spurious corr: ': 0.8952,
 'HSIC indist accuracy - spurious corr: ': 0.8602,
 'Baseline ood accuracy - spurious corr: ': 0.1079,
 'HSIC ood accuracy- spurious corr: ': 0.086,
 'Baseline indist accuracy - no spurious corr: ': 0.8731,
 'HSIC indist accuracy - no spurious corr: ': 0.8162,
 'Baseline ood accuracy - no spurious corr: ': 0.6833,
 'HSIC ood accuracy - no spurious corr: ': 0.4691}

# Compare with PISCO Results on CIFAR 10 

In [3]:
import numpy as np
from sklearn.linear_model import LogisticRegression
from numpy import linalg as LA
import torch
from numpy import load
import sys, json
from itertools import product
from sklearn import preprocessing


# Function for binarizing labels
def binarize(y):    
    y = np.copy(y) > 5
    return y.astype(int)

# Function for creating spurious correlations
def create_spurious_corr(z, z_t, y_og, spu_corr= 0.1, binarize_label=True):
    y_bin = binarize(y_og)
    mod_labels = np.logical_xor(y_bin, np.random.binomial(1, spu_corr, size=len(y_bin)))
    
    modified_images = z_t[mod_labels]
    unmodified_images = z[~mod_labels]
    all_z = np.concatenate((modified_images, unmodified_images), axis=0)
    
    all_img_labels = None
    
    if binarize_label:
        modified_imgs_labels = y_bin[mod_labels]
        unmodified_imgs_labels = y_bin[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)
    else:
        modified_imgs_labels = y_og[mod_labels]
        unmodified_imgs_labels = y_og[~mod_labels]
        all_img_labels = np.concatenate((modified_imgs_labels, unmodified_imgs_labels), axis=None)    
        
    return all_z, all_img_labels 
    

# call this function to get experiments results for different parameters    
def get_exp_results(alpha = 1.0, seed=0, lamda=1, extractor='simclr', transf_type='contrasted', 
                    dataset='cifar10', eta=0.95):
    
    np.random.seed(seed)
    
    # Load saved image features
    z_train_og = load('./data/Z_train_og_'+dataset+'_'+extractor+'.npy')
    z_train_t = load('./data/Z_train_'+transf_type+'_'+dataset+'_'+extractor+'.npy')

    z_test_og = load('./data/Z_test_og_'+dataset+'_'+extractor+'.npy')
    z_test_t = load('./data/Z_test_'+transf_type+'_'+dataset+'_'+extractor+'.npy')

    y_train_og = load('./data/train_labels_'+dataset+'.npy')

    y_test_og = load('./data/test_labels_'+dataset+'.npy')
    
    # Create spurious correlations on train and test sets
    z_train, train_labels = create_spurious_corr(z_train_og, z_train_t, y_train_og, 
                                             spu_corr= alpha, binarize_label=False)

    z_test_indist, indist_test_labels = create_spurious_corr(z_test_og, z_test_t, y_test_og, 
                                                             spu_corr= alpha, binarize_label=False)

    z_test_ood, ood_test_labels = create_spurious_corr(z_test_og, z_test_t, y_test_og, 
                                                             spu_corr= 1-alpha, binarize_label=False)
   
    # concatenate original and transformed features
    z_train_og_t = np.concatenate((z_train_og, z_train_t), axis=0)
    t_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_t))), axis=None) 
    z_test_og_t = np.concatenate((z_test_og, z_test_t), axis=0)
    t_test_labels = np.concatenate((np.zeros(len(z_test_og)), np.ones(len(z_test_t))), axis=None) 
   
    # Prediction Accuracies on image features extracted using a baseline model
    logistic_regression_on_baseline = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                                  random_state=0).fit(z_train,train_labels)                                                                                     
    baseline_accuracy0 = logistic_regression_on_baseline.score(z_train, train_labels)
    baseline_accuracy1 = logistic_regression_on_baseline.score(z_test_indist, indist_test_labels)
    baseline_accuracy2 = logistic_regression_on_baseline.score(z_test_ood, ood_test_labels)
    
    # Trained on original baseline features, tested on transformed features - no spurious correlations here
    logistic_regression_on_baseline_og = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                                     random_state=0).fit(z_train_og,y_train_og)                                                                                     
    baseline_og_accuracy0 = logistic_regression_on_baseline_og.score(z_train_og, y_train_og)
    baseline_og_accuracy1 = logistic_regression_on_baseline_og.score(z_test_og, y_test_og)
    baseline_transf_accuracy2 = logistic_regression_on_baseline_og.score(z_test_t, y_test_og)
          
    # Obtain prediction coefficients of transformations done on images
    z_train_rotated = load('./data/Z_train_rotated_cifar10_'+extractor+'.npy')
    z_train_contrasted = load('./data/Z_train_contrasted_cifar10_'+extractor+'.npy')
    z_train_blurred = load('./data/Z_train_blurred_cifar10_'+extractor+'.npy')
    z_train_saturated = load('./data/Z_train_saturated_cifar10_'+extractor+'.npy')
       
    z_train_og_rotated = np.concatenate((z_train_og, z_train_rotated), axis=0)
    rotat_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_rotated))), axis=None)
    
    z_train_og_contrasted = np.concatenate((z_train_og, z_train_contrasted), axis=0)
    contrast_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_contrasted))), axis=None)
    
    z_train_og_blurred= np.concatenate((z_train_og, z_train_blurred), axis=0)
    blur_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_blurred))), axis=None)
    
    z_train_og_saturated = np.concatenate((z_train_og, z_train_saturated), axis=0)
    saturat_train_labels = np.concatenate((np.zeros(len(z_train_og)), np.ones(len(z_train_saturated))), axis=None)
    
       
    lr_model_rotated = LogisticRegression(random_state=0).fit(z_train_og_rotated, rotat_train_labels)
    rotat_coefficients = lr_model_rotated.coef_.reshape(-1,1)
    theta_1 = rotat_coefficients / np.linalg.norm(rotat_coefficients)
    
    lr_model_contrasted = LogisticRegression(random_state=0).fit(z_train_og_contrasted, contrast_train_labels)
    contrast_coefficients = lr_model_contrasted.coef_.reshape(-1,1)
    theta_2 = contrast_coefficients / np.linalg.norm(contrast_coefficients)
    
    lr_model_blurred = LogisticRegression(random_state=0).fit(z_train_og_blurred, blur_train_labels)
    blur_coefficients = lr_model_blurred.coef_.reshape(-1,1)
    theta_3 = blur_coefficients / np.linalg.norm(blur_coefficients)
    
    lr_model_saturated = LogisticRegression(random_state=0).fit(z_train_og_saturated, saturat_train_labels)
    saturat_coefficients = lr_model_saturated.coef_.reshape(-1,1)
    theta_4 = saturat_coefficients / np.linalg.norm(saturat_coefficients)
       

    # Find P, get post-processed features, and perform predictions
    delta_z_matrix1 = z_train_og - z_train_rotated 
    delta_z_matrix2 = z_train_og - z_train_contrasted
    delta_z_matrix3 = z_train_og - z_train_blurred
    delta_z_matrix4 = z_train_og - z_train_saturated
    combined_delta_z_matrix = np.concatenate((delta_z_matrix1, delta_z_matrix2,delta_z_matrix3,
                                              delta_z_matrix4), axis=0)
    
    z_train_og_4_ts = np.concatenate((z_train_og, z_train_rotated,z_train_contrasted, 
                                      z_train_blurred,z_train_saturated), axis=0)
    
    k = int(z_train_og_4_ts.shape[1]*eta) # % of original number of features
    n = z_train_og_4_ts.shape[0]
    n_delt =  combined_delta_z_matrix.shape[0]

    
    M = - z_train_og_4_ts.T @ z_train_og_4_ts/n + lamda * combined_delta_z_matrix.T @ combined_delta_z_matrix /n_delt 
    
    # Perform SVD to get eigenvectors and eigenvalues
    eigenvalues, eigenvectors = LA.eigh(M)

    P_2 = eigenvectors[:,:(k-4)]

    P = np.concatenate((theta_1,theta_2,theta_3,theta_4,P_2), axis=1)
    
    # Obtain post-processed features
    f_train_og = z_train_og @ P  
    f_train = z_train @ P 
    f_test_indist = z_test_indist @ P 
    f_test_ood = z_test_ood @ P 
    f_test_og = z_test_og @ P 
    f_test_t = z_test_t @ P 
    f_test_og_t = z_test_og_t @ P 
    
    # Correlation Matrix Analysis
    if transf_type=='rotated':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,1])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
            
        
    elif transf_type=='contrasted':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,2])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
    
        
    elif transf_type=='blurred':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,3])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
        
        
    elif transf_type=='saturated':
        # concatenate transformation labels with f_test_og_t
        t_labels_f_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), f_test_og_t), axis=1)
        t_labels_z_test_og_t = np.concatenate((t_test_labels.reshape(-1,1), z_test_og_t), axis=1)
        corr_matrix = np.corrcoef(t_labels_f_test_og_t.T)
        corr_z_matrix = np.corrcoef(t_labels_z_test_og_t.T)
        corr_special = np.abs(corr_matrix[0,4])
        corr_ns_f_norm = np.sqrt((corr_matrix[0,5:]**2).mean()) 
        z_corr_ns_f_norm = np.sqrt((corr_z_matrix[0,:]**2).mean()) 
        

    # Classification task using all post-processed features except style features    
    lr_model_pisco_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs',
                                        random_state=0).fit(f_train[:,4:],train_labels)
    pisco_sp_accuracy0 = lr_model_pisco_sp.score(f_train[:,4:], train_labels)
    pisco_sp_accuracy1 = lr_model_pisco_sp.score(f_test_indist[:,4:], indist_test_labels)
    pisco_sp_accuracy2 = lr_model_pisco_sp.score(f_test_ood[:,4:], ood_test_labels)
    
    # trained on original post-processed features, tested on transformed post-processed features 
    # without features without style features  
    lr_model_pisco_no_sp = LogisticRegression(multi_class='multinomial', solver='lbfgs', 
                                        random_state=0).fit(f_train_og[:,4:],y_train_og)
    pisco_no_sp_accuracy0 = lr_model_pisco_no_sp.score(f_train_og[:,4:], y_train_og)
    pisco_no_sp_accuracy1 = lr_model_pisco_no_sp.score(f_test_og[:,4:], y_test_og)
    pisco_no_sp_accuracy2 = lr_model_pisco_no_sp.score(f_test_t[:,4:], y_test_og)
    
    # put all the results in a dictionary
    results_log = {}
    results_log['Baseline indist accuracy - spurious corr: '] = baseline_accuracy1
    results_log['PISCO indist accuracy - spurious corr: '] = pisco_sp_accuracy1

    results_log['Baseline ood accuracy - spurious corr: '] = baseline_accuracy2 
    results_log['PISCO ood accuracy- spurious corr: '] = pisco_sp_accuracy2    

    results_log['Baseline indist accuracy - no spurious corr: '] = baseline_og_accuracy1
    results_log['PISCO indist accuracy - no spurious corr: '] = pisco_no_sp_accuracy1

    results_log['Baseline ood accuracy - no spurious corr: '] = baseline_transf_accuracy2            
    results_log['PISCO ood accuracy - no spurious corr: '] = pisco_no_sp_accuracy2 
    
    return results_log



# if __name__ == "__main__":
#     ITERS = range(10)
#     datasets = ['cifar10'] 
#     extractors= ['resnet', 'simclr']  
#     transf_types = ['contrasted', 'rotated', 'blurred', 'saturated']  
#     alphas = [0.5,0.75,0.90,0.95,0.99,1.0] 
#     lamdas= [0,1,10,50]
#     etas = [0.90,0.93,0.95,0.98,1.0]

#     grid = list(product(datasets, extractors, transf_types, alphas, lamdas,etas,ITERS))
    
#     i = int(float(sys.argv[1]))
#     dataset, extractor, transf_type, alpha, lamda, eta, ITER = grid[i]    

#     results_log = get_exp_results(alpha = alpha, seed=int(ITER), lamda=lamda, extractor=extractor, 
#                                   transf_type=transf_type, dataset=dataset, eta=eta)
    
#     with open(f'summary_cifar10/summary_{i}.json', 'w') as fp:
#         json.dump(results_log, fp)


get_exp_results(alpha = 1.0, seed=0, lamda=10, extractor='resnet', transf_type='rotated', dataset='cifar10', eta=0.95)



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

{'Baseline indist accuracy - spurious corr: ': 0.8952,
 'PISCO indist accuracy - spurious corr: ': 0.8368,
 'Baseline ood accuracy - spurious corr: ': 0.1079,
 'PISCO ood accuracy- spurious corr: ': 0.6547,
 'Baseline indist accuracy - no spurious corr: ': 0.8731,
 'PISCO indist accuracy - no spurious corr: ': 0.8433,
 'Baseline ood accuracy - no spurious corr: ': 0.6833,
 'PISCO ood accuracy - no spurious corr: ': 0.7299}