In [None]:
import pandas as pd
import numpy as np
import os
import fnmatch, re

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn import preprocessing
from PIL import Image as PILImage

import matplotlib

import matplotlib.pyplot as plt
import seaborn as sns

import scipy
from scipy import stats
from scipy.stats import pearsonr, kendalltau, spearmanr

import altair as alt
import base64, io

In [None]:
def drop_demos(data,data_type,num_row_to_drop):
    data = data.drop(range(0,num_row_to_drop))
    if data_type == "real-fake-judgement":
        data['ground_truth'] = data['ground_truth'].astype(int)
        data = data.replace([0,1],["fake", "real"])
    elif data_type == "SAR":
        data["ground_truth"] = data["ground_truth"].replace(["0","1"],["fake", "real"])
    return data

def process_individual_SAR(file_path, file_name, column_to_read, rows_to_skip):
    participant_file_path = os.path.join(file_path, file_name)
    participant_data = pd.read_csv(participant_file_path, usecols = column_to_read)
    
    participant_data = drop_demos(participant_data, "SAR", rows_to_skip)
    return participant_data

def process_all_people_SAR(file_path, attribute_name):
    files = list(filter(lambda f: os.path.isfile(os.path.join(file_path,f)), os.listdir(file_path)))

    df_stat = pd.DataFrame()
    
    trial_num_path = os.path.join(file_path, files[0])
    #print(trial_num_path)
    
    image_info = ["Trial_num","image_name", "image_path","ground_truth"]
    aggregate = pd.read_csv(trial_num_path, usecols = image_info)
    aggregate = drop_demos(aggregate, "SAR", 5)
    
    use_response = ["Trial_num","image_name","image_path", "ground_truth", attribute_name]
    
    for person_data in files:
        if fnmatch.fnmatch(person_data, '*.csv'):
            person_result = process_individual_SAR(file_path, person_data, use_response, 5)
            #Normalize to 0-1 range
            person_result[attribute_name] = (person_result[attribute_name]-person_result[attribute_name].min())/(person_result[attribute_name].max()-person_result[attribute_name].min())
            aggregate = pd.concat([aggregate, person_result[attribute_name]], axis=1)
            all_res_name = attribute_name + "_all"
            aggregate[all_res_name] =  aggregate[aggregate.columns[4]].values.tolist()
            aggregate["mean_rating"] = aggregate[all_res_name].apply(np.mean)
            
    return aggregate

def combine_attributes():
    ## Merge All Attribute Mean ratings
    all_attribute_wide = pd.DataFrame({'image_name':translucency_score['image_name'],
                                    'image_path':translucency_score['image_path'],
                                  'ground_truth':translucency_score['ground_truth'],
                                  'translucency_mean':translucency_score["mean_rating"],
                                  'see_throughness_mean':seethroughness_score["mean_rating"],
                                  'glow_mean':glow_score["mean_rating"]}) 

    all_attribute_long = pd.melt(all_attribute_wide, id_vars=['image_name','image_path','ground_truth'],
                            var_name='attribute', value_name='mean_rating',
            value_vars=['translucency_mean', 'see_throughness_mean','glow_mean'])
    
    return all_attribute_wide, all_attribute_long


def corrfunc(x, y, **kws):
    ## Correlation plots
    r, p = stats.spearmanr(x, y)
    ax = plt.gca()
    # count how many annotations are already present
    n = len([c for c in ax.get_children() if 
                  isinstance(c, matplotlib.text.Annotation)])
    #  make positions for every label by hand
    pos = (.1, .9) if kws['label'] == 'real' else (.1,.85)

    ax.annotate("{}: rho = {:.3f}, p = {:.08f}".format(kws['label'],r, p),fontsize = 20,
                xy=pos, xycoords=ax.transAxes)
    
def my_hist(x, label, color):
    ax0 = plt.gca()
    ax = ax0.twinx()
    
    sns.despine(ax=ax, left=True, top=True, right=False)
    ax.yaxis.tick_right()
    ax.set_ylabel('Counts')
    
    ax.hist(x, label=label, color=color)

    
