In [2]:
import sys 
sys.path.insert(1, "../")
from workloads.util import use_plots, use_results

import seaborn as sns 
from matplotlib import pyplot as plt
sns.set()

import pandas as pd
%load_ext autoreload
%autoreload 2

In [3]:
sns.set_theme(style="whitegrid")
sns.set_theme(style="whitegrid", font_scale=1.7)
plt.tight_layout()

<Figure size 432x288 with 0 Axes>

## Get STL Results 

In [4]:
data_dir = use_results("yahoo_A1_window_48_keys_67_length_700/round_robin_0_A1", download=False)
filename = f"{data_dir}/simulation_predictions.csv"
stl_df = pd.read_csv(filename)
stl_df.columns = ['ts', 'y_pred', 'y_true', 'staleness', 'key']
stl_df

Unnamed: 0,ts,y_pred,y_true,staleness,key
0,0,0.032859,0.076970,0,1
1,1,0.093368,0.063933,1,1
2,2,0.071843,0.149733,1,1
3,3,0.083161,0.041479,1,1
4,4,0.093230,0.089318,1,1
...,...,...,...,...,...
43679,647,-0.708341,1.000000,1,67
43680,648,5.499994,12.000000,1,67
43681,649,42.166662,19.000000,1,67
43682,650,25.583331,45.000000,1,67


In [5]:
#Fit model
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, LogisticRegressionCV, LassoCV, ElasticNet
from sklearn.metrics import mean_squared_error
from sktime.performance_metrics.forecasting import mean_absolute_scaled_error
import numpy as np

def compute_window_error(df, W):
    unique_keys = df['key'].unique()
    results = []
    
    for key in unique_keys:
        key_df = df[df['key'] == key]

        # Iterate through the rows within the rolling window
        for i in range(W, len(key_df), W):
            window = key_df.iloc[i-W:i]

            # Check if the window has enough data points
            if len(window) < W:
                error = np.nan  # Fill with NaN if the window is too small
            else:
                error = mean_absolute_scaled_error(
                    window['y_pred'], 
                    window['y_true'], 
                    y_train=key_df.y_true
                )

            # Append the result to the list
            results.append({"key": key, "error": error, "ts": key_df.iloc[i-W]["ts"]})

    # Add the results as a new column in the original DataFrame
    return pd.DataFrame(results) 


results = []

for W in [2, 4, 8, 16, 32]:
        
    error_df = compute_window_error(stl_df, W).dropna()
    print(error_df)
        
    for offset in [0, 2, 4, 8]: 

        x = []
        y = []
        # Compute rolling window
        for user_id in error_df['key'].unique()[:20]: # only for a single key 
            user_df = error_df[error_df["key"] == user_id]
            print(user_df)
            #x += [error_df['error'].iloc[i:i+W].tolist() for i in range(len(user_df) - W - offset)]
            #y += [error_df['error'].iloc[i+W+offset].tolist() for i in range(len(user_df) - W - offset)]
            x += [[error_df['error'].iloc[i].tolist()] for i in range(len(user_df) - offset - 1)]
            y += [error_df['error'].iloc[i+1+offset].tolist() for i in range(len(user_df) - offset - 1)]

        # Split the data into training and testing sets (50/50 split)
        #print(x, y)
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5, random_state=42)

        # Initialize and fit a linear regression model
        model = ElasticNet()
        model.fit(x_train, y_train)

        # Make predictions on the test set
        y_pred = model.predict(x_test)

        # Calculate Mean Squared Error (MSE)
        mse = mean_squared_error(y_test, y_pred)
        y_avg = sum(y_train) / len(y_train)
        baseline_mse = mean_squared_error(y_test, len(y_test) * [y_avg])
        print(offset, W, mse, "data size", len(y), "baseline", baseline_mse)
        results.append({"mse": mse, "window_size": W, "baseline_mse": baseline_mse, "offset": offset, "key": user_id})


       key     error     ts
0        1  0.940041    0.0
1        1  1.528355    2.0
2        1  0.529310    4.0
3        1  0.367171    6.0
4        1  1.804411    8.0
...    ...       ...    ...
21770   67  0.720851  640.0
21771   67  0.895494  642.0
21772   67  0.408219  644.0
21773   67  0.253700  646.0
21774   67  0.511711  648.0

[21775 rows x 3 columns]
     key     error     ts
