In [1]:
import numpy as np

# import seaborn as sns
# sns.set(style="darkgrid")
from scipy.stats import poisson
import pandas as pd
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score


In [2]:
from pathlib import Path
import pickle
from pif_dsbmm_dpf.citation_real.process_real import label_propagation

seed = 42
datadir = Path("/scratch/fitzgeraldj/data/caus_inf_data/")
data_model_str = f"real_seed{seed}"
data_model_path = datadir / f"{data_model_str}.pkl"
with open(data_model_path, "rb") as f:
    data_model = pickle.load(f)

Y = data_model.Y
old_aus = [np.flatnonzero(Y[-2].sum(axis=1)), np.flatnonzero(Y[-1].sum(axis=1))]
Y_heldout = data_model.Y_heldout
full_A_end = data_model.full_A_end
test_aus = data_model.test_aus

dsbmm_datadir = datadir / "dsbmm_data"
deg_corr = True
directed = True


In [75]:
def predict(params, A, Y_p, model):
    # only do for heldout aus, so final timestep
    gamma = params["Gamma_hat"][:, -1, :]
    alpha = params["Alpha_hat"][:, -1, :]
    z = params["Z_hat"][:, -1, :]
    variant = model.split(".")[-1]
    print(f"Model {model.split('.')[0]}; variant {variant}")
    if "dsbmm_dpf" in model or "network_pref_only" in model:
        dsbmm_res_str = f"{data_model_str}_{'dc' if deg_corr else 'ndc'}_{'dir' if directed else 'undir'}_{'meta' if variant=='z-theta-joint' else 'nometa'}"
        with open(dsbmm_datadir / f"{dsbmm_res_str}_subs.pkl", "rb") as f:
            _, Z_trans, block_probs = pickle.load(f)
        full_node_probs = label_propagation(
            test_aus,
            old_aus,
            full_A_end,
            params["Z_hat"][:, -2:, :].copy(),
            Z_trans,
            block_probs[..., -1],
            deg_corr=deg_corr,
        )
        z = full_node_probs[:, -1, :]
        if "dsbmm_dpf" in model:
            # need to expand to match form of rho
            z = np.pad(z, ((0, 0), (0, alpha.shape[-1])))

    # may want to do similar for w
    w = params["W_hat"][:, -1, :]
    beta = params["Beta_hat"][:, -1]

    rate = (beta * A).dot(Y_p)

    if model == "network_pref_only":
        rate += z.dot(gamma.T)
    elif model == "topic_only":
        rate += alpha.dot(w.T)
    elif "dsbmm_dpf" in model:
        rate += z.dot(gamma.T) + alpha.dot(w.T)
    try:
        return rate.toarray() + 1e-10
    except AttributeError:
        return rate + 1e-10


def get_ll(predicted, truth, restrict_users=None):
    if restrict_users is not None:
        predicted = predicted[restrict_users, :]
        # truth = truth[restrict_users,:] already restricted now
    return poisson.logpmf(truth.toarray(), predicted).sum(axis=1).mean()


def get_classification_metrics(pred, truth, restrict_users=None):
    if restrict_users is not None:
        pred = pred[restrict_users, :]
        # truth = truth[restrict_users,:] again already restricted
    return roc_auc_score(truth.toarray().flatten(), pred.flatten())


def get_influence_rates(params, A, Y_p):
    # again only want final timestep
    beta = params["Beta_hat"][:, -1]
    rate = (beta * A).dot(Y_p)
    mean_inf_rate = rate.mean(axis=1)
    return mean_inf_rate


def get_pres_beta_mean(full_beta):
    # take beta in shape (N,T-1) but w val 1.0
    # when author is missing, and only return mean of
    # (ragged) list of arrays length (T-1) with only
    # pres authors, for proper mean infl calc
    return [beta[beta != 1.0].mean() for beta in full_beta.T]


In [47]:
## Filter aus that publish at least 1 paper in the held-out period
aus_to_predict = Y_heldout.sum(axis=1) > 0
assert len(test_aus) == len(aus_to_predict)
print(
    "Num aus that publish at least one paper in the held-out data:",
    aus_to_predict.sum(),
)
aus_to_predict = test_aus