def annotate_scatter(data, **kws):
    x = data['mean_rating']
    y = data['dist_norm']
    r, p = stats.pearsonr(x, y)
#     r, p = stats.spearmanr(x, y)
    r, p = round(r,2), round(p,3)
    ax = plt.gca()
    ax.text(.1, .9, f"r_hc = {r}, p = {p}", transform=ax.transAxes)


def plot_svm_predict(file_name):
    #print(os.path.basename(file_name))
    svm_prediction = pd.read_csv(file_name)
    svm_vs_human = all_attributes_wide[all_attributes_wide["ground_truth"] == "fake"].merge(svm_prediction, how='left', on='image_name')
    
    svm_vs_human["dist_norm"] = (svm_vs_human["distance_bound"]-svm_vs_human["distance_bound"].min())/(svm_vs_human["distance_bound"].max()-svm_vs_human["distance_bound"].min())
    print(svm_vs_human)
    r_tran, p_tran = stats.pearsonr(svm_vs_human['translucency_mean'], svm_vs_human["dist_norm"])  
    r_see, p_see = stats.pearsonr(svm_vs_human['see_throughness_mean'], svm_vs_human["dist_norm"])  
    r_glow, p_glow = stats.pearsonr(svm_vs_human['glow_mean'], svm_vs_human["dist_norm"])  
    
    print("Translucency r:",r_tran, p_tran)
    print("See-through r:",r_see, p_see)
    print("Glow r:",r_glow, p_glow)


    
    svm_vs_human_long =  pd.melt(svm_vs_human, id_vars=['image_name','image_path','ground_truth','distance_bound','dist_norm','predicted'],
                                var_name='attribute', value_name='mean_rating',
                value_vars=['translucency_mean', 'see_throughness_mean','glow_mean'])
    
    svm_vs_human_long.to_csv("svm_vs_human_long_layer9.csv")
    
    file_name = os.path.basename(file_name)
    sns.color_palette("rocket", as_cmap=True)
    g = sns.FacetGrid(svm_vs_human_long,"attribute", margin_titles=False, hue = "attribute", hue_order = ["see_throughness_mean","glow_mean", "translucency_mean"],
                      hue_kws=dict(marker=["X", "s", "o"]), height=3, aspect=1.2,
                     palette = "deep")

    g.map_dataframe(sns.scatterplot, x="dist_norm", y="mean_rating", s = 70, alpha = 1)
    #g.map_dataframe(annotate_scatter)
    g.set(ylabel='Mean attribute rating', ylim=(0, 1))
    g.set_titles(' ', ' ', ' ')
    
    plt.xticks([0, 0.5, 1])
    plt.yticks([0, 0.5, 1])
    g.set(xlabel=None, ylabel=None, xticklabels=[], yticklabels=[])
    #g.fig.suptitle(file_name)

    return r_tran, r_see, r_glow, p_tran, p_see, p_glow

    
def plot_r_trend(file_path):
    layer_names, r_trans_layers, r_see_layers, r_glow_layers, p_tran_layers, p_see_layers, p_glow_layers= [],[],[],[],[],[],[]
    for i in range(18):
        print("Layer:", i)
        layer_name = str(i)
        path = file_path + str(i) + ".csv"
        r_tran, r_see, r_glow, p_tran, p_see, p_glow = plot_svm_predict(path)

        layer_names.append(layer_name)
        r_trans_layers.append(np.abs(r_tran))
        r_see_layers.append(np.abs(r_see))
        r_glow_layers.append(np.abs(r_glow))
        p_tran_layers.append(p_tran) 
        p_see_layers.append(p_see)
        p_glow_layers.append(p_glow)

    r_table = pd.DataFrame({'layer_name':layer_names,
                            'translucent': r_trans_layers,
                            'see-throughness':r_see_layers,
                            'glow':r_glow_layers,
                            'tran-p':p_tran_layers,
                            'see-p':p_see_layers,
                            'glow-p':p_glow_layers
                            })

    return r_table


# Read data

In [None]:
## Read the data
SAR_folder_path = "data/Experiment2"