0      1  0.940041    0.0
1      1  1.528355    2.0
2      1  0.529310    4.0
3      1  0.367171    6.0
4      1  1.804411    8.0
..   ...       ...    ...
320    1  1.312024  640.0
321    1  2.179110  642.0
322    1  2.003231  644.0
323    1  2.801045  646.0
324    1  4.133106  648.0

[325 rows x 3 columns]
     key     error     ts
325    2  1.122780    0.0
326    2  1.177418    2.0
327    2  1.125345    4.0
328    2  1.418032    6.0
329    2  0.888964    8.0
..   ...       ...    ...
645    2  4.472779  640.0
646    2  5.865159  642.0
647    2  4.300912  644.0
648    2  4.191378  646.0
649    2  6.3150

[325 rows x 3 columns]
      key     error     ts
975     4  3.217575    0.0
976     4  1.032861    2.0
977     4  4.591096    4.0
978     4  3.221655    6.0
979     4  1.414319    8.0
...   ...       ...    ...
1295    4  0.832952  640.0
1296    4  0.800994  642.0
1297    4  1.508833  644.0
1298    4  0.626924  646.0
1299    4  0.607205  648.0

[325 rows x 3 columns]
      key     error     ts
1300    5  1.039746    0.0
1301    5  2.998695    2.0
1302    5  3.313680    4.0
1303    5  3.287925    6.0
1304    5  2.441895    8.0
...   ...       ...    ...
1620    5  0.413144  640.0
1621    5  0.596114  642.0
1622    5  1.116542  644.0
1623    5  2.703025  646.0
1624    5  3.236330  648.0

[325 rows x 3 columns]
      key     error     ts
1625    6  0.948342    0.0
1626    6  1.367800    2.0
1627    6  1.427072    4.0
1628    6  2.870861    6.0
1629    6  1.563852    8.0
...   ...       ...    ...
1945    6  0.500007  640.0
1946    6  1.335885  642.0
1947    6  1.471145  644.0
1948    6  

      key     error     ts
2925   10  2.247214    0.0
2926   10  1.974999    2.0
2927   10  2.267025    4.0
2928   10  1.923405    6.0
2929   10  1.044929    8.0
...   ...       ...    ...
3245   10  1.354434  640.0
3246   10  1.930657  642.0
3247   10  1.061492  644.0
3248   10  1.746113  646.0
3249   10  1.175535  648.0

[325 rows x 3 columns]
      key      error     ts
3250   11   3.260583    0.0
3251   11   5.324245    2.0
3252   11   3.563672    4.0
3253   11   5.926479    6.0
3254   11  11.820853    8.0
...   ...        ...    ...
3570   11   4.867161  640.0
3571   11   3.076792  642.0
3572   11   2.961099  644.0
3573   11   3.597580  646.0
3574   11   0.803543  648.0

[325 rows x 3 columns]
      key     error     ts
3575   12  1.967860    0.0
3576   12  3.799023    2.0
3577   12  2.320992    4.0
3578   12  2.577298    6.0
3579   12  2.628560    8.0
...   ...       ...    ...
3895   12  0.650732  640.0
3896   12  0.579536  642.0
3897   12  0.283361  644.0
3898   12  0.743287  6

[162 rows x 3 columns]
     key     error     ts
810    6  1.158071    0.0
811    6  2.148966    4.0
812    6  1.692273    8.0
813    6  0.925545   12.0
814    6  1.129955   16.0
..   ...       ...    ...
967    6  1.334365  628.0
968    6  1.554733  632.0
969    6  0.702898  636.0
970    6  0.917946  640.0
971    6  0.953661  644.0

[162 rows x 3 columns]
      key     error     ts
972     7  5.713135    0.0
973     7  3.584198    4.0
974     7  1.870274    8.0
975     7  4.695951   12.0
976     7  0.992471   16.0
...   ...       ...    ...
1129    7  0.066725  628.0
1130    7  0.082376  632.0
1131    7  0.094239  636.0
1132    7  0.083365  640.0
1133    7  0.020759  644.0

[162 rows x 3 columns]
      key     error     ts
1134    8  0.862065    0.0
1135    8  1.591467    4.0
1136    8  0.589782    8.0
1137    8  0.971873   12.0
1138    8  1.188010   16.0
...   ...       ...    ...
1291    8  0.743314  628.0
1292    8  0.854612  632.0
1293    8  2.809290  636.0
1294    8  3.078592  64

      key     error     ts
