In [102]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os, pickle
import sys
sys.path.append("../")

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [103]:
## Function to obtain stats table
def get_stats(basedir, model_name_prefix, isave = 1, split_list = [1],
              seed_list=[42], c_list = [0,1,2,3]):
    
    pd_stats = pd.DataFrame()
    for c in c_list:
        for split in split_list:
            for seed in seed_list:
                model_name = model_name_prefix+ str(c) + '_seed' + str(seed) + '_split' + str(split)
                if not os.path.exists(basedir+model_name+'/eval'+str(isave)+'.csv'):
                    print('Not found :', basedir+model_name+'/eval'+str(isave)+'.csv')
                else:
                    pd_aux = pd.read_csv(basedir+model_name+'/eval'+str(isave)+'.csv')
                    pd_aux['isave'] = isave
                    
                    n_utility = np.unique(pd_aux['ygt'].values).shape[0]
                    columns = ['ypred_'+ str(i) for i in range(n_utility)]

                    ## include stats columns :
                    ypred1 = pd_aux[columns].values #get softmax predictions
                    ygt = pd_aux['ygt'].values #gt label
                    acc = np.argmax(ypred1,axis = 1) == ygt #accuracy
                    pd_aux['acc'] = acc
                    
                    samples_list = []
                    for splitset in ['train','validation','test']:
                        list_stats_group = []
                        list_group = []
                        
                        for group in pd_aux['group_tag'].unique():
                            pd_filter = pd_aux.loc[(pd_aux['dataset'].values == splitset) &\
                                               (pd_aux['group_tag'].values == group)]
                            
                            acc = pd_filter['acc'].values
                            prevmodel_acc = pd_filter['prevmodel_acc'].values

                            mNFR = 1 - np.mean((acc==0)&(prevmodel_acc == 1)) # 1 - NFR
                            psamples = acc.shape[0]
                            
                            list_stats_group.append([np.mean(prevmodel_acc),np.mean(acc),
                                                 np.mean(acc)-np.mean(prevmodel_acc),
                                                 mNFR, psamples])

                            list_group.append('group '+str(group))
                            if splitset == 'train':
                                samples_list.append(psamples)
                                
                        ### Sample Mean ###
                        pd_filter = pd_aux.loc[(pd_aux['dataset'].values == splitset)]

                        acc = pd_filter['acc'].values
                        prevmodel_acc = pd_filter['prevmodel_acc'].values
                        mNFR = 1 - np.mean((acc==0)&(prevmodel_acc == 1))
                        psamples = acc.shape[0]

                        list_group.append('sample_mean')
                        list_stats_group.append([np.mean(prevmodel_acc),np.mean(acc),
                                                     np.mean(acc)-np.mean(prevmodel_acc),
                                                     mNFR, psamples])

                        pd_stats_i = pd.DataFrame(data = list_stats_group,columns = ['acc_h1','acc_h2',
                                                                                     'dacc','1-NFR',
                                                                                     'nsamples'])
                        
                        #difference w.r.t NFR bound
                        pd_stats_i['dNFR'] = np.minimum(pd_stats_i['acc_h1'].values,1-pd_stats_i['acc_h2'].values)-(1-pd_stats_i['1-NFR'].values)
                        pd_stats_i['group'] = list_group
                        pd_stats_i['group_added'] = c
                        pd_stats_i['splitset'] = splitset
                        pd_stats_i['seed'] = seed
                        pd_stats = pd.concat([pd_stats,pd_stats_i], axis = 0)
                        
            
                        
    return pd_stats


## WATERBIRD DATASET 

In [104]:
data = 'CUB'
# data = 'celebA_blond'


previous_model = 'erm'
# previous_model = 'gmmf' #other option is gmmf


basedir = '/data/natalia/models/' + data + '/'


c_list = np.arange(4)

if previous_model == 'erm':
    model_prefix_dic = {'erm':'h2erm_resnet34_pretrained_batchnorm_sgd1e4_ManualLRDecayNWReset_reg1e4_CE_h1erm42data04addC',
                        'gmmf':'h2gmmf_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1erm42data04addC',
                        'grm':'h2grm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1erm42data04addC',
                        'srm':'h2srm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1erm42data04addC',
                        }
elif previous_model == 'gmmf':
    model_prefix_dic = {'erm':'h2erm_resnet34_pretrained_batchnorm_sgd1e4_ManualLRDecayNWReset_reg1e4_CE_h1gmmf42data04addC',
                        'gmmf':'h2gmmf_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1gmmf42data04addC',
                        'grm':'h2grm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1gmmf42data04addC',
                        'srm':'h2srm_resnet34_pretrained_batchnorm_sgd1e3_ManualLRDecayNWReset_reg1e4_mw0wc05c025_CE_h1gmmf42data04addC',
                       }


pd_table = pd.DataFrame()
for model in model_prefix_dic.keys():
    pd_stats = get_stats(basedir, model_prefix_dic[model], isave = 1,
                         split_list = [1], seed_list=[42],
                         c_list = [0,1,2,3])
    pd_stats['model'] = model
    pd_table = pd.concat([pd_table,pd_stats],axis = 0)