Num aus that publish at least one paper in the held-out data: 1235


### Load results; print average influence and heldout prediction results.

In [104]:
out = Path("/scratch/fitzgeraldj/data/caus_inf_data/real_results/")
b = "Beta_hat"
clean_names = {
    "unadjusted.main": "Unadjusted",
    #   'spf':'mSPF',
    "network_pref_only.main": "Network-Only",
    "topic_only.main": "Topic-Only",
    "dsbmm_dpf.z-theta-joint": "Ours",
}

methods = [
    "unadjusted.main",
    #  'network_pref_only.main',
    "topic_only.main",
    "dsbmm_dpf.z-theta-joint",
]
results = {
    m: np.load(
        (out / (m + "_pres_subs_ewcnone_model_fitted_params")) / "all_params.npz"
    )
    for m in methods
}

recalc_preds = False
try:
    preds
    assert not recalc_preds
except NameError:
    # only run if not already calc, or told to recalc
    preds = [predict(results[m], full_A_end, Y[-1], m) for m in methods]

hol = {
    m: get_ll(pred, Y_heldout, restrict_users=aus_to_predict)
    for m, pred in zip(methods, preds)
}
auc = {
    m: get_classification_metrics(pred, Y_heldout, restrict_users=aus_to_predict)
    for m, pred in zip(methods, preds)
}

data = [
    [clean_names[m], get_pres_beta_mean(results[m][b]), hol[m], auc[m]] for m in methods
]

df = pd.DataFrame(data, columns=["Method", "Average Estimated Influence", "HOL", "AUC"])
df


Unnamed: 0,Method,Average Estimated Influence,HOL,AUC
0,Unadjusted,"[0.08038918390170302, 0.0593276891523302, 0.04...",-16.90877,0.887669
1,Topic-Only,"[0.08913982282330712, 0.06904576833187484, 0.0...",-16.932607,0.887687
2,Ours,"[0.0714931751872941, 0.05125730121430877, 0.03...",-20.286492,0.926881


In [105]:
df[
    [
        f"$\langle\\beta^{t+1}\\rangle$"
        for t in range(df["Average Estimated Influence"].str.len()[0])
    ]
] = pd.DataFrame([*df["Average Estimated Influence"]])


In [106]:
df


Unnamed: 0,Method,Average Estimated Influence,HOL,AUC,$\langle\beta^1\rangle$,$\langle\beta^2\rangle$,$\langle\beta^3\rangle$,$\langle\beta^4\rangle$
0,Unadjusted,"[0.08038918390170302, 0.0593276891523302, 0.04...",-16.90877,0.887669,0.080389,0.059328,0.045317,0.029603
1,Topic-Only,"[0.08913982282330712, 0.06904576833187484, 0.0...",-16.932607,0.887687,0.08914,0.069046,0.048551,0.031316
2,Ours,"[0.0714931751872941, 0.05125730121430877, 0.03...",-20.286492,0.926881,0.071493,0.051257,0.038345,0.026874


In [313]:
out = df.copy()
out.drop(columns=["Average Estimated Influence"], inplace=True)
out.set_index("Method", inplace=True)
out.index.name = ""
out.columns = ["HOL", "AUC"] + out.columns[2:].tolist()
out = np.round(out, 2)
out = out.astype(str)
out["HOL"][out["HOL"] == out["HOL"].min(axis=0)] = (
    "\textbf{" + out["HOL"][out["HOL"] == out["HOL"].min(axis=0)] + "}"
)
out["AUC"][out["AUC"] == out["AUC"].max(axis=0)] = (
    "\textbf{" + out["AUC"][out["AUC"] == out["AUC"].max(axis=0)] + "}"
)
out.to_latex("./results/real_results.tex", escape=False)

out


  out.to_latex("./results/real_results.tex", escape=False)


Unnamed: 0,HOL,AUC,$\langle\beta^1\rangle$,$\langle\beta^2\rangle$,$\langle\beta^3\rangle$,$\langle\beta^4\rangle$
,,,,,,
Unadjusted,\textbf{-16.91},0.89,0.08,0.06,0.05,0.03
Topic-Only,-16.93,0.89,0.09,0.07,0.05,0.03
Ours,-20.29,\textbf{0.93},0.07,0.05,0.04,0.03


