In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os, pickle
import sys
sys.path.append("../")

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
## Function to obtain stats table
def get_stats(basedir, model_name_prefix, isave = 1, split_list = [1],
              seed_list=[42], c_list = [0,1,2,3]):
    
    pd_stats = pd.DataFrame()
    for c in c_list:
        for split in split_list:
            for seed in seed_list:
                model_name = model_name_prefix+ str(c) + '_seed' + str(seed) + '_split' + str(split)
                if not os.path.exists(basedir+model_name+'/eval'+str(isave)+'.csv'):
                    print('Not found :', basedir+model_name+'/eval'+str(isave)+'.csv')
                else:
                    pd_aux = pd.read_csv(basedir+model_name+'/eval'+str(isave)+'.csv')
                    pd_aux['isave'] = isave
                    
                    n_utility = np.unique(pd_aux['ygt'].values).shape[0]
                    columns = ['ypred_'+ str(i) for i in range(n_utility)]

                    ## include stats columns :
                    ypred1 = pd_aux[columns].values #get softmax predictions
                    ygt = pd_aux['ygt'].values #gt label
                    acc = np.argmax(ypred1,axis = 1) == ygt #accuracy
                    pd_aux['acc'] = acc
                    
                    samples_list = []
                    for splitset in ['train','validation','test']:
                        list_stats_group = []
                        list_group = []
                        
                        for group in pd_aux['group_tag'].unique():
                            pd_filter = pd_aux.loc[(pd_aux['dataset'].values == splitset) &\
                                               (pd_aux['group_tag'].values == group)]
                            
                            acc = pd_filter['acc'].values
                            prevmodel_acc = pd_filter['prevmodel_acc'].values

                            mNFR = 1 - np.mean((acc==0)&(prevmodel_acc == 1)) # 1 - NFR
                            psamples = acc.shape[0]
                            
                            list_stats_group.append([np.mean(prevmodel_acc),np.mean(acc),
                                                 np.mean(acc)-np.mean(prevmodel_acc),
                                                 mNFR, psamples])

                            list_group.append('group '+str(group))
                            if splitset == 'train':
                                samples_list.append(psamples)
                                
                        ### Sample Mean ###
                        pd_filter = pd_aux.loc[(pd_aux['dataset'].values == splitset)]

                        acc = pd_filter['acc'].values
                        prevmodel_acc = pd_filter['prevmodel_acc'].values
                        mNFR = 1 - np.mean((acc==0)&(prevmodel_acc == 1))
                        psamples = acc.shape[0]

                        list_group.append('sample_mean')
                        list_stats_group.append([np.mean(prevmodel_acc),np.mean(acc),
                                                     np.mean(acc)-np.mean(prevmodel_acc),
                                                     mNFR, psamples])

                        pd_stats_i = pd.DataFrame(data = list_stats_group,columns = ['acc_h1','acc_h2',
                                                                                     'dacc','1-NFR',
                                                                                     'nsamples'])
                        
                        #difference w.r.t NFR bound
                        pd_stats_i['dNFR'] = np.minimum(pd_stats_i['acc_h1'].values,1-pd_stats_i['acc_h2'].values)-(1-pd_stats_i['1-NFR'].values)
                        pd_stats_i['group'] = list_group
                        pd_stats_i['group_added'] = c
                        pd_stats_i['splitset'] = splitset
                        pd_stats_i['seed'] = seed
                        pd_stats = pd.concat([pd_stats,pd_stats_i], axis = 0)
                        
            
                        
    return pd_stats


## WATERBIRD/celebA DATASET 

In [32]:
data = 'CUB'
# data = 'celebA_blond'


previous_model = 'erm'
# previous_model = 'gmmf' #other option is gmmf


basedir = '/data/natalia/models/' + data + '/'


c_list = np.arange(4)

if previous_model == 'erm':
    model_prefix_dic = {'erm':'h2erm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayPlateau_reg1e4_CE_h1erm42data04addC',
                        'gmmf':'h2gmmf_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1erm42data04addC',
                        'grm':'h2grm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1erm42data04addC',
                        'srm':'h2srm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1erm42data04addC',
                        }
elif previous_model == 'gmmf':
    model_prefix_dic = {'erm':'h2erm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayPlateau_reg1e4_CE_h1gmmf42data04addC',
                        'gmmf':'h2gmmf_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1gmmf42data04addC',
                        'grm':'h2grm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1gmmf42data04addC',
                        'srm':'h2srm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1gmmf42data04addC',
                       }

pd_table = pd.DataFrame()
for model in model_prefix_dic.keys():
    pd_stats = get_stats(basedir, model_prefix_dic[model], isave = 1,
                         split_list = [1], seed_list=[42],
                         c_list = [0,1,2,3])
    pd_stats['model'] = model
    pd_table = pd.concat([pd_table,pd_stats],axis = 0)

pd_table['acc_h1'] = pd_table['acc_h1'].values*100
pd_table['acc_h2'] = pd_table['acc_h2'].values*100
pd_table['dacc'] = pd_table['dacc'].values*100

In [33]:
### Best and worst case summary table ###
splitset = 'test'

agg = {'dacc':[min,max],
       '1-NFR':[min,max],
       'dNFR':[min,max]}

pd_aux = pd_table.loc[(pd_table['splitset'] == splitset) &\
                      (pd_table['group'] != 'sample_mean')].groupby(['group_added','model']).agg(agg)


