In [21]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns


model_names = ['inceptionv3','resnet18', 'MobileNetV2', 'resnet50']
hue_order = ['mean', 'max', 'gps', 'ours']
model_name_dict = {'resnet18': 'ResNet-18', 'resnet50': 'ResNet-50', 
                   'MobileNetV2': 'MobileNetV2', 'inceptionv3': 'InceptionV3',
                   'resnet18_added_data': 'ResNet-18 (+LD)',
                   'resnet50_added_data': 'ResNet-50 (+LD)',
                   'inceptionv3_added_data': 'InceptionV3 (+LD)',
                   'MobileNetV2_added_data': 'MobileNetV2 (+LD)',
                   'cifar100_cnn': "CIFAR100 CNN",
                   'stl10_cnn': "STL10 CNN"
                  }
dataset_name_dict = {'birds200': 'Birds200', 'flowers102':'Flowers102', 'imnet':'ImageNet',
                     'cifar100': 'CIFAR100', 'stl10': 'STL10'}
legend_name_dict = {'mean': 'Standard', 'gps': 'GPS', 'max': 'Max', 'ours': 'Ours', 
                    'partial_lr': 'AugTTA', 'class_lr': 'ClassTTA'}
agg_list = ['raw', 'max', 'mean', 'gps', 'ours']
#agg_list = ['raw', 'max', 'mean', 'gps', 'ours']

y_col = 'top5'

In [22]:
#all pil results
datasets = ['flowers102', 'imnet', 'cifar100', 'stl10']
model_names = [('inceptionv3','resnet18', 'MobileNetV2', 'resnet50'),
               ('inceptionv3','resnet18', 'MobileNetV2', 'resnet50'),
               ('cifar100_cnn',),
               ('stl10_cnn',)]
#policy = 'standard'
policy = 'pil'
all_expanded_results = []
for j, dataset in enumerate(datasets):
    dataset_model_names = model_names[j]
    for model_idx in range(len(dataset_model_names)):
        model_name = dataset_model_names[model_idx]
        policy_dir = policy
        if policy == 'standard':
            policy_dir = 'five_crop_hflip_scale'
#             if dataset in ['stl10']:
#                 policy_dir = 'hflip_modified_five_crop_scale'
            if dataset in ['cifar100', 'stl10']:
                policy_dir = 'five_crop_hflip_scale'
        results_path = "../results/" + dataset + "/" + policy_dir + "/val/" + model_name + "_agg_fs"
        results = pd.read_csv(results_path)
        results['dataset'] = dataset
        all_expanded_results.append(results)

In [23]:
all_results_df = pd.concat(all_expanded_results)
combo = all_results_df[all_results_df['aug'] == 'combo']
orig = all_results_df[all_results_df['aug'] == 'orig']
orig['aug'] = 'combo'
orig['agg'] = 'raw'
combo = pd.concat([combo, orig])
mean_df = combo.groupby(['dataset', 'model', 'agg']).mean().reset_index()
std_df = combo.groupby(['dataset', 'model', 'agg']).std().reset_index()
mean_df['top1_std'] = std_df['top1']
mean_df['top5_std'] = std_df['top5']
if y_col == 'top1':
    subset = mean_df[['dataset', 'model', 'agg', 'top1', 'top1_std']]
else:
    subset = mean_df[['dataset', 'model', 'agg', 'top5', 'top5_std']]

subset = subset[subset['agg'].isin(agg_list)]
pivoted_subset = subset.pivot_table(columns='agg', index=['dataset', 'model'])
pivoted_subset = pivoted_subset.reset_index()
pivoted_subset.columns = pivoted_subset.columns.map('|'.join).str.strip('|')

of pandas will change to not sort by default.