0       1  0.841219    0.0
1       1  1.048857    8.0
2       1  1.237754   16.0
3       1  0.710598   24.0
4       1  0.732005   32.0
...   ...       ...    ...
5422   67  0.984972  608.0
5423   67  0.944545  616.0
5424   67  0.577112  624.0
5425   67  0.898369  632.0
5426   67  0.569566  640.0

[5427 rows x 3 columns]
    key     error     ts
0     1  0.841219    0.0
1     1  1.048857    8.0
2     1  1.237754   16.0
3     1  0.710598   24.0
4     1  0.732005   32.0
..  ...       ...    ...
76    1  1.161552  608.0
77    1  2.147283  616.0
78    1  1.797483  624.0
79    1  2.292879  632.0
80    1  2.073853  640.0

[81 rows x 3 columns]
     key     error     ts
81     2  1.210894    0.0
82     2  0.860779    8.0
83     2  0.756472   16.0
84     2  0.655628   24.0
85     2  0.621832   32.0
..   ...       ...    ...
157    2  0.395166  608.0
158    2  1.268482  616.0
159    2  4.036890  624.0
160    2  3.793551  632.0
161    2  4.707557  640.0

[81 rows x 3 co

[81 rows x 3 columns]
     key     error     ts
648    9  1.292931    0.0
649    9  1.169715    8.0
650    9  1.278538   16.0
651    9  1.000007   24.0
652    9  1.113768   32.0
..   ...       ...    ...
724    9  2.744595  608.0
725    9  6.296523  616.0
726    9  7.390802  624.0
727    9  3.690744  632.0
728    9  2.788934  640.0

[81 rows x 3 columns]
     key     error     ts
729   10  2.103161    0.0
730   10  1.269122    8.0
731   10  1.165570   16.0
732   10  0.710038   24.0
733   10  1.396491   32.0
..   ...       ...    ...
805   10  0.834515  608.0
806   10  1.272702  616.0
807   10  0.784469  624.0
808   10  0.952075  632.0
809   10  1.523174  640.0

[81 rows x 3 columns]
     key     error     ts
810   11  4.518745    0.0
811   11  5.008567    8.0
812   11  5.313233   16.0
813   11  5.873927   24.0
814   11  6.735556   32.0
..   ...       ...    ...
886   11  1.576374  608.0
887   11  1.693052  616.0
888   11  1.020481  624.0
889   11  2.653575  632.0
890   11  3.625658  64

      key     error     ts
0       1  0.945038    0.0
1       1  0.974176   16.0
2       1  1.043454   32.0
3       1  1.267560   48.0
4       1  0.977076   64.0
...   ...       ...    ...
2675   67  1.251967  560.0
2676   67  1.323118  576.0
2677   67  0.751486  592.0
2678   67  0.964759  608.0
2679   67  0.737741  624.0

[2680 rows x 3 columns]
    key     error     ts
0     1  0.945038    0.0
1     1  0.974176   16.0
2     1  1.043454   32.0
3     1  1.267560   48.0
4     1  0.977076   64.0
5     1  0.963335   80.0
6     1  1.002366   96.0
7     1  1.537941  112.0
8     1  0.914322  128.0
9     1  0.898238  144.0
10    1  1.417750  160.0
11    1  1.164312  176.0
12    1  1.262044  192.0
13    1  1.063950  208.0
14    1  1.332107  224.0
15    1  1.489580  240.0
16    1  1.258209  256.0
17    1  1.162388  272.0
18    1  1.382747  288.0
19    1  1.483388  304.0
20    1  1.195573  320.0
21    1  1.041604  336.0
22    1  0.771249  352.0
23    1  0.992158  368.0
24    1  0.744099  384.0
2

519   13  0.790950  624.0
     key     error     ts
520   14  3.385494    0.0
521   14  3.703702   16.0
522   14  2.376992   32.0
523   14  1.898613   48.0
524   14  3.238735   64.0
525   14  0.772708   80.0
526   14  1.950859   96.0
527   14  1.196166  112.0
528   14  0.486596  128.0
529   14  0.702446  144.0
530   14  0.733321  160.0
531   14  0.933715  176.0
532   14  1.112699  192.0
533   14  0.990564  208.0
534   14  0.972819  224.0
535   14  2.939520  240.0
536   14  2.843772  256.0
537   14  1.667421  272.0
538   14  1.031062  288.0
539   14  0.641509  304.0
540   14  1.250191  320.0
541   14  0.933206  336.0
542   14  1.065969  352.0
543   14  1.343324  368.0
544   14  1.023638  384.0
545   14  0.790029  400.0
546   14  0.731469  416.0
547   14  0.876788  432.0
548   14  0.953762  448.0
549   14  0.943989  464.0
550   14  1.319998  480.0
551   14  0.630899  496.0
552   14  0.732498  512.0
553   14  0.717543  528.0
554   14  0.766214  544.0
555   14  0.711461  560.0
556   14  1.

      key     error     ts
0       1  0.959607    0.0
1       1  1.155507   32.0
2       1  0.970205   64.0
3       1  1.270153   96.0
4       1  0.906280  128.0
...   ...       ...    ...
1335   67  1.116583  480.0
1336   67  0.807140  512.0
1337   67  1.132844  544.0
1338   67  1.037302  576.0
1339   67  0.851250  608.0

[1340 rows x 3 columns]
    key     error     ts
0     1  0.959607    0.0
1     1  1.155507   32.0
2     1  0.970205   64.0
3     1  1.270153   96.0
4     1  0.906280  128.0
5     1  1.291031  160.0
6     1  1.162997  192.0
7     1  1.410844  224.0
8     1  1.210299  256.0
9     1  1.433067  288.0
10    1  1.118588  320.0
11    1  0.881704  352.0
12    1  0.699125  384.0
13    1  0.830217  416.0
14    1  1.153634  448.0
15    1  0.765306  480.0
16    1  1.048738  512.0
17    1  1.410735  544.0
18    1  1.353215  576.0
19    1  1.849799  608.0
    key      error     ts
20    2   0.870943    0.0
21    2   1.189338   32.0
22    2   3.152817   64.0
23    2   3.774681   9

8 32 0.10641498115462097 data size 220 baseline 0.10641498115462095


In [6]:
window_error_df = pd.DataFrame(results)
stl_error_df = window_error_df.groupby(["window_size", "offset"]).mse.mean().reset_index()

In [9]:
stl_error_df

Unnamed: 0,window_size,offset,mse
0,2,0,0.500271
1,2,2,0.515779
2,2,4,0.504318
3,2,8,0.529603
4,4,0,0.286255
5,4,2,0.293437
6,4,4,0.29128
7,4,8,0.280977
8,8,0,0.146393
9,8,2,0.141264


## Get ALS Results 

In [86]:
data_dir = use_results("ml-1m", download=False)
filename = f"{data_dir}/round_robin_None_60_split_0.5_results.csv"
als_df = pd.read_csv(filename)
als_df

Unnamed: 0.1,Unnamed: 0,y_true,y_pred,user_id,movie_id,timestamp
0,0,4,4.134009,2783,1396,0
1,1,5,1.359328,2783,2901,0
2,2,4,4.235560,3970,3408,1
3,3,4,1.000000,3970,2890,2
4,4,4,1.000000,2782,1265,10
...,...,...,...,...,...,...
12018,12018,5,1.000000,2689,1396,5509
12019,12019,5,1.000000,2689,3347,5509
12020,12020,5,1.000000,2689,3361,5509
12021,12021,5,1.000000,2689,356,5509


In [87]:
top_users = als_df['user_id'].iloc[8000:].value_counts().nlargest(20).index
top_users

Int64Index([3471, 2731, 4169, 2700, 2724, 2752, 2730, 4728, 2692, 2719, 2701,
            2693, 2748, 2703, 2727, 2726, 2718, 2699, 2689, 2694],
           dtype='int64')

In [88]:
#als_df['squared_error'] = (als_df['y_true'] / als_df['y_true'].mean() - als_df['y_pred'] / als_df['y_true'].mean()) ** 2
als_df['squared_error'] = (als_df['y_true'] - als_df['y_pred']) ** 2 
als_df["error"] = als_df['squared_error'] 
als_df["error"]

0         0.017958
1        13.254491
2         0.055489
3         9.000000
4         9.000000
           ...    
12018    16.000000
12019    16.000000
12020    16.000000
12021    16.000000
12022     9.000000
Name: error, Length: 12023, dtype: float64

In [89]:
def fill_timestamps(df): 
    min_timestamp = df.index.min()
    max_timestamp = df.index.max()

    # Create a new index with all integers between the min and max timestamps
    new_index = pd.RangeIndex(start=min_timestamp, stop=max_timestamp+1)
    # Reindex the Series with the new index and fill missing values with 0
    df = df.reindex(new_index, fill_value=0)
    df["timestamp"] = new_index
    return df

In [90]:
cache = {}

In [None]:
#Fit model
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, LogisticRegressionCV, LassoCV, ElasticNet
from sklearn.metrics import mean_squared_error

results = []
for offset in [0, 1, 2, 4, 8, 32]: 

    for W in [1, 2, 3, 4, 5, 6, 7, 8, 16, 32]:
        
        if f"{offset}_{W}" in cache: 
            results.append(cache[f"{offset}_{W}"])
            continue 

        # Compute rolling window
        x = []
        y = []
        for user_id in top_users: 
            user_df = als_df[als_df["user_id"] == user_id]
            user_df = fill_timestamps(user_df)
            x += [als_df['error'].iloc[i:i+W].tolist() for i in range(len(als_df) - W - offset)]
            y += [als_df['error'].iloc[i+W+offset].tolist() for i in range(len(als_df) - W - offset)]

        # Split the data into training and testing sets (50/50 split)
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5, random_state=42)

        # Initialize and fit a linear regression model
        model = ElasticNet()
        model.fit(x_train, y_train)

        # Make predictions on the test set
        y_pred = model.predict(x_test)

        # Calculate Mean Squared Error (MSE)
        mse = mean_squared_error(y_test, y_pred)
        y_avg = sum(y_train) / len(y_train)
        baseline_mse = mean_squared_error(y_test, len(y_test) * [y_avg])
        print(offset, W, mse, "data size", len(y), "baseline", baseline_mse)
        results.append({"mse": mse, "window_size": W, "baseline_mse": baseline_mse, "offset": offset})
        cache[f"{offset}_{W}"] = {"mse": mse, "window_size": W, "baseline_mse": baseline_mse, "offset": offset}

