In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../src/generic')
import csv
import os
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_theme()

In [76]:
from dataset.amazon_reviews_clf_dataset import AmazonClfDataset
from results.process_results import ResultProcessor

In [4]:
main_result_dir = "/data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf"

In [67]:
test_cols = ["test_{}_accuracy".format(i) for i in range(5)]

## Majority Class Baseline

In [78]:
data_dir = "/data/ddmg/redditlanguagemodeling/data/AmazonReviews/data"

In [79]:
data_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0/reviews.csv'),
                      dtype={'reviewerID':str, 'asin':str, 'reviewTime':str,'unixReviewTime':int,
                             'reviewText':str,'summary':str,'verified':bool,'category':str, 'reviewYear':int},
                      keep_default_na=False, na_values=[], quoting=csv.QUOTE_NONNUMERIC)

In [80]:
split_df = pd.read_csv(os.path.join(data_dir, 'amazon_v2.0', 'splits', 'my_user_split.csv'))

In [95]:
# get select people
select_file = "/data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/selected_people.txt"
with open(select_file, 'r') as f:
    people = f.read().splitlines()

In [98]:
data_df["split"] = split_df["split"]

In [99]:
select_df = data_df[data_df["reviewerID"].isin(people)]

In [97]:
len(select_df)

140326

In [103]:
test_df = select_df[split_df["split"] == 2]

  test_df = select_df[split_df["split"] == 2]


In [228]:
train_df = select_df[split_df["split"] == 0]

  train_df = select_df[split_df["split"] == 0]


In [229]:
len(train_df)

43555

In [104]:
len(test_df)

14687

In [113]:
user_maj_cls = test_df.groupby(["reviewerID"])[["overall"]].agg(lambda x: x.value_counts(normalize=True).iloc[0])
user_maj_cls

Unnamed: 0_level_0,overall
reviewerID,Unnamed: 1_level_1
A101S5PLO0VRHQ,0.314286
A10E0V7PGY34UZ,0.933333
A10O7THJ2O20AG,0.800000
A11P853U6FIKAM,0.421053
A12O5K3EQ4MC7Z,0.411765
...,...
AYT4FJYVCHYLE,0.466667
AYVW3O6W8S5S4,0.533333
AZD488SA9QMYF,0.933333
AZJ4DFLH9O4FZ,0.400000


In [114]:
print(user_maj_cls.mean())
print(user_maj_cls.std())

overall    0.650542
dtype: float64
overall    0.195295
dtype: float64


In [227]:
print(user_maj_cls.quantile(q=[.2, .4, .6, .8, 1]))

     test_accuracy_mb
0.2          0.466667
0.4          0.577671
0.6          0.686352
0.8          0.866667
1.0          1.000000


In [230]:
# what if you took majority class from train data and predicted that for test data?
# get majority class from train data
user_maj_cls_train = train_df.groupby(["reviewerID"])[["overall"]].agg(lambda x: x.value_counts().index[0])

In [231]:
user_maj_cls_train

Unnamed: 0_level_0,overall
reviewerID,Unnamed: 1_level_1
A101S5PLO0VRHQ,2.0
A10E0V7PGY34UZ,5.0
A10O7THJ2O20AG,5.0
A11P853U6FIKAM,5.0
A12O5K3EQ4MC7Z,5.0
...,...
AYT4FJYVCHYLE,4.0
AYVW3O6W8S5S4,4.0
AZD488SA9QMYF,5.0
AZJ4DFLH9O4FZ,5.0


In [232]:
score_df = test_df[["reviewerID", "overall"]]
score_df.head(5)

Unnamed: 0,reviewerID,overall
260,A1AEPMPA12GUJ7,4.0
1705,A4UWNRY0WWECK,5.0
1837,A23URR08HKOXIN,5.0
3229,A2PGJP6GV2ZC02,5.0
6168,A67ZXSOC2XH4O,4.0


In [235]:
# predict that for test data
users = []
perfs = []
for user, row in user_maj_cls_train.iterrows():
    user_gt = score_df[score_df["reviewerID"] == user]["overall"].values
    pred_score = row["overall"]
    acc = sum(user_gt == pred_score) / len(user_gt)
    users.append(user)
    perfs.append(acc)

In [236]:
train_baseline_df = pd.DataFrame({"reviewerID": users, "perf": perfs})
train_baseline_df