To accept the future behavior, pass 'sort=False'.


  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  after removing the cwd from sys.path.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """


In [24]:
pivoted_subset

Unnamed: 0,dataset,model,top5|gps,top5|max,top5|mean,top5|ours,top5|raw,top5_std|gps,top5_std|max,top5_std|mean,top5_std|ours,top5_std|raw
0,cifar100,cifar100_cnn,93.67,76.37,93.05,93.09,92.84,0.15,0.19,0.12,0.16,0.11
1,flowers102,MobileNetV2,97.75,93.34,97.58,98.62,97.65,0.1,0.06,0.1,0.05,0.12
2,flowers102,inceptionv3,97.73,95.04,97.65,98.77,97.31,0.1,0.17,0.1,0.06,0.16
3,flowers102,resnet18,97.74,91.28,97.48,97.66,97.56,0.11,0.06,0.1,0.09,0.1
4,flowers102,resnet50,98.12,95.69,98.16,99.18,97.89,0.05,0.08,0.1,0.06,0.1
5,imnet,MobileNetV2,90.73,84.64,90.32,90.69,90.24,0.08,0.08,0.05,0.05,0.05
6,imnet,inceptionv3,89.6,83.28,89.43,89.6,88.52,0.03,0.08,0.02,0.05,0.07
7,imnet,resnet18,89.55,84.22,89.2,89.69,89.02,0.06,0.07,0.05,0.07,0.07
8,imnet,resnet50,92.89,87.65,92.65,92.92,92.65,0.04,0.06,0.06,0.07,0.05
9,stl10,stl10_cnn,98.24,94.85,98.28,98.21,97.97,0.08,0.09,0.07,0.14,0.15


In [25]:
jj = combo[combo['agg'].isin(agg_list)]
kk = jj.groupby(['dataset', 'model', 'run', 'agg']).mean()
kk = kk.pivot_table(columns='agg', index=['dataset', 'model', 'run'])
kk_reset = kk.reset_index()
kk_reset.columns = kk_reset.columns.map('|'.join).str.strip('|')
from scipy.stats import wilcoxon
print(wilcoxon(kk_reset['top1|ours'], kk_reset['top1|raw']))
len(np.where(np.array(kk_reset['top1|ours']) > np.array(kk_reset['top1|gps']))[0])

WilcoxonResult(statistic=65.0, pvalue=3.266185610932541e-08)


34

In [26]:
pivoted_subset['dataset_name'] = pivoted_subset['dataset'].map(dataset_name_dict)
pivoted_subset['model_name'] = pivoted_subset['model'].map(model_name_dict)


In [27]:
print(wilcoxon(pivoted_subset['top1|ours'],pivoted_subset['top1|gps']))

KeyError: 'top1|ours'

In [28]:
# create new columns for each model name
pd.options.display.float_format = '{:,.2f}'.format
for method in agg_list:
    if  y_col == 'top1':
        mean_key = 'top1|' + method
        std_key = 'top1_std|' + method
    else:
        mean_key = 'top5|' + method
        std_key = 'top5_std|' + method


    mean_values = pivoted_subset[mean_key].round(2).map('{:.2f}'.format)
    std_values = pivoted_subset[std_key].round(2).map('{:.2f}'.format)
    values = "$" + mean_values + ' ± ' + std_values + "$"
    pivoted_subset[method] = values

In [29]:
select_idxs = ['dataset_name', 'model_name'] + agg_list
table = pivoted_subset[select_idxs]
table

Unnamed: 0,dataset_name,model_name,raw,max,mean,gps,ours
0,CIFAR100,CIFAR100 CNN,$92.84 ± 0.11$,$76.37 ± 0.19$,$93.05 ± 0.12$,$93.67 ± 0.15$,$93.09 ± 0.16$
1,Flowers102,MobileNetV2,$97.65 ± 0.12$,$93.34 ± 0.06$,$97.58 ± 0.10$,$97.75 ± 0.10$,$98.62 ± 0.05$
2,Flowers102,InceptionV3,$97.31 ± 0.16$,$95.04 ± 0.17$,$97.65 ± 0.10$,$97.73 ± 0.10$,$98.77 ± 0.06$
3,Flowers102,ResNet-18,$97.56 ± 0.10$,$91.28 ± 0.06$,$97.48 ± 0.10$,$97.74 ± 0.11$,$97.66 ± 0.09$
4,Flowers102,ResNet-50,$97.89 ± 0.10$,$95.69 ± 0.08$,$98.16 ± 0.10$,$98.12 ± 0.05$,$99.18 ± 0.06$
5,ImageNet,MobileNetV2,$90.24 ± 0.05$,$84.64 ± 0.08$,$90.32 ± 0.05$,$90.73 ± 0.08$,$90.69 ± 0.05$
6,ImageNet,InceptionV3,$88.52 ± 0.07$,$83.28 ± 0.08$,$89.43 ± 0.02$,$89.60 ± 0.03$,$89.60 ± 0.05$
7,ImageNet,ResNet-18,$89.02 ± 0.07$,$84.22 ± 0.07$,$89.20 ± 0.05$,$89.55 ± 0.06$,$89.69 ± 0.07$
8,ImageNet,ResNet-50,$92.65 ± 0.05$,$87.65 ± 0.06$,$92.65 ± 0.06$,$92.89 ± 0.04$,$92.92 ± 0.07$
9,STL10,STL10 CNN,$97.97 ± 0.15$,$94.85 ± 0.09$,$98.28 ± 0.07$,$98.24 ± 0.08$,$98.21 ± 0.14$


In [30]:
latex_string = table.to_latex(escape=False)
latex_string = latex_string.replace('±','\pm')
print(latex_string)

\begin{tabular}{llllllll}
\toprule
{} & dataset_name &    model_name &             raw &             max &            mean &             gps &            ours \\
\midrule
0 &     CIFAR100 &  CIFAR100 CNN &  $92.84 \pm 0.11$ &  $76.37 \pm 0.19$ &  $93.05 \pm 0.12$ &  $93.67 \pm 0.15$ &  $93.09 \pm 0.16$ \\
1 &   Flowers102 &   MobileNetV2 &  $97.65 \pm 0.12$ &  $93.34 \pm 0.06$ &  $97.58 \pm 0.10$ &  $97.75 \pm 0.10$ &  $98.62 \pm 0.05$ \\
2 &   Flowers102 &   InceptionV3 &  $97.31 \pm 0.16$ &  $95.04 \pm 0.17$ &  $97.65 \pm 0.10$ &  $97.73 \pm 0.10$ &  $98.77 \pm 0.06$ \\
3 &   Flowers102 &     ResNet-18 &  $97.56 \pm 0.10$ &  $91.28 \pm 0.06$ &  $97.48 \pm 0.10$ &  $97.74 \pm 0.11$ &  $97.66 \pm 0.09$ \\
4 &   Flowers102 &     ResNet-50 &  $97.89 \pm 0.10$ &  $95.69 \pm 0.08$ &  $98.16 \pm 0.10$ &  $98.12 \pm 0.05$ &  $99.18 \pm 0.06$ \\
5 &     ImageNet &   MobileNetV2 &  $90.24 \pm 0.05$ &  $84.64 \pm 0.08$ &  $90.32 \pm 0.05$ &  $90.73 \pm 0.08$ &  $90.69 \pm 0.05$ \\
6 &     Image

In [20]:
# for standard
# CIFAR uses partial_lr --> more classes, goes ot partial_lr
# STL-10 users full_lr  --> fewer classes, uses full_lr 
# for PIL
# CIFAR is inbetween --> partial
# STL is still full 