In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot

In [None]:
val_df = pd.read_csv('../data/validation_predictions.csv')

In [None]:
# Ensure age is treated as a numerical feature
val_df['age'] = val_df['age'].astype(float)

# Group by age and gender and calculate the mean predicted probability for each group
age_gender_grouped = val_df.groupby(['age', 'gender'])['Predicted'].mean().reset_index()

# Calculate the overall mean predicted probability for statistical parity
overall_mean_pred = val_df['Predicted'].mean()

In [None]:

traces = []

# trace for females
traces.append(go.Scatter(
    x=age_gender_grouped[age_gender_grouped['gender'] == 0]['age'],
    y=age_gender_grouped[age_gender_grouped['gender'] == 0]['Predicted'] - overall_mean_pred,
    mode='lines+markers',
    name='Female',
    marker=dict(color='LightPink', size=10, line=dict(width=2)),
    line=dict(color='HotPink', width=2)
))

# trace for males
traces.append(go.Scatter(
    x=age_gender_grouped[age_gender_grouped['gender'] == 1]['age'],
    y=age_gender_grouped[age_gender_grouped['gender'] == 1]['Predicted'] - overall_mean_pred,
    mode='lines+markers',
    name='Male',
    marker=dict(color='LightSkyBlue', size=10, line=dict(width=2)),
    line=dict(color='DodgerBlue', width=2)
))

# trace for statistical parity line
traces.append(go.Scatter(
    x=age_gender_grouped['age'],
    y=[0]*len(age_gender_grouped['age']),
    mode='lines',
    name='Statistical Parity',
    line=dict(color='Gray', width=2, dash='dash')
))

# layout
layout = go.Layout(
    title='Fairness Partial Dependence Plot (FPDP) with Statistical Parity for Age',
    xaxis=dict(title='Age'),
    yaxis=dict(title='Deviation from Overall Mean Probability of ASD'),
    margin=dict(l=40, r=40, t=40, b=40)
)

fig = go.Figure(data=traces, layout=layout)
fig.show()
