In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Model definitions

In [None]:
def get_df(mode, polluted_ratio):
    acc       = pd.read_csv(f"test_{mode}_accuracy_p{polluted_ratio}.csv")
    acc_cls   = pd.read_csv(f"test_{mode}_accuracy_cls_p{polluted_ratio}.csv")
    violation = pd.read_csv(f"test_{mode}_violation_p{polluted_ratio}.csv")
    
    acc["mode"]=mode
    acc["polluted_ratio"]=polluted_ratio
    
    acc_cls["mode"]=mode
    acc_cls["polluted_ratio"]=polluted_ratio
    
    violation["mode"]=mode
    violation["polluted_ratio"]=polluted_ratio
        
    return acc, acc_cls, violation

In [None]:
modes=['baseline','ideal','lb','fair_naive','fair']
polluted_ratio = 0.3

accs       = pd.DataFrame()
accs_cls   = pd.DataFrame()
violations = pd.DataFrame()

for mode in modes :
    
    acc, acc_cls, violation = get_df(mode,polluted_ratio)

    accs       = accs.append(acc)
    accs_cls   = accs_cls.append(acc_cls)
    violations = violations.append(violation)
    
accs       = accs.rename(columns={'Unnamed: 0': 'Iteration', '0':'Accuracy'})
accs_cls   = accs_cls.rename(columns={'Unnamed: 0': 'Iteration'})
violations = violations.rename(columns={'Unnamed: 0': 'Iteration'})

accs_cls['Average'] = accs_cls.iloc[:,1:11].mean(axis=1)
violations['Average'] = violations.iloc[:,1:11].mean(axis=1)

In [None]:
def plot1(dat):
    fig = plt.figure()
    
    #lb, fair_naive, fair
    ax = sns.lineplot(data=dat[dat["mode"].isin(["lb","fair","fair_naive"])],x='Iteration', y='Accuracy',hue='mode')
    
    #baseline
    baseline = dat[dat["mode"].isin(["baseline"])]['Accuracy'].values[0]
    plt.axhline(y=baseline, c="y", linestyle="dashed", label="Baseline")
    
    #ideal
    ideal = dat[dat["mode"].isin(["ideal"])]['Accuracy'].values[0]
    plt.axhline(y=ideal,c="r",linestyle="dashed", label="Ideal")

    labels = ["Label Bias", "FAIR Naive", "FAIR", "Baseline", "Ideal"]
    handles, _ = ax.get_legend_handles_labels()

    plt.legend(handles = handles, labels = labels)

    plt.show()

## Accuracy

In [None]:
plot1(accs)

In [None]:
def plot2(dat, num_class=-1):
    fig = plt.figure()
    
    if num_class == -1:
        num_class = "Average"
    else :
        num_class = str(num_class)

    #lb, fair_naive, fair
    ax = sns.lineplot(data=dat[dat["mode"].isin(["lb","fair","fair_naive"])], x='Iteration', y=num_class, hue='mode')

    #baseline
    baseline = dat[dat["mode"].isin(["baseline"])][num_class].values[0]
#     print(baseline)
    plt.axhline(y=baseline,c="y",linestyle="dashed", label="Baseline")
    
    #ideal
    ideal    = dat[dat["mode"].isin(["ideal"])][num_class].values[0]
#     print(ideal)
    plt.axhline(y=ideal, c="r", linestyle="dashed", label="Ideal")

    labels = ["Label Bias", "FAIR Naive", "FAIR", "Baseline", "Ideal"]
    handles, _ = ax.get_legend_handles_labels()

    plt.legend(handles = handles, labels = labels)

    plt.show()

## Violation

In [None]:
plot2(violations)

In [None]:
plot2(violations,0)

In [None]:
plot2(violations,2)

## Per-class accuracy

In [None]:
plot2(accs_cls)