Unnamed: 0,reviewerID,perf
0,A101S5PLO0VRHQ,0.314286
1,A10E0V7PGY34UZ,0.933333
2,A10O7THJ2O20AG,0.800000
3,A11P853U6FIKAM,0.421053
4,A12O5K3EQ4MC7Z,0.392157
...,...,...
495,AYT4FJYVCHYLE,0.333333
496,AYVW3O6W8S5S4,0.533333
497,AZD488SA9QMYF,0.933333
498,AZJ4DFLH9O4FZ,0.333333


In [238]:
train_baseline_df.agg(["mean", "std"])

Unnamed: 0,perf
mean,0.629733
std,0.219778


In [239]:
train_baseline_df.quantile(q=[.2, .4, .6, .8, 1])

Unnamed: 0,perf
0.2,0.411132
0.4,0.545455
0.6,0.679622
0.8,0.866667
1.0,1.0


## Predict w/ Test Prob Baseline

In [181]:
def get_val_counts(x, val):
    return sum(x == val)

In [143]:
score_df = test_df[["reviewerID", "overall"]]
score_df.head(5)

Unnamed: 0,reviewerID,overall
260,A1AEPMPA12GUJ7,4.0
1705,A4UWNRY0WWECK,5.0
1837,A23URR08HKOXIN,5.0
3229,A2PGJP6GV2ZC02,5.0
6168,A67ZXSOC2XH4O,4.0


In [186]:
def count_1(x):
    return sum(x == 1)

count_1.__name__ = "count_1"

def count_2(x):
    return sum(x == 2)

count_2.__name__ = "count_2"

def count_3(x):
    return sum(x == 3)

count_3.__name__ = "count_3"

def count_4(x):
    return sum(x == 4)

count_4.__name__ = "count_4"

def count_5(x):
    return sum(x == 5)

count_5.__name__ = "count_5"

In [187]:
count_fns = [count_1, count_2, count_3, count_4, count_5]

In [188]:
test_dist_by_user = score_df.groupby(["reviewerID"]).agg(count_fns)
test_dist_by_user

Unnamed: 0_level_0,overall,overall,overall,overall,overall
Unnamed: 0_level_1,count_1,count_2,count_3,count_4,count_5
reviewerID,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
A101S5PLO0VRHQ,2.0,11.0,11.0,5.0,6.0
A10E0V7PGY34UZ,0.0,0.0,0.0,1.0,14.0
A10O7THJ2O20AG,1.0,0.0,1.0,1.0,12.0
A11P853U6FIKAM,1.0,0.0,10.0,22.0,24.0
A12O5K3EQ4MC7Z,0.0,0.0,10.0,21.0,20.0
...,...,...,...,...,...
AYT4FJYVCHYLE,0.0,1.0,2.0,5.0,7.0
AYVW3O6W8S5S4,0.0,1.0,2.0,8.0,4.0
AZD488SA9QMYF,0.0,0.0,0.0,1.0,14.0
AZJ4DFLH9O4FZ,0.0,0.0,4.0,6.0,5.0


In [208]:
# make preds according to these probabilities
users = []
perfs = []
seeds = []
for user, row in test_dist_by_user.iterrows():
    user_gt = score_df[score_df["reviewerID"] == user]["overall"].values
    for seed in [42, 43, 44]:
        np.random.seed(seed)
        score_counts = np.array([row["overall", "count_{}".format(i + 1)] for i in range(5)])
        norm_score_counts = score_counts / sum(score_counts)
        preds = np.random.choice([1, 2, 3, 4, 5], size=int(sum(score_counts)), p=norm_score_counts)
        # get perf
        acc = sum(preds == user_gt) / len(preds)
        users.append(user)
        seeds.append(seed)
        perfs.append(acc)

In [210]:
baseline_prob_df = pd.DataFrame({"reviewerID": users, "perf": perfs, "seed": seeds})

In [213]:
user_baseline_prob_df = baseline_prob_df.groupby(["reviewerID"]).agg(["mean", "std"])
user_baseline_prob_df

Unnamed: 0_level_0,perf,perf,seed,seed
Unnamed: 0_level_1,mean,std,mean,std
reviewerID,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
A101S5PLO0VRHQ,0.257143,0.049487,43,1.0
A10E0V7PGY34UZ,0.844444,0.076980,43,1.0
A10O7THJ2O20AG,0.666667,0.133333,43,1.0
A11P853U6FIKAM,0.368421,0.030387,43,1.0
A12O5K3EQ4MC7Z,0.392157,0.019608,43,1.0
...,...,...,...,...
AYT4FJYVCHYLE,0.355556,0.076980,43,1.0
AYVW3O6W8S5S4,0.311111,0.038490,43,1.0
AZD488SA9QMYF,0.844444,0.076980,43,1.0
AZJ4DFLH9O4FZ,0.244444,0.038490,43,1.0