In [195]:
# look at per author influence inferred
beta_hats = [results[m]["Beta_hat"] for m in methods]
beta_idxs = [[np.flatnonzero(beta != 1.0) for beta in bhat.T] for bhat in beta_hats]
beta_hats_nz = [[beta[beta != 1.0] for beta in bhat.T] for bhat in beta_hats]


In [211]:
k = 1000
top_k_au_idxs = [
    [
        beta_idx[np.argsort(beta_hat)[-k:][::-1]]
        for beta_idx, beta_hat in zip(bidx, bhat)
    ]
    for bidx, bhat in zip(beta_idxs, beta_hats_nz)
]


In [212]:
with open(datadir / "real_seed42.pkl", "rb") as f:
    data_loader = pickle.load(f)
subset_au_idxs = data_loader.aus

full_net_paths = (dsbmm_datadir).glob("net_*.pkl")
nets = []
for fname in full_net_paths:
    with open(fname, "rb") as f:
        nets.append(pickle.load(f))


In [271]:
# see available metadata
list(dict(nets[-1].nodes(data=True)).values())[0].keys()


dict_keys(['eid', 'au_name', 'CitationCount', 'career_age', 'ASJC_cnt_vec', 'subjareas_cnt_vec', 'tpc_clust_id_cnt_vec', 'all_au_insts_cnt_vec', 'weighted_ASJC', 'weighted_subjareas', 'weighted_tpc_clusts', 'weighted_insts', 'main_ctry', 'main_adm1', 'main_ctry_1hot', 'main_adm1_1hot'])

In [295]:
tot_weighted_tpcs = [
    np.array(
        list(
            map(
                lambda x: x.sum() if x is not None else x,
                list(dict(net.nodes(data="weighted_tpc_clusts")).values()),
            )
        ),
        dtype=float,
    )
    for net in nets
]


In [296]:
tot_pubs = [
    np.array(list(dict(net.nodes(data="eid")).values()), dtype=float) for net in nets
]


In [306]:
av_n_aus = np.stack([tot_wt / tot_pub for tot_wt, tot_pub in zip(tot_weighted_tpcs, tot_pubs)],axis=1) 
av_n_aus = np.nanmean(av_n_aus,axis=1)


  av_n_aus = np.nanmean(av_n_aus,axis=1)


In [213]:
from functools import reduce

node_order = list(reduce(lambda x, y: set(x) | set(y), [net.nodes() for net in nets]))


In [214]:
for net in nets:
    missing_nodes = set(node_order) - set(net.nodes())
    net.add_nodes_from(missing_nodes)


In [250]:
au_names = [
    [
        node.get("au_name", f"unknown_{idx}")
        for idx, node in enumerate(list(dict(net.nodes(data=True)).values()))
    ]
    for net in nets
]
use_curr_cits = False
if use_curr_cits:
    # use cit count metadata from Elsevier -- for pubs in the period only
    au_cits = [
        [
            node.get("CitationCount", np.nan)
            for idx, node in enumerate(list(dict(net.nodes(data=True)).values()))
        ]
        for net in nets
    ]
else:
    # use net -- won't apply weighting, as otherwise will greatly exaggerate citations from high au papers
    # -- can address in future
    au_cits = [list(dict(net.in_degree()).values()) for net in nets]


In [251]:
all_known_names = np.zeros(len(node_order), dtype="O")
for au_name in au_names:
    name_arr = pd.Series(au_name)
    all_known_names[~name_arr.str.contains("unknown_")] = name_arr[
        ~name_arr.str.contains("unknown_")
    ]
all_known_names[all_known_names == 0] = "unknown_" + pd.Series(
    np.arange(len(all_known_names))[all_known_names == 0]
).astype(str)


In [252]:
subset_aus = all_known_names[subset_au_idxs]
subset_cits = [np.array(cit_t)[subset_au_idxs] for cit_t in au_cits]
subset_cumcits = np.nancumsum(np.stack(subset_cits, axis=1), axis=1)


