# Get the graph for model performance and fairness
## Imports

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

from lib.utils import *

faceattrmodel_attributes = ['Race', 'Gender', 'Age']

def resolve_categori_stat(attr, stat, length):
    # stat should be in shape (N, attributes, 4)
    # 4: group 1 correct / group 1 wrong / group 2 correct /  group 2 wrong
    g1_acc_list, g2_acc_list, total_acc_list, acc_diff_list = list(), list(), list(), list()
    for batch_idx in range(length):
        g1_acc = stat[batch_idx,attr,0] / (stat[batch_idx,attr,0]+stat[batch_idx,attr,1])
        g2_acc = stat[batch_idx,attr,2] / (stat[batch_idx,attr,2]+stat[batch_idx,attr,3])
        total_acc = (stat[batch_idx,attr,0]+stat[batch_idx,attr,2])/(np.sum(stat[batch_idx,attr,:]))
        acc_diff = abs(g1_acc-g2_acc)
        g1_acc_list.append(g1_acc)
        g2_acc_list.append(g2_acc)
        total_acc_list.append(total_acc)
        acc_diff_list.append(acc_diff)
    # return multiple lists of stats for a single attribute
    return g1_acc_list, g2_acc_list, total_acc_list, acc_diff_list

In [None]:
def show_faceattrmodel_stat(val_stat, train_stat=np.array([]), length=None, marker=".", markersize=4, save_name='default', root_folder='./eval/celeba'):
    # resolve the output file path
    folder = Path(root_folder)
    folder.mkdir(parents=True, exist_ok=True)
    path = folder / f"{save_name}.png"
    # parse the stat, Face attributes model stats are in shape (N, attributes, 4)
    x_axis = np.linspace(0, length-1, length) if length else np.linspace(0, val_stat.shape[0]-1, val_stat.shape[0])
    fig, axs  = plt.subplots(2,3, figsize=(14,8))
    for attr in range(val_stat.shape[1]): # for each attribute
        val_g1_acc_list, val_g2_acc_list, val_total_acc_list, val_acc_diff_list = resolve_categori_stat(attr, val_stat, x_axis.shape[0])
        axs[0][attr].set_title(faceattrmodel_attributes[attr])
        axs[0][attr].set_xlabel('Epochs')
        axs[0][attr].set_ylabel('Accuracy')
        axs[0][attr].set_ylim([0.5, 1.0])
        axs[1][attr].set_xlabel('Epochs')
        axs[1][attr].set_ylabel('Fairness, (lower the better)')
        axs[1][attr].set_ylim([0.0, 1.0])
        if len(train_stat):
            train_g1_acc_list, train_g2_acc_list, train_total_acc_list, train_acc_diff_list = resolve_categori_stat(attr, train_stat, x_axis.shape[0])
            train_total_acc, = axs[0][attr].plot(x_axis, train_total_acc_list, marker=marker, markersize=markersize)
            val_total_acc, = axs[0][attr].plot(x_axis, val_total_acc_list, marker=marker, markersize=markersize)
            axs[0][attr].legend((train_total_acc, val_total_acc), ('Training Acc.', 'Validation Acc.',), loc='lower right')
            train_acc_diff, = axs[1][attr].plot(x_axis, train_acc_diff_list, marker=marker, markersize=markersize)
            val_acc_diff, = axs[1][attr].plot(x_axis, val_acc_diff_list, marker=marker, markersize=markersize)
            axs[1][attr].legend((train_acc_diff, val_acc_diff), ('Training Acc. differences', 'Validation Acc. differences',), loc='upper right')
        else:
            val_g1_acc, = axs[0][attr].plot(x_axis, val_g1_acc_list, marker=marker, markersize=markersize)
            val_g2_acc, = axs[0][attr].plot(x_axis, val_g2_acc_list, marker=marker, markersize=markersize)
            val_total_acc, = axs[0][attr].plot(x_axis, val_total_acc_list, marker=marker, markersize=markersize)
            axs[0][attr].legend((val_g1_acc, val_g2_acc, val_total_acc), ('Group 1', 'Group 2', 'Total'), loc='lower right')
            val_acc_diff, = axs[1][attr].plot(x_axis, val_acc_diff_list, marker=marker, markersize=markersize)
            axs[1][attr].legend((val_acc_diff,), ('Acc. differences',), loc='upper right')
    fig.tight_layout()
    fig.savefig(path,)
    plt.close(fig)

def print_faceattrmodel_stat_by_epoch(epoch, val_stat, train_stat=np.array([])):
    for attr in range(val_stat.shape[1]): # for each attribute
        print(f'==== {faceattrmodel_attributes[attr]} ====')
        val_g1_acc_list, val_g2_acc_list, val_total_acc_list, val_acc_diff_list = resolve_categori_stat(attr, val_stat, val_stat.shape[0])
        if len(train_stat):
            train_g1_acc_list, train_g2_acc_list, train_total_acc_list, train_acc_diff_list = resolve_categori_stat(attr, train_stat, train_stat.shape[0])
            print(f'Training:')
            print(f'    Group 1 Acc.: {train_g1_acc_list[epoch]:.4f}')
            print(f'    Group 2 Acc.: {train_g2_acc_list[epoch]:.4f}')
            print(f'    Total   Acc.: {train_total_acc_list[epoch]:.4f}')
            print(f'        Acc. differences: {train_acc_diff_list[epoch]:.4f}')
        print(f'Validation:')
        print(f'    Group 1 Acc.: {val_g1_acc_list[epoch]:.4f}')
        print(f'    Group 2 Acc.: {val_g2_acc_list[epoch]:.4f}')
        print(f'    Total   Acc.: {val_total_acc_list[epoch]:.4f}')
        print(f'        Acc. differences: {val_acc_diff_list[epoch]:.4f}')
        print(f'')


model_ckpt_root = Path('/tmp2/npfe/model_stats')
# draw FairFace stats here

# draw UTKFace stats here