seethroughness_score = process_all_people_SAR(SAR_folder_path, "seethroughness_score.response")
translucency_score = process_all_people_SAR(SAR_folder_path, "translucency_score.response")
glow_score = process_all_people_SAR(SAR_folder_path, "glow_score.response")

# Plot material attribute ratings

In [None]:
## Define color palette for data of real and generated images 

real_fake_color_palette = dict(real="#bdbdbd", fake="#2ca02c")

real_fake_hue_order = ["fake","real"]


In [None]:
def plot_hist(df, title):
    plot = sns.histplot(data=df, x="mean_rating", hue="ground_truth",binwidth = 0.2,stat="percent", alpha=0.7,
                                 hue_order=real_fake_hue_order,edgecolor= "white",linewidth=1,
                                 shrink=1,common_norm=False,
                                 palette=real_fake_color_palette, legend=True)
    
    

    plot.set(xticks=np.arange(0,1.2,0.2), yticks=np.arange(0,61,20))


    sns.move_legend(plot, "upper left", bbox_to_anchor=(1, 1), title='Groun truth')
    sns.set_context(context='poster', font_scale=0.8)
    sns.despine(offset=10, left=False, right = True)
    #plot.set(xlabel=None, ylabel=None, xticklabels=[], yticklabels=[])



In [None]:
print("Plot Translucency ratings")
plot_hist(translucency_score, "Trans")


In [None]:
print("Plot See-throughness ratings")
plot_hist(seethroughness_score, "See-throughness")

In [None]:
print("Plot Glow ratings")
plot_hist(glow_score, "Glow")


In [None]:
## Merge all attributes
all_attributes_wide, all_attribute_long = combine_attributes()

all_attributes_wide[all_attributes_wide["ground_truth"] == "fake"].sort_values(by=['translucency_mean'])

In [None]:
all_attributes_wide.to_csv("human_rating.csv")

# Pair plot (Figure 2C)

In [None]:
sns.set_context(context='poster', font_scale=0.8)

pair_plot_attribute = sns.pairplot(all_attributes_wide, hue = 'ground_truth', 
             diag_kind = 'hist', markers=["o", "o"],
             palette=real_fake_color_palette,
             hue_order = real_fake_hue_order,
             plot_kws = {'alpha': 0.8, 's': 100, 'edgecolor': 'k'},
             diag_kws = {'alpha': 0.3, 'binwidth' : 0.2, 'multiple' : "dodge", 'kde':False},
             #diag_kws = {'alpha': 0.3, 'binwidth' : 1, 'multiple' : "dodge", 'kde':False},
             grid_kws=dict(diag_sharey=False),                      
             height = 5, corner = True)
sns.despine(offset=10, left=False, right = True)



pair_plot_attribute = pair_plot_attribute.map_lower(corrfunc)
pair_plot_attribute = pair_plot_attribute.set(xlim=(0, 1), ylim=(0, 1))


In [None]:
def plot_scatter(data, feature1, feature2):
    sns.set_context(context='poster', font_scale=0.8)
    
    fig, ax = plt.subplots(figsize=(5, 5), dpi=100)
    
    plot = sns.scatterplot(data=data, x=feature1, y=feature2, hue="ground_truth",
                           alpha=0.7, hue_order=real_fake_hue_order, s=200,
                           palette=real_fake_color_palette,
                           legend=True)
    plot.set(xlim=(0, 1), ylim=(0, 1))
    #plot.set(xlabel=None, ylabel=None, xticklabels=[], yticklabels=[])
    sns.despine(offset=10, left=False, right = True)
    sns.move_legend(plot, "upper left", bbox_to_anchor=(1, 1), title='Groun truth')
    
    
plot_scatter(all_attributes_wide, "translucency_mean", "see_throughness_mean")

In [None]:
plot_scatter(all_attributes_wide, "translucency_mean", "glow_mean")

In [None]:
plot_scatter(all_attributes_wide, "see_throughness_mean", "glow_mean")

# Show images