In [253]:
top_k_aus = [
    [subset_aus[au_idxs] for au_idxs in top_k_idxs] for top_k_idxs in top_k_au_idxs
]
top_k_cits = [
    [cit_t[au_idxs] for cit_t, au_idxs in zip(subset_cits[1:], top_k_idxs)]
    for top_k_idxs in top_k_au_idxs
]
top_k_cumcits = [
    [cit_t[au_idxs] for cit_t, au_idxs in zip(subset_cumcits.T[1:], top_k_idxs)]
    for top_k_idxs in top_k_au_idxs
]


In [309]:
top_k_naus = [[av_n_aus[subset_au_idxs][au_idxs] for au_idxs in top_k_idxs] for top_k_idxs in top_k_au_idxs]

In [311]:
top_k_naus

[[array([1.10337692, 1.0154321 , 1.22846522, 1.02314815, 1.07106902,
         0.98214286,        nan, 0.96666667,        nan, 1.        ,
         0.99784123, 1.02361027, 1.01388889, 1.02777778, 1.18783422,
         1.        , 1.05555556, 0.97166667, 1.        , 1.01038679,
         1.06551852, 1.        , 1.01091103, 0.98714711, 0.96609269,
         0.98      , 0.95333333, 0.97619048, 1.00104157, 1.04166667,
         1.01155556, 1.        , 1.04481962, 1.08744728, 1.        ,
         1.00555556, 1.04221491, 1.        , 1.        , 1.08985806,
         1.        , 1.07710803, 1.125     , 1.        , 0.96844709,
         0.98148148, 0.99470899, 0.98333333, 1.05059524, 1.0212585 ,
         1.02324735, 1.07397504,        nan, 1.03016775, 1.        ,
         1.        , 1.08450099, 0.96637982, 1.02644051, 0.98731399,
         1.        , 0.98367347, 0.97992725, 1.        , 1.06060606,
         1.07407407, 1.06471289, 1.00347222, 1.01041667, 1.01825397,
         0.98333333, 1.        , 0

In [254]:
top_k_aus[-1][-1][:10]


array(['Takahashi R.H.C.', 'Maioli L.S.', 'Foschini C.R.',
       'Falciano F.T.', 'Campana-Filho S.P.', 'Munoz-Martinez L.F.',
       'Oliveira R.S.', 'Perez C.B.', 'Wan X.', 'Lu X.'], dtype=object)

In [255]:
from scipy.stats import spearmanr

[
    [
        spearmanr(cits[~np.isnan(cits)], bhat[top_k, t][~np.isnan(cits)])
        for t, (cits, top_k) in enumerate(zip(top_cits, top_idxs))
    ]
    for bhat, top_idxs, top_cits in zip(beta_hats, top_k_au_idxs, top_k_cits)
]


[[SpearmanrResult(correlation=0.06274684290027917, pvalue=0.0472889998661006),
  SpearmanrResult(correlation=0.0669488854174967, pvalue=0.0342736805138755),
  SpearmanrResult(correlation=-0.03234977324590226, pvalue=0.30679147384168903),
  SpearmanrResult(correlation=-0.009406848312849554, pvalue=0.7663862069438251)],
 [SpearmanrResult(correlation=0.06239526442492627, pvalue=0.04854503818822524),
  SpearmanrResult(correlation=-0.003631355487945088, pvalue=0.9086906050984273),
  SpearmanrResult(correlation=-0.053665935435837156, pvalue=0.08985450102890319),
  SpearmanrResult(correlation=-0.01700556485603804, pvalue=0.5911782152809242)],
 [SpearmanrResult(correlation=0.057999073729085776, pvalue=0.0667518382234787),
  SpearmanrResult(correlation=0.04504028334527306, pvalue=0.15466667668490267),
  SpearmanrResult(correlation=-0.033898774206585866, pvalue=0.2841962216015558),
  SpearmanrResult(correlation=-0.02859275815485218, pvalue=0.3664005524644348)]]

In [256]:
[
    [
        spearmanr(cits[cits != 0], bhat[top_k, t][cits != 0])
        for t, (cits, top_k) in enumerate(zip(top_cits, top_idxs))
    ]
    for bhat, top_idxs, top_cits in zip(beta_hats, top_k_au_idxs, top_k_cumcits)
]


[[SpearmanrResult(correlation=0.013139152443945578, pvalue=0.7777328209472075),
  SpearmanrResult(correlation=0.04331187621009696, pvalue=0.21147619830759237),
  SpearmanrResult(correlation=-0.017377577111408783, pvalue=0.595431751022006),
  SpearmanrResult(correlation=0.0029594565289727255, pvalue=0.927263143171816)],
 [SpearmanrResult(correlation=0.03164151023108008, pvalue=0.5050790428271119),
  SpearmanrResult(correlation=-0.04638930477461424, pvalue=0.17637430647586727),
  SpearmanrResult(correlation=-0.06260376894871404, pvalue=0.05514906558888096),
  SpearmanrResult(correlation=0.0019140875590886318, pvalue=0.9530913583652886)],
 [SpearmanrResult(correlation=0.05068868348677999, pvalue=0.27175641863941696),
  SpearmanrResult(correlation=0.006448688936904561, pvalue=0.8523926404778217),
  SpearmanrResult(correlation=-0.007707919297586622, pvalue=0.8139144373383917),
  SpearmanrResult(correlation=-0.04365604287830059, pvalue=0.17880645702521816)]]

In [263]:
# look at prop of top k missing cit info - basically just too high for this to be particularly meaningful if using meta
[
    [
        print(f"mod {mod_idx}, t {t}: {1-len(cits[~np.isnan(cits)])/k}")
        for t, (cits, top_k) in enumerate(zip(top_cits, top_idxs))
    ]
    for mod_idx, (bhat, top_idxs, top_cits) in enumerate(
        zip(beta_hats, top_k_au_idxs, top_k_cits)
    )
]
"show prop missing cit info for each model at each timestep"


mod 0, t 0: 0.0
mod 0, t 1: 0.0
mod 0, t 2: 0.0
mod 0, t 3: 0.0
mod 1, t 0: 0.0
mod 1, t 1: 0.0
mod 1, t 2: 0.0
mod 1, t 3: 0.0
mod 2, t 0: 0.0
mod 2, t 1: 0.0
mod 2, t 2: 0.0
mod 2, t 3: 0.0


'show prop missing cit info for each model at each timestep'

In [262]:
[
    [
        print(f"mod {mod_idx}, t {t}: {1-len(cits[cits!=0.0])/k}")
        for t, (cits, top_k) in enumerate(zip(top_cits, top_idxs))
    ]
    for mod_idx, (bhat, top_idxs, top_cits) in enumerate(
        zip(beta_hats, top_k_au_idxs, top_k_cumcits)
    )
]
"same for cum cits"


mod 0, t 0: 0.536
mod 0, t 1: 0.16600000000000004
mod 0, t 2: 0.06399999999999995
mod 0, t 3: 0.04600000000000004
mod 1, t 0: 0.554
mod 1, t 1: 0.14900000000000002
mod 1, t 2: 0.061000000000000054
mod 1, t 3: 0.05300000000000005
mod 2, t 0: 0.528
mod 2, t 1: 0.16500000000000004
mod 2, t 2: 0.06499999999999995
mod 2, t 3: 0.050000000000000044


'same for cum cits'

In [266]:
list(dict(nets[0].nodes(data=True)).values())[0]["main_adm1"]


'Sao Paulo'

In [269]:
au_adm1s = [
    [
        node.get("main_adm1", f"unknown_{idx}")
        for idx, node in enumerate(list(dict(net.nodes(data=True)).values()))
    ]
    for net in nets
]
subset_adm1s = [np.array(adm1)[subset_au_idxs] for adm1 in au_adm1s]
top_k_adm1s = [
    [sub_adm1[top_idxs] for sub_adm1, top_idxs in zip(subset_adm1s, top_k_idxs)]
    for top_k_idxs in top_k_au_idxs
]


In [None]:
# try instead using scholarly to get google scholar profiles
# from scholarly import scholarly
# would be interesting but need to link to proxy, and won't on maths machines