In [214]:
print(user_baseline_prob_df["perf", "mean"].mean())
print(user_baseline_prob_df["perf", "mean"].std())

0.5404806448624055
0.21367634666530266


## Predict w/ Train Prob Baseline

In [None]:
# may also want to add baseline where you predict with probs according to user data

## Mean + Median Basline

In [121]:
user_means = test_df.groupby(["reviewerID"]).agg("mean")["overall"]
user_means

reviewerID
A101S5PLO0VRHQ    3.057143
A10E0V7PGY34UZ    4.933333
A10O7THJ2O20AG    4.533333
A11P853U6FIKAM    4.192982
A12O5K3EQ4MC7Z    4.196078
                    ...   
AYT4FJYVCHYLE     4.200000
AYVW3O6W8S5S4     4.000000
AZD488SA9QMYF     4.933333
AZJ4DFLH9O4FZ     4.066667
AZZV9PDNMCOZW     3.933333
Name: overall, Length: 500, dtype: float64

In [None]:
def get_mean_pred(x):
    return round(x.loc[x['reviewerID']])

## Global Model

In [5]:
base_result_dir = os.path.join(main_result_dir, "from_embeds", "eval_train_all_my_user_split_from_my_user_split_clf_embeddings")
levels = ["user", "seed"]
global_results = ResultProcessor(base_result_dir, levels, verbose=True)

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings: Found results for 500 users
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1JY6HFCL4PZI4: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A11P853U6FIKAM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2AKH66IWM5O5: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2BAAKZHSUGCDP: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embe

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3P738KVXL2YYM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/AMCZLPIRP0QTE: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3V39KWHCBSF30: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1HQP7190B0WJU: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2MF4TISBBQT5A: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2UM2ABAII4QTT: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1U0RS0JIDAHDM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1GARI2JT6EAWA: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/AYT4FJYVCHYLE: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1ZSQ0ZRYGPK7D: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A4S8JJMA33F2B: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3QRR8PSCBI07C: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2F1A7DANSLFJ9: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/ANWAMG5B44UU5: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A30V9M9DZW8SFU: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/c

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2NB2E5DXE319Z: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A249G4SVEWV9UX: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2LLGLMTGAQ8B4: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1QXR4HL9JW1HI: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2NV86LTJDQ2BB: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3LW8GY42A4URK: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/AQ8OO59DJFJNZ: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A36EDWL4F3AASU: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3MW8B6I2LXVWR: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A7AO0PBCKSW82: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/c

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2YOFCOEKH3KB: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2QHM5HBSIXRL4: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A20IYX6BSPQ5PR: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1C97CZ8GVFMY5: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1EXGL6L0QQ0M5: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3H9JSM1SUTE4O: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/ASI65UKWLTDJQ: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/AGARMSTYE4ZYE: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2Y9088O384NIW: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A2UE9D1TQ3XGUH: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/c

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3LCVGMQD8HK43: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3HYB9AL7BZY4: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1TPWGBNIYJ76F: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A33OM5IX0UBUS7: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A86KXT0G63WEO: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/c

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A382NTLH5U16W5: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A1MLHNQK1LV6WI: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/ARW1MQYTDO8KM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/ABYVXJZ41TCS4: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/eval_train_all_my_user_split_from_my_user_split_clf_embeddings/A3G0123D15ORSW: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/c

In [16]:
user_results_global = global_results.results_df.groupby(["user"]).agg(["mean", "std"])[["eval_accuracy", "test_accuracy"]]
user_results_global

Unnamed: 0_level_0,eval_accuracy,eval_accuracy,test_accuracy,test_accuracy
Unnamed: 0_level_1,mean,std,mean,std
user,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
A101S5PLO0VRHQ,0.539216,0.044927,0.485714,0.049487
A10E0V7PGY34UZ,1.000000,0.000000,0.933333,0.000000
A10O7THJ2O20AG,0.777778,0.038490,0.555556,0.076980
A11P853U6FIKAM,0.533333,0.020995,0.479532,0.010129
A12O5K3EQ4MC7Z,0.653061,0.053995,0.725490,0.033962
...,...,...,...,...
AYT4FJYVCHYLE,0.444444,0.038490,0.733333,0.000000
AYVW3O6W8S5S4,0.688889,0.038490,0.333333,0.000000
AZD488SA9QMYF,0.800000,0.000000,0.955556,0.038490
AZJ4DFLH9O4FZ,0.644444,0.038490,0.644444,0.038490


