In [1]:
import numpy as np
import seaborn as sns
sns.set(style="darkgrid")
from scipy.stats import poisson
import pandas as pd
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score

In [None]:
from pathlib import Path
import pickle
from pif_dsbmm_dpf.citation_real.process_real import label_propagation

seed = 42
datadir = Path('/scratch/fitzgeraldj/data/caus_inf_data/')
data_model_str = f"real_seed{seed}"
data_model_path = datadir / f"{data_model_str}.pkl"
with open(data_model_path, "rb") as f:
    data_model = pickle.load(f)

Y = data_model.Y
old_aus = [np.flatnonzero(Y[-2].sum(axis=1)),np.flatnonzero(Y[-1].sum(axis=1))]
Y_heldout = data_model.Y_heldout
full_A_end = data_model.full_A_end
test_aus = data_model.test_aus

dsbmm_datadir = datadir / "dsbmm_data"
deg_corr = True
directed = True
variant = 'z-theta-joint'
dsbmm_res_str = f"{data_model_str}_{'dc' if deg_corr else 'ndc'}_{'dir' if directed else 'undir'}_{'meta' if variant=='z-theta-joint' else 'nometa'}"
with open(dsbmm_datadir / f"{dsbmm_res_str}_subs.pkl", "rb") as f:
    _, Z_trans, block_probs = pickle.load(f)



In [13]:
def predict(params, A, Y_p, model):
    # only do for heldout aus, so final timestep
    gamma = params['Gamma_hat'][:,-1,:]
    alpha = params['Alpha_hat'][:,-1,:]
    z = params['Z_hat'][:,-1,:]
    if 'dsbmm_dpf' in model:
        full_node_probs = label_propagation(test_aus,old_aus,full_A_end,params['Z_hat'][:,-2:,:],Z_trans,block_probs,deg_corr=deg_corr)
        z = full_node_probs[:,-1,:]
    w = params['W_hat'][:,-1,:]
    beta = params['Beta_hat'][:,-1]
    
    rate = (beta * A).dot(Y_p)
    
    if model == 'network_pref_only':
        rate += z.dot(gamma.T)
    elif model == 'topic_only':
        rate += alpha.dot(w.T)
    elif 'dsbmm_dpf' in model:
        rate += z.dot(gamma.T) + alpha.dot(w.T)
    return rate + 1e-10

def get_ll(predicted, truth, restrict_users=None):
    if restrict_users is not None:
        predicted = predicted[restrict_users,:]
        truth = truth[restrict_users,:]
    return poisson.logpmf(truth, predicted).sum(axis=1).mean()

def get_classification_metrics(pred, truth, restrict_users=None):
    if restrict_users is not None:
        pred = pred[restrict_users,:]
        truth = truth[restrict_users,:]
    return roc_auc_score(truth.flatten(), pred.flatten())
    

def get_influence_rates(params,A, Y_p):
    beta = params['Beta_hat']
    rate = (beta * A).dot(Y_p)
    mean_inf_rate = rate.mean(axis=1)
    return mean_inf_rate

### Load adjacency matrix and past, current and future (held-out) song-listens

In [7]:
dat = '../../dat/lastfm/lastfm_processed.npz'
array = np.load(dat)
A = array['adj']
Y_p = array['y_past']
Y_heldout = array['y_heldout']
Y = array['y']

## Filter users that listen to at least 1 song in the held-out period
users_to_predict = (Y_heldout.sum(axis=1) > 0)
print("Num users that listen to at least one song in the held-out data:", users_to_predict.sum())

Num users that listen to at least one song in the held-out data: 3212


### Load results; print average influence and heldout prediction results.

In [12]:
out = '../../out/lastfm/'
b = 'Beta_hat'
clean_names = {
            'unadjusted.main':'Unadjusted',
            #   'spf':'mSPF',
              'network_pref_only.main':'Network-Only',
              'dsbmm_dpf.z-theta-joint':'Ours',
              'topic_only':'Topic-Only',
              }

methods = ['unadjusted', 'spf', 'network_pref_only', 'pif']
results = {m:np.load(out + m + '_fitted_params.npz') for m in methods}

hol = {m:get_ll(predict(results[m], A, Y_p, m), 
                Y_heldout, 
                restrict_users=users_to_predict) for m in methods}
auc = {m:get_classification_metrics(predict(results[m], A, Y_p, m), 
                                    Y_heldout, 
                                    restrict_users=users_to_predict) for m in methods}

data = [[clean_names[m], results[m][b].mean(), hol[m], auc[m]] for m in methods]

df = pd.DataFrame(data, columns=['Method', 'Average Estimated Influence', 'HOL', 'AUC'])
df

0.04085443425128536
0.04085443425128536


Unnamed: 0,Method,Average Estimated Influence,HOL,AUC
0,Unadjusted,0.003649,-331.743795,0.545955
1,mSPF,0.000377,-198.366396,0.659936
2,Network-Only,0.001833,-191.551117,0.547117
3,PIF,0.000627,-186.010893,0.667357