print('--- Worst and best per group added --- ')
print()
print(pd_aux.groupby(['group_added','model']).mean()) #.to_latex(float_format="%.3f"))
print()
print()





print('--- Average --- ')
print(pd_aux.groupby(['model']).mean()) #.to_latex(float_format="%.3f"))
print()

--- Worst and best per group added --- 

                       dacc                1-NFR                dNFR          
                        min        max       min       max       min       max
group_added model                                                             
0           erm   -4.301552   1.401869  0.925234  0.996896  0.003548  0.362928
            gmmf  -7.050998  24.454829  0.928603  0.992212  0.008869  0.149533
            grm   -0.934579   4.257206  0.956386  0.996452  0.005765  0.338006
            srm   -4.205607   0.576497  0.923676  0.999557  0.003548  0.367601
1           erm   -5.451713   8.381375  0.943925  0.999557  0.004435  0.376947
            gmmf  -6.119734  23.520249  0.938803  0.998442  0.009756  0.165109
            grm   -0.221729   3.426791  0.984424  0.995122  0.007095  0.352025
            srm   -0.842572   8.099688  0.985981  0.992461  0.008869  0.313084
2           erm   -1.419069  11.214953  0.977384  0.998442  0.009313  0.288162
           

In [34]:

#LATEX
print(pd_aux.groupby(['group_added','model']).mean().to_latex(float_format="%.2f"))
print()

print(pd_aux.groupby(['model']).mean().to_latex(float_format="%.2f"))
print()

\begin{tabular}{llrrrrrr}
\toprule
  &     & \multicolumn{2}{l}{dacc} & \multicolumn{2}{l}{1-NFR} & \multicolumn{2}{l}{dNFR} \\
  &     &   min &   max &   min &  max &  min &  max \\
group\_added & model &       &       &       &      &      &      \\
\midrule
0 & erm & -4.30 &  1.40 &  0.93 & 1.00 & 0.00 & 0.36 \\
  & gmmf & -7.05 & 24.45 &  0.93 & 0.99 & 0.01 & 0.15 \\
  & grm & -0.93 &  4.26 &  0.96 & 1.00 & 0.01 & 0.34 \\
  & srm & -4.21 &  0.58 &  0.92 & 1.00 & 0.00 & 0.37 \\
1 & erm & -5.45 &  8.38 &  0.94 & 1.00 & 0.00 & 0.38 \\
  & gmmf & -6.12 & 23.52 &  0.94 & 1.00 & 0.01 & 0.17 \\
  & grm & -0.22 &  3.43 &  0.98 & 1.00 & 0.01 & 0.35 \\
  & srm & -0.84 &  8.10 &  0.99 & 0.99 & 0.01 & 0.31 \\
2 & erm & -1.42 & 11.21 &  0.98 & 1.00 & 0.01 & 0.29 \\
  & gmmf & -5.06 & 23.83 &  0.95 & 1.00 & 0.01 & 0.16 \\
  & grm & -0.80 &  6.39 &  0.96 & 0.99 & 0.00 & 0.31 \\
  & srm & -1.42 & 11.21 &  0.98 & 1.00 & 0.01 & 0.29 \\
3 & erm & -4.88 &  6.54 &  0.92 & 1.00 & 0.01 & 0.34 \\
  & gmm

In [35]:

### Per Group BC analysis
for group in pd_table['group'].unique():
    if group not in ['sample_mean']:
        print('------------ ', group, ' ----------------')
        print(pd_table.loc[(pd_table['splitset'] == splitset) &\
                           (pd_table['group'] == group)].groupby(['group_added','model'])['acc_h1','dacc','1-NFR','dNFR'].mean())#.to_latex(float_format="%.3f"))
        print()
        print('-Average')
        print(pd_table.loc[(pd_table['splitset'] == splitset) &\
                           (pd_table['group'] == group)].groupby(['model'])['acc_h1','dacc','1-NFR','dNFR'].mean())#.to_latex(float_format="%.3f"))
        print()
        



------------  group 3  ----------------
                      acc_h1      dacc     1-NFR      dNFR
group_added model                                         
0           erm    84.423676  1.401869  0.976636  0.118380
            gmmf   84.423676  2.336449  0.964174  0.096573
            grm    84.423676 -0.934579  0.964174  0.129283
            srm    84.423676  0.467290  0.973520  0.124611
1           erm    84.423676 -5.451713  0.943925  0.154206
            gmmf   84.423676  6.230530  0.996885  0.090343
            grm    84.423676  0.155763  0.987539  0.141745
            srm    84.423676  1.557632  0.985981  0.126168
2           erm    84.423676  2.492212  0.996885  0.127726
            gmmf   84.423676  4.361371  0.987539  0.099688
            grm    84.423676  3.115265  0.990654  0.115265
            srm    84.423676  2.647975  0.996885  0.126168
3           erm    84.423676  6.542056  0.995327  0.085670
            gmmf   84.423676  7.320872  0.996885  0.079439
            grm 

  
  # Remove the CWD from sys.path while we load stuff.
  
  # Remove the CWD from sys.path while we load stuff.
  
  # Remove the CWD from sys.path while we load stuff.
  
  # Remove the CWD from sys.path while we load stuff.


In [None]:
### Per group summary table ###