In [17]:
user_results_global.mean()

eval_accuracy  mean    0.693647
               std     0.037806
test_accuracy  mean    0.694542
               std     0.035038
dtype: float64

In [21]:
user_results_global.std()

eval_accuracy  mean    0.172218
               std     0.034061
test_accuracy  mean    0.177386
               std     0.031738
dtype: float64

In [20]:
user_results_global.quantile(q=[.1, .25, .5, .75, .9])

Unnamed: 0_level_0,eval_accuracy,eval_accuracy,test_accuracy,test_accuracy
Unnamed: 0_level_1,mean,std,mean,std
0.1,0.475238,0.0,0.466667,0.0
0.25,0.577778,0.009545,0.577778,0.005308
0.5,0.690936,0.03849,0.708333,0.03849
0.75,0.822222,0.051551,0.822222,0.044227
0.9,0.920794,0.07698,0.933333,0.07698


In [69]:
user_results2 = global_results.results_df.groupby(["user"]).agg(["mean", "std"])[test_cols]
user_results2

Unnamed: 0_level_0,test_0_accuracy,test_0_accuracy,test_1_accuracy,test_1_accuracy,test_2_accuracy,test_2_accuracy,test_3_accuracy,test_3_accuracy,test_4_accuracy,test_4_accuracy
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
user,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
A101S5PLO0VRHQ,0.333333,0.288675,0.333333,0.052486,0.727273,0.157459,0.800000,0.000000,0.111111,0.096225
A10E0V7PGY34UZ,,,,,,,0.333333,0.577350,0.976190,0.041239
A10O7THJ2O20AG,0.000000,0.000000,,,0.666667,0.577350,0.000000,0.000000,0.638889,0.048113
A11P853U6FIKAM,0.000000,0.000000,,,0.400000,0.000000,0.242424,0.069433,0.750000,0.083333
A12O5K3EQ4MC7Z,,,,,0.566667,0.057735,0.793651,0.027493,0.733333,0.076376
...,...,...,...,...,...,...,...,...,...,...
AYT4FJYVCHYLE,,,0.000000,0.000000,1.000000,0.000000,0.600000,0.000000,0.857143,0.000000
AYVW3O6W8S5S4,,,0.000000,0.000000,0.000000,0.000000,0.166667,0.072169,0.916667,0.144338
AZD488SA9QMYF,,,,,,,1.000000,0.000000,0.952381,0.041239
AZJ4DFLH9O4FZ,,,,,0.666667,0.144338,0.722222,0.096225,0.533333,0.115470


In [70]:
user_results2.mean()

test_0_accuracy  mean    0.284548
                 std     0.029801
test_1_accuracy  mean    0.364809
                 std     0.062921
test_2_accuracy  mean    0.459106
                 std     0.040381
test_3_accuracy  mean    0.486014
                 std     0.085739
test_4_accuracy  mean    0.852480
                 std     0.046945
dtype: float64

In [71]:
user_results2.std()

test_0_accuracy  mean    0.376999
                 std     0.108371
test_1_accuracy  mean    0.360599
                 std     0.151442
test_2_accuracy  mean    0.365721
                 std     0.103273
test_3_accuracy  mean    0.287129
                 std     0.110596
test_4_accuracy  mean    0.170055
                 std     0.055394
dtype: float64

## Person-Specific Models

