# Import data

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np


tab = pd.read_csv('BFree_viral_images.csv')[['filename','label','days_since_1st_post']]
score_csv = pd.read_csv('/path/to/score.csv')
tab = tab.merge(score_csv)

# score csv must have the column 'filename' and the LOGIT scores of method1, method2, ...
# the predictions must be LOGITS (in the range [-inf,inf])

# merged csv example:
#            filename               | label | days_since_1st_post |  B-Free  |    DMID
# ----------------------------------|-------|---------------------|----------|----------
# REAL/Ed7JvuVXsAUTp-j/img00400.jpg | REAL  |     0.000000        |-4.431049 | -9.541626
# FAKE/FsL1ChiXwAAaVFk/img03517.jpg | FAKE  |     5.747164        | 5.462078 |  0.092956

algs = [_ for _ in tab.columns if _ not in ['filename','label','days_since_1st_post']]
print('Methods: ', algs)
tab

# Compute metrics

In [None]:
# ---- if you have SCORE PROBABILITIES instead of LOGITS, 
# ---- please change the threshold accordingly (e.g., from 0 to 0.5)
threshold = 0

tab_delta = tab.sort_values('days_since_1st_post').copy().reset_index(drop=True)
for alg in algs:
    # TPR, TNR
    tab_delta[alg+'_tpr'] = np.cumsum([(a> threshold)&(b=='FAKE') for a,b in zip(tab_delta[alg],tab_delta['label'])]) / np.cumsum([(b=='FAKE') for a,b in zip(tab_delta[alg],tab_delta['label'])])
    tab_delta[alg+'_tnr'] = np.cumsum([(a<=threshold)&(b!='FAKE') for a,b in zip(tab_delta[alg],tab_delta['label'])]) / np.cumsum([(b!='FAKE') for a,b in zip(tab_delta[alg],tab_delta['label'])])
    # balanced accuracy
    tab_delta[alg+'_acc']= 100*(tab_delta[alg+'_tpr'] + tab_delta[alg+'_tnr'])/2

# ignoring the first few, as we have too few samples for accuracy    
tab_delta = tab_delta.iloc[5:]

# Plot graph

In [None]:
sns.set_theme(context='notebook', style='darkgrid')

fig = plt.figure(figsize=(5,3), dpi=100)
for idx, alg in enumerate(algs):
    ax = sns.lineplot(data=tab_delta, x='days_since_1st_post', y=alg+'_acc', label=alg, linewidth=2.5)
legend = plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=14)

plt.xlim([0.1,100])
plt.ylim([50,100])
plt.ylabel('bAcc (%)')
plt.xlabel('Period (days)')
plt.xscale("log")

ax.minorticks_on()
ax.grid(True, which='major', axis='x', linestyle='-',  linewidth=1)   # Log grid on x-axis
ax.grid(True, which='minor', axis='x', linestyle='--', linewidth=0.7) # Log grid on x-axis
ax.grid(True, which='major', axis='y', linestyle='-',  linewidth=1)   # Normal grid on y-axis

ax.set_xticks([0.1, 1, 10, 100])
ax.set_xticklabels(["0", "1", "10", "100"])

plt.show()