In [105]:
pd_table

Unnamed: 0,acc_h1,acc_h2,dacc,1-NFR,nsamples,dNFR,group,group_added,splitset,seed,model
0,0.976501,0.976501,0.000000,0.997389,383,0.020888,group 3,0,train,42,erm
1,0.827586,0.827586,0.000000,0.982759,58,0.155172,group 2,0,train,42,erm
2,0.992855,0.999725,0.006870,1.000000,3639,0.000275,group 0,0,train,42,erm
3,0.990741,0.995370,0.004630,1.000000,216,0.004630,group 1,0,train,42,erm
4,0.989060,0.995112,0.006052,0.999534,4296,0.004423,sample_mean,0,train,42,erm
...,...,...,...,...,...,...,...,...,...,...,...
0,0.845794,0.869159,0.023364,0.979751,642,0.110592,group 3,3,test,42,srm
1,0.520249,0.523364,0.003115,0.962617,642,0.439252,group 2,3,test,42,srm
2,0.991574,0.994235,0.002661,0.997339,2255,0.003104,group 0,3,test,42,srm
3,0.846120,0.822173,-0.023947,0.952993,2255,0.130820,group 1,3,test,42,srm


In [106]:
### Best and worst case summary table ###
splitset = 'test'

agg = {'dacc':[min,max],
       '1-NFR':[min,max],
       'dNFR':[min,max]}

pd_aux = pd_table.loc[(pd_table['splitset'] == splitset) &\
                      (pd_table['group'] != 'sample_mean')].groupby(['group_added','model']).agg(agg)


print('--- Worst and best per group added --- ')
print()
print(pd_aux.groupby(['group_added','model']).mean()) #.to_latex(float_format="%.3f"))
print()
print()

print('--- Average --- ')
print(pd_aux.groupby(['model']).mean()) #.to_latex(float_format="%.3f"))
print()

--- Worst and best per group added --- 

                       dacc               1-NFR                dNFR          
                        min       max       min       max       min       max
group_added model                                                            
0           erm   -0.065421  0.005765  0.914330  0.998670  0.002661  0.434579
            gmmf  -0.043459  0.257009  0.954324  0.995327  0.006208  0.218069
            grm   -0.013747  0.140187  0.937472  0.982705  0.004878  0.302181
            srm   -0.020249  0.011086  0.922118  0.997339  0.002661  0.422118
1           erm   -0.028037  0.074766  0.957944  0.999113  0.006208  0.398754
            gmmf  -0.051885  0.292835  0.946785  1.000000  0.007095  0.186916
            grm   -0.033703  0.155763  0.952550  0.998442  0.007095  0.322430
            srm   -0.032710  0.043614  0.953271  0.991574  0.006208  0.414330
2           erm   -0.009313  0.132399  0.984479  1.000000  0.008426  0.347352
            gmmf  -0.06

In [107]:

### Per Group BC analysis
for group in pd_table['group'].unique():
    if group not in ['sample_mean']:
        print('------------ ', group, ' ----------------')
        print(pd_table.loc[(pd_table['splitset'] == splitset) &\
                           (pd_table['group'] == group)].groupby(['group_added','model'])['acc_h1','dacc','1-NFR','dNFR'].mean())#.to_latex(float_format="%.3f"))
        print()
        print('-Average')
        print(pd_table.loc[(pd_table['splitset'] == splitset) &\
                           (pd_table['group'] == group)].groupby(['model'])['acc_h1','dacc','1-NFR','dNFR'].mean())#.to_latex(float_format="%.3f"))
        print()
        



------------  group 3  ----------------
                     acc_h1      dacc     1-NFR      dNFR
group_added model                                        
0           erm    0.845794 -0.009346  0.971963  0.135514
            gmmf   0.845794  0.034268  0.971963  0.091900
            grm    0.845794  0.015576  0.967290  0.105919
            srm    0.845794  0.000000  0.970405  0.124611
1           erm    0.845794 -0.028037  0.957944  0.140187
            gmmf   0.845794  0.049844  0.987539  0.091900
            grm    0.845794  0.040498  0.992212  0.105919
            srm    0.845794 -0.032710  0.953271  0.140187
2           erm    0.845794  0.021807  0.992212  0.124611
            gmmf   0.845794  0.037383  0.976636  0.093458
            grm    0.845794  0.004673  0.975078  0.124611
            srm    0.845794  0.020249  0.992212  0.126168
3           erm    0.845794  0.065421  0.995327  0.084112
            gmmf   0.845794  0.076324  0.993769  0.071651
            grm    0.845794  0.0

  
  # Remove the CWD from sys.path while we load stuff.
  
  # Remove the CWD from sys.path while we load stuff.
  
  # Remove the CWD from sys.path while we load stuff.
  
  # Remove the CWD from sys.path while we load stuff.


In [None]:
### Per group summary table ###