In [54]:
base_result_dir = os.path.join(main_result_dir, "from_embeds", "person_specific_my_split_n_500")
levels = ["user", "seed"]
person_specific_results = ResultProcessor(base_result_dir, levels, verbose=True)

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500: Found results for 500 users
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2RQOO8VYAEZZG: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2WDC81C1MQUAS: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A38CKQUHA9POY0: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A14R9XMZVJ6INB: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2MF4TISBBQT5A: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/p

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A11P853U6FIKAM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1JY6HFCL4PZI4: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2T5O7MHGONT6S: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A18Y6RF6S79076: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2BAAKZHSUGCDP: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AVCAXJ845TL8S: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/f

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A21G2H64TFS4JO: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AECD1QOMZ2F2I: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1UOS0IM2GP87S: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/ATANE2SC44592: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A22F9L73A92U6B: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A3IH73YPH07FTP: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/fr

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A13WOT3RSXKRD5: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AZJ4DFLH9O4FZ: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/ADVTJ03JD4RQ2: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AP4FQR3BIIYEW: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1YFB1OF0XKJOD: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A37SZWL3R0LEQ3: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/fro

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AN81JUYW2SL24: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AF1IU3K4DB1XI: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1TUT3W4Q9KN8E: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2HWCU87BKZ8M0: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1W1UTQ5SZNE8J: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1WR5OUT03E3M8: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/fr

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/ACUJMLOJEVYTB: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2MYC0P0L0W7BU: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1QS1B2IW9SWHC: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A32QERE04I60K9: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A3A1HHLJZL97DP: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1R796P7A9BKMH: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/f

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A4X56LVVL2X2U: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2X75UXQLLI32H: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A3I5J6JJHQY7H7: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A27VAEBHL9FQDV: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/AAEIK0DZ1F537: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A26T2MC3VCLVYB: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/fr

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2MHCZISNWHQFR: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1TKZM4ZQXC4HY: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A33CY1MZDI8944: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A3BO9I25753U4C: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1WQM564J3V3P2: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/ATC0DD938W4QM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/f

Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A17EWTSBIHB4QM: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1ZON6G8O4BDH3: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2Y29IRSI08F0I: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A1CGOR398UH1IB: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A17BTP1QHK2I3I: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/from_embeds/person_specific_my_split_n_500/A2HZRORRKBERKH: Found results for 3 seeds
Base dir /data/ddmg/redditlanguagemodeling/results/amazon_reviews/clf/

In [55]:
user_results_ps = person_specific_results.results_df.groupby(["user"]).agg(["mean", "std"])[["eval_accuracy", "test_accuracy"]]
user_results_ps

Unnamed: 0_level_0,eval_accuracy,eval_accuracy,test_accuracy,test_accuracy
Unnamed: 0_level_1,mean,std,mean,std
user,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
A101S5PLO0VRHQ,0.490196,0.089854,0.466667,0.016496
A10E0V7PGY34UZ,0.977778,0.038490,0.933333,0.000000
A10O7THJ2O20AG,0.933333,0.000000,0.800000,0.000000
A11P853U6FIKAM,0.490909,0.000000,0.421053,0.000000
A12O5K3EQ4MC7Z,0.489796,0.040816,0.614379,0.059903
...,...,...,...,...
AYT4FJYVCHYLE,0.600000,0.066667,0.355556,0.038490
AYVW3O6W8S5S4,0.377778,0.038490,0.533333,0.000000
AZD488SA9QMYF,0.800000,0.000000,0.911111,0.038490
AZJ4DFLH9O4FZ,0.666667,0.133333,0.555556,0.101835


In [56]:
user_results_ps.mean()

eval_accuracy  mean    0.677737
               std     0.037296
test_accuracy  mean    0.679493
               std     0.036073
dtype: float64

In [57]:
user_results_ps.std()

eval_accuracy  mean    0.191658
               std     0.052216
test_accuracy  mean    0.190525
               std     0.050658
dtype: float64

In [58]:
user_results_ps.quantile(q=[.1, .25, .5, .75, .9])

Unnamed: 0_level_0,eval_accuracy,eval_accuracy,test_accuracy,test_accuracy
Unnamed: 0_level_1,mean,std,mean,std
0.1,0.429954,0.0,0.422222,0.0
0.25,0.533333,0.0,0.5435,0.0
0.5,0.666667,0.015604,0.666667,0.016615
0.75,0.822222,0.061114,0.823016,0.042649
0.9,0.938571,0.11547,0.933333,0.101835


In [73]:
user_results2 = person_specific_results.results_df.groupby(["user"]).agg(["mean", "std"])[test_cols]
user_results2