In [None]:
all_attributes_wide[all_attributes_wide["ground_truth"] == "fake"].sort_values(by=['translucency_mean'])

# Compare with SVM prediction (Figure 5B and Figure 5C)

In [None]:
sns.set_context(context='poster', font_scale=0.5)
file_path_svm_pred = "data/svm_c001/svm_prediction_"
#file_path_svm_pred = "data/svm_c0001/svm_prediction_"

r_table = plot_r_trend(file_path_svm_pred)

# Plot layer-wise SVM prediction with human perception

In [None]:
r_trend_long = r_table.melt('layer_name', var_name='attribute', value_name='correlation')


xlabel = list(r_table["layer_name"])

fig, ax = plt.subplots()
fig.set_size_inches(10, 7)

sns.set_context(context='poster', font_scale=1)
r_trend = sns.lineplot(
                        data=r_trend_long,
                        x="layer_name", y="correlation", hue="attribute", style="attribute",
                        markers=True, dashes=True, alpha  = 0.9, markersize=20,
                        palette="deep", hue_order = ["see-throughness","glow", "translucent"],
    
                    )

ticks_val = range(18)
tick_val_str = [str(i+1) for i in ticks_val]


plt.xticks(ticks = ticks_val,labels = tick_val_str, rotation=90, fontsize=20)
sns.move_legend(r_trend, "upper left", bbox_to_anchor=(1, 1), title='Attribute')
sns.despine()
r_trend.set(xlabel='Layer of W+ latent space', ylabel='Correlation',
                           title='Correlation with human perception')

In [None]:
r_table["layer_name"] = tick_val_str
r_table["translucent"] = np.round(r_table["translucent"],3)
r_table["see-throughness"] = np.round(r_table["see-throughness"],3)
r_table["glow"] = np.round(r_table["glow"],3)

#pd.options.display.float_format = '{:.3e}'.format

# r_table = r_table.style.format({
#    'tran-p': '{:.2e}'.format,
#     'see-p': '{:.2e}'.format,
#     'glow-p': '{:.2e}'.format,
#     'translucent':'{:.2f}'.format,
#     'see-throughness':'{:.2f}'.format,
#     'glow':'{:.2f}'.format,
# })



In [None]:
## Create new column based on complex conditions
# create a list of our conditions
def get_conditions(col_name):
    #col_name = "tran-p"
    conditions = [
        (r_table[col_name] <= 0.0005),
        (r_table[col_name] > 0.0005) & (r_table[col_name] <= 0.005),
        (r_table[col_name] > 0.005) & (r_table[col_name] <= 0.001),
        (r_table[col_name] > 0.001)
        ]
    return conditions

# create a list of the values we want to assign for each condition

for col_name in ["tran-p", "see-p", "glow-p"]:
    values = ['<0.0005', '<0.005', '<0.001', np.round(r_table[col_name],3)]

    # create a new column and use np.select to assign values to it using our lists as arguments
    p_val_name = col_name 
    r_table[p_val_name] = np.select(get_conditions(col_name), values)

r_table = r_table.rename({"tran-p":"translucent ",
              "see-p":"see-throughness ",
              "glow-p": "glow ", 
             }, axis=1)
r_table

# Show layer-wise correlation with human perceptual rating (Supplementary Figure S6)

In [None]:
pd.set_option('precision', 2)

r_table2 = r_table.style.set_properties(**{'text-align': 'center',
                                          'translucent':'{:,.2f}'.format,
                                          'see-throughness':'{:,.2f}'.format,
                                          'glow':'{:,.2f}'.format,
                                          'background-color': 'white'})

r_table2

# Test correlation with tSNE and MDS features (Supplementary Figure S10)

In [None]:
sns.set_context(context='poster', font_scale=0.5)

file_name = "data/svm_dim_reduction/svm_prediction_tsne_5_norm.csv"
plot_svm_predict(file_name)

In [None]:
file_name = "data/svm_dim_reduction/svm_prediction_mds_norm.csv"
plot_svm_predict(file_name)

In [None]:
plot_svm_predict("data/svm_c0001/svm_prediction_9.csv")