0 1 27.300735335302 data size 240440 baseline 31.709058256923615
0 2 25.726194999522313 data size 240420 baseline 31.512048925494444
0 3 25.22449165163593 data size 240400 baseline 31.513914242443853
0 4 24.95649298600221 data size 240380 baseline 31.58903015073394
0 5 24.78885979890943 data size 240360 baseline 31.53972336940853
0 6 24.651872504604178 data size 240340 baseline 31.54657149752334
0 7 24.477754269414326 data size 240320 baseline 31.50335750062525
0 8 24.32967335450473 data size 240300 baseline 31.583528404273
0 16 23.88843896683338 data size 240140 baseline 31.460622916120503
0 32 23.910372027913994 data size 239820 baseline 31.555170687999123
1 1 27.96340404546638 data size 240420 baseline 31.512048925494444
1 2 26.889865360601213 data size 240400 baseline 31.513914242443853
1 3 26.459532869107203 data size 240380 baseline 31.58903015073394
1 4 26.104019588877208 data size 240360 baseline 31.53972336940853
1 5 25.90461451223268 data size 240340 baseline 31.5465714975233

In [None]:
als_error_df = pd.DataFrame(results)

# Plot 

In [None]:
plots_dir = use_plots("", download=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12,6))
linewidth = 2


# plot ALS 
g1 = sns.lineplot(data=als_error_df, marker="o", y="mse", x="window_size", hue="offset", ax=axes[1], linewidth=linewidth)
axes[1].set_xlabel("Window Size (Number of Timesteps)", fontsize=22, fontweight="bold")
axes[1].set_ylabel("MSE", fontsize=22, fontweight="bold")
axes[1].set_title("Recommendation", fontsize=26)

# plot STL 
g0 = sns.lineplot(data=stl_error_df , marker="o", y="mse", x="window_size", hue="offset", ax=axes[0], linewidth=linewidth)                                                                                                
axes[0].set_xlabel("Window Size (Number of Timesteps)", fontsize=22, fontweight="bold")
axes[0].set_ylabel("MSE", fontsize=22, fontweight="bold")
axes[0].set_title("Anomaly Detection (Yahoo A1)", fontsize=26)

# Add legends based on the "legend_label:" comments
#handles, labels = axes.get_legend_handles_labels()
#print(handles, labels)
#axes.legend(handles, labels)

fig.savefig(f'{plots_dir}/predict_error.pdf', dpi=300, bbox_inches = "tight")

fig.tight_layout()