Unnamed: 0_level_0,test_0_accuracy,test_0_accuracy,test_1_accuracy,test_1_accuracy,test_2_accuracy,test_2_accuracy,test_3_accuracy,test_3_accuracy,test_4_accuracy,test_4_accuracy
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
user,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
A101S5PLO0VRHQ,0.0,0.0,0.484848,0.104973,0.484848,0.138866,0.266667,0.305505,0.722222,0.346944
A10E0V7PGY34UZ,,,,,,,0.000000,0.000000,1.000000,0.000000
A10O7THJ2O20AG,0.0,0.0,,,0.000000,0.000000,0.000000,0.000000,1.000000,0.000000
A11P853U6FIKAM,0.0,0.0,,,0.000000,0.000000,0.000000,0.000000,1.000000,0.000000
A12O5K3EQ4MC7Z,,,,,0.233333,0.230940,0.666667,0.312259,0.750000,0.390512
...,...,...,...,...,...,...,...,...,...,...
AYT4FJYVCHYLE,,,0.000000,0.000000,0.000000,0.000000,1.000000,0.000000,0.047619,0.082479
AYVW3O6W8S5S4,,,0.000000,0.000000,0.000000,0.000000,0.791667,0.190941,0.416667,0.381881
AZD488SA9QMYF,,,,,,,0.000000,0.000000,0.976190,0.041239
AZJ4DFLH9O4FZ,,,,,0.333333,0.144338,0.333333,0.333333,1.000000,0.000000


In [74]:
user_results2.mean()

test_0_accuracy  mean    0.112555
                 std     0.057781
test_1_accuracy  mean    0.057553
                 std     0.036199
test_2_accuracy  mean    0.140687
                 std     0.059845
test_3_accuracy  mean    0.315754
                 std     0.138338
test_4_accuracy  mean    0.827407
                 std     0.085961
dtype: float64

In [75]:
user_results2.std()

test_0_accuracy  mean    0.244238
                 std     0.155373
test_1_accuracy  mean    0.157676
                 std     0.104961
test_2_accuracy  mean    0.255461
                 std     0.122826
test_3_accuracy  mean    0.350169
                 std     0.171026
test_4_accuracy  mean    0.292800
                 std     0.145554
dtype: float64

## Compare Perf of Global vs Local

In [60]:
combined_results = user_results_global.merge(user_results_ps, how='left', left_index=True, right_index=True, suffixes=["_global", "_local"])

In [44]:
def get_diff(x, split):
    diff = x['{}_accuracy_global'.format(split), 'mean'] - x['{}_accuracy_local'.format(split), 'mean']
    return diff

In [61]:
for split in ['eval', 'test']:
    combined_results['{}_accuracy_diff'.format(split)] = combined_results.apply(lambda x: get_diff(x, split), axis=1)

In [62]:
combined_results

Unnamed: 0_level_0,eval_accuracy_global,eval_accuracy_global,test_accuracy_global,test_accuracy_global,eval_accuracy_local,eval_accuracy_local,test_accuracy_local,test_accuracy_local,eval_accuracy_diff,test_accuracy_diff
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,Unnamed: 9_level_1,Unnamed: 10_level_1
user,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
A101S5PLO0VRHQ,0.539216,0.044927,0.485714,0.049487,0.490196,0.089854,0.466667,0.016496,0.049020,0.019048
A10E0V7PGY34UZ,1.000000,0.000000,0.933333,0.000000,0.977778,0.038490,0.933333,0.000000,0.022222,0.000000
A10O7THJ2O20AG,0.777778,0.038490,0.555556,0.076980,0.933333,0.000000,0.800000,0.000000,-0.155556,-0.244444
A11P853U6FIKAM,0.533333,0.020995,0.479532,0.010129,0.490909,0.000000,0.421053,0.000000,0.042424,0.058480
A12O5K3EQ4MC7Z,0.653061,0.053995,0.725490,0.033962,0.489796,0.040816,0.614379,0.059903,0.163265,0.111111
...,...,...,...,...,...,...,...,...,...,...
AYT4FJYVCHYLE,0.444444,0.038490,0.733333,0.000000,0.600000,0.066667,0.355556,0.038490,-0.155556,0.377778
AYVW3O6W8S5S4,0.688889,0.038490,0.333333,0.000000,0.377778,0.038490,0.533333,0.000000,0.311111,-0.200000
AZD488SA9QMYF,0.800000,0.000000,0.955556,0.038490,0.800000,0.000000,0.911111,0.038490,0.000000,0.044444
AZJ4DFLH9O4FZ,0.644444,0.038490,0.644444,0.038490,0.666667,0.133333,0.555556,0.101835,-0.022222,0.088889


In [63]:
print(combined_results['test_accuracy_diff'].abs().mean())
print(combined_results['test_accuracy_diff'].abs().std())
print(combined_results['test_accuracy_diff'].abs().quantile(q=[.1, .25, .5, .75, .9]))

0.10519796985884508
0.10265003610859712
0.10    0.010071
0.25    0.030303
0.50    0.075730
0.75    0.155556
0.90    0.244444
Name: test_accuracy_diff, dtype: float64


In [64]:
print(combined_results['test_accuracy_diff'].mean())
print(combined_results['test_accuracy_diff'].std())
print(combined_results['test_accuracy_diff'].quantile(q=[.1, .25, .5, .75, .9]))

0.015049269820253054
0.14628358064219235
0.10   -0.155556
0.25   -0.044444
0.50    0.022222
0.75    0.102431
0.90    0.177778
Name: test_accuracy_diff, dtype: float64


In [65]:
print(sum(combined_results['test_accuracy_diff'] > 0))
print(sum(combined_results['test_accuracy_diff'] == 0))
print(sum(combined_results['test_accuracy_diff'] < 0))

275
34
191


## Get Difference Compared to Majority Class Baseline

In [215]:
user_maj_cls.head(5)

Unnamed: 0_level_0,overall
reviewerID,Unnamed: 1_level_1
A101S5PLO0VRHQ,0.314286
A10E0V7PGY34UZ,0.933333
A10O7THJ2O20AG,0.8
A11P853U6FIKAM,0.421053
A12O5K3EQ4MC7Z,0.411765


In [216]:
user_maj_cls = user_maj_cls.rename(columns = {"overall": "test_accuracy_mb"})

In [217]:
combined_results2 = combined_results.merge(user_maj_cls, how='left', left_index=True, right_index=True)



In [218]:
combined_results2

Unnamed: 0_level_0,"(eval_accuracy_global, mean)","(eval_accuracy_global, std)","(test_accuracy_global, mean)","(test_accuracy_global, std)","(eval_accuracy_local, mean)","(eval_accuracy_local, std)","(test_accuracy_local, mean)","(test_accuracy_local, std)","(eval_accuracy_diff, )","(test_accuracy_diff, )",test_accuracy_mb
user,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
A101S5PLO0VRHQ,0.539216,0.044927,0.485714,0.049487,0.490196,0.089854,0.466667,0.016496,0.049020,0.019048,0.314286
A10E0V7PGY34UZ,1.000000,0.000000,0.933333,0.000000,0.977778,0.038490,0.933333,0.000000,0.022222,0.000000,0.933333
A10O7THJ2O20AG,0.777778,0.038490,0.555556,0.076980,0.933333,0.000000,0.800000,0.000000,-0.155556,-0.244444,0.800000
A11P853U6FIKAM,0.533333,0.020995,0.479532,0.010129,0.490909,0.000000,0.421053,0.000000,0.042424,0.058480,0.421053
A12O5K3EQ4MC7Z,0.653061,0.053995,0.725490,0.033962,0.489796,0.040816,0.614379,0.059903,0.163265,0.111111,0.411765
...,...,...,...,...,...,...,...,...,...,...,...
AYT4FJYVCHYLE,0.444444,0.038490,0.733333,0.000000,0.600000,0.066667,0.355556,0.038490,-0.155556,0.377778,0.466667
AYVW3O6W8S5S4,0.688889,0.038490,0.333333,0.000000,0.377778,0.038490,0.533333,0.000000,0.311111,-0.200000,0.533333
AZD488SA9QMYF,0.800000,0.000000,0.955556,0.038490,0.800000,0.000000,0.911111,0.038490,0.000000,0.044444,0.933333
AZJ4DFLH9O4FZ,0.644444,0.038490,0.644444,0.038490,0.666667,0.133333,0.555556,0.101835,-0.022222,0.088889,0.400000


In [219]:
# add columns to get diff with local and global perf
def get_diff_w_baseline(x, model_type):
    diff = x['test_accuracy_{}'.format(model_type), "mean"] - x['test_accuracy_mb']
    return diff

In [221]:
for model_type in ["global", "local"]:
    combined_results2["{}_mb_diff".format(model_type)] = combined_results2.apply(lambda x: get_diff_w_baseline(x, model_type), axis=1)

In [222]:
combined_results2

Unnamed: 0_level_0,"(eval_accuracy_global, mean)","(eval_accuracy_global, std)","(test_accuracy_global, mean)","(test_accuracy_global, std)","(eval_accuracy_local, mean)","(eval_accuracy_local, std)","(test_accuracy_local, mean)","(test_accuracy_local, std)","(eval_accuracy_diff, )","(test_accuracy_diff, )",test_accuracy_mb,global_mb_diff,local_mb_diff
user,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
A101S5PLO0VRHQ,0.539216,0.044927,0.485714,0.049487,0.490196,0.089854,0.466667,0.016496,0.049020,0.019048,0.314286,1.714286e-01,1.523810e-01
A10E0V7PGY34UZ,1.000000,0.000000,0.933333,0.000000,0.977778,0.038490,0.933333,0.000000,0.022222,0.000000,0.933333,3.973643e-09,3.973643e-09
A10O7THJ2O20AG,0.777778,0.038490,0.555556,0.076980,0.933333,0.000000,0.800000,0.000000,-0.155556,-0.244444,0.800000,-2.444444e-01,1.192093e-08
A11P853U6FIKAM,0.533333,0.020995,0.479532,0.010129,0.490909,0.000000,0.421053,0.000000,0.042424,0.058480,0.421053,5.847954e-02,3.137087e-09
A12O5K3EQ4MC7Z,0.653061,0.053995,0.725490,0.033962,0.489796,0.040816,0.614379,0.059903,0.163265,0.111111,0.411765,3.137255e-01,2.026144e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...
AYT4FJYVCHYLE,0.444444,0.038490,0.733333,0.000000,0.600000,0.066667,0.355556,0.038490,-0.155556,0.377778,0.466667,2.666667e-01,-1.111111e-01
AYVW3O6W8S5S4,0.688889,0.038490,0.333333,0.000000,0.377778,0.038490,0.533333,0.000000,0.311111,-0.200000,0.533333,-2.000000e-01,2.781550e-08
AZD488SA9QMYF,0.800000,0.000000,0.955556,0.038490,0.800000,0.000000,0.911111,0.038490,0.000000,0.044444,0.933333,2.222222e-02,-2.222222e-02
AZJ4DFLH9O4FZ,0.644444,0.038490,0.644444,0.038490,0.666667,0.133333,0.555556,0.101835,-0.022222,0.088889,0.400000,2.444445e-01,1.555556e-01


In [226]:
combined_results2[[("test_accuracy_global", "mean"), ("test_accuracy_local", "mean"), "test_accuracy_mb"]]

  return array(a, dtype, copy=False, order=order)


Unnamed: 0_level_0,"(test_accuracy_global, mean)","(test_accuracy_local, mean)",test_accuracy_mb
user,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A101S5PLO0VRHQ,0.485714,0.466667,0.314286
A10E0V7PGY34UZ,0.933333,0.933333,0.933333
A10O7THJ2O20AG,0.555556,0.800000,0.800000
A11P853U6FIKAM,0.479532,0.421053,0.421053
A12O5K3EQ4MC7Z,0.725490,0.614379,0.411765
...,...,...,...
AYT4FJYVCHYLE,0.733333,0.355556,0.466667
AYVW3O6W8S5S4,0.333333,0.533333,0.533333
AZD488SA9QMYF,0.955556,0.911111,0.933333
AZJ4DFLH9O4FZ,0.644444,0.555556,0.400000


In [223]:
combined_results2.mean()

(eval_accuracy_global, mean)    0.693647
(eval_accuracy_global, std)     0.037806
(test_accuracy_global, mean)    0.694542
(test_accuracy_global, std)     0.035038
(eval_accuracy_local, mean)     0.677737
(eval_accuracy_local, std)      0.037296
(test_accuracy_local, mean)     0.679493
(test_accuracy_local, std)      0.036073
(eval_accuracy_diff, )          0.015910
(test_accuracy_diff, )          0.015049
test_accuracy_mb                0.650542
global_mb_diff                  0.044000
local_mb_diff                   0.028951
dtype: float64

In [224]:
combined_results2.std()

(eval_accuracy_global, mean)    0.172218
(eval_accuracy_global, std)     0.034061
(test_accuracy_global, mean)    0.177386
(test_accuracy_global, std)     0.031738
(eval_accuracy_local, mean)     0.191658
(eval_accuracy_local, std)      0.052216
(test_accuracy_local, mean)     0.190525
(test_accuracy_local, std)      0.050658
(eval_accuracy_diff, )          0.144518
(test_accuracy_diff, )          0.146284
test_accuracy_mb                0.195295
global_mb_diff                  0.171540
local_mb_diff                   0.093981
dtype: float64