In [1]:
import pandas as pd
import numpy as np
import os

In [2]:
def accuracy_at_k(preds_df, true_items, max_k=25):
    accs = []
    for k in range(1, max_k + 1):
        correct = 0
        for i in range(len(true_items)):
            scores = preds_df.iloc[i]
            top_k_items = scores.sort_values(ascending=False).index[:k]
            if true_items.iloc[i] in top_k_items:
                correct += 1
        accs.append(correct / len(true_items))
    return accs

def fast_accuracy_at_k(preds_df: pd.DataFrame, true_items: pd.Series, max_k: int = 25):
    # Convert DataFrame to numpy array
    scores = preds_df.to_numpy()
    item_ids = np.array(preds_df.columns)

    # Get indices of top-k items for each row
    topk_indices = np.argpartition(-scores, range(max_k), axis=1)[:, :max_k]  # partial sort
    topk_scores = np.take_along_axis(scores, topk_indices, axis=1)

    # Fully sort only the top-k items
    sorted_topk_idx = np.argsort(-topk_scores, axis=1)
    topk_sorted_indices = np.take_along_axis(topk_indices, sorted_topk_idx, axis=1)

    # Map true_items to column indices
    item_to_index = {item: i for i, item in enumerate(preds_df.columns)}
    true_indices = true_items.map(item_to_index).to_numpy()

    # Compute match matrix: shape (num_samples, max_k)
    matches = (topk_sorted_indices == true_indices[:, None])

    # Cumulative accuracy: if true item is in top-k
    acc_at_k = matches.cumsum(axis=1).clip(0, 1).mean(axis=0)

    return acc_at_k.tolist()

In [3]:
# fast_accuracy_at_k(df_results, df_results[0], max_k=top_k)

In [4]:
top_k = 50

results = np.array([0.0 for _ in range(top_k)])

total_count = 0

for file in os.listdir('extended_results'):
    print(f'Loading ./extended_results/{file}')
    df_results = pd.read_parquet(f'./extended_results/{file}')
    print('\tCalculating...')
    results += np.array(fast_accuracy_at_k(df_results, df_results[0], max_k=top_k)) * len(df_results)
    total_count += len(df_results)

Loading ./extended_results/extended_results_ordered_0.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_1.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_2.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_3.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_4.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_5.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_6.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_7.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_8.parquet
	Calculating...
Loading ./extended_results/extended_results_ordered_last.parquet
	Calculating...


In [5]:
np.savetxt('extended_accuracy.txt', results / total_count)

In [29]:
results

array([3., 3., 6., 6., 6., 6., 6., 6., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
       7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
       7., 7., 7., 7., 7., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.])

In [21]:
results

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0])

In [22]:
out

array([0.13043478, 0.13043478, 0.26086957, 0.26086957, 0.26086957,
       0.26086957, 0.26086957, 0.26086957, 0.30434783, 0.30434783,
       0.30434783, 0.30434783, 0.30434783, 0.30434783, 0.30434783,
       0.30434783, 0.30434783, 0.30434783, 0.30434783, 0.30434783,
       0.30434783, 0.30434783, 0.30434783, 0.30434783, 0.30434783,
       0.30434783, 0.30434783, 0.30434783, 0.30434783, 0.30434783,
       0.30434783, 0.30434783, 0.30434783, 0.30434783, 0.30434783,
       0.30434783, 0.30434783, 0.30434783, 0.30434783, 0.34782609,
       0.34782609, 0.34782609, 0.34782609, 0.34782609, 0.34782609,
       0.34782609, 0.34782609, 0.34782609, 0.34782609, 0.34782609])

In [6]:
df_results = pd.read_parquet(f'./extended_results/extended_results_ordered_0.parquet')
df_results

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,3018,3019,3020,3021,3022,3023,3024,3025,3026,3027
0,,236.425583,235.43158,234.761398,235.508194,236.186493,236.109375,236.529495,236.099197,235.616425,...,232.483932,227.166626,230.592346,227.777954,234.953476,231.59613,232.379959,230.952667,229.870041,229.449158
1,,236.859177,235.80658,236.858109,235.742249,236.690247,236.555023,236.942108,236.566254,236.003372,...,232.83728,227.366379,230.809906,227.992416,235.214035,231.983368,232.626953,231.250076,230.187714,229.687515
2,,236.59581,235.60022,235.461243,235.551788,236.339752,236.244263,236.877609,236.338181,235.723984,...,232.66803,227.260391,230.699387,227.88295,235.102203,231.791397,232.505722,231.091461,230.025879,229.574615
3,,237.024307,236.006012,235.565735,235.842072,236.698349,236.64592,239.158386,237.12413,236.399673,...,232.970535,227.582321,230.985031,228.18103,235.56871,232.154037,232.797714,231.469452,230.371597,229.820267
4,,237.324295,236.395462,235.827744,236.120453,237.070953,237.070297,240.613571,237.660721,236.953735,...,233.344315,227.838211,231.214157,228.410492,235.835022,232.417801,233.077118,231.769073,230.642563,230.079803
5,,237.283218,236.316452,235.808792,236.125702,237.010666,237.014252,240.31572,237.489456,236.841019,...,233.304626,227.80632,231.176361,228.376358,235.786682,232.365128,233.041595,231.724777,230.616867,230.047134
6,,237.343369,236.338699,235.870743,236.164337,237.113068,237.104797,240.463867,237.545517,236.904343,...,233.393982,227.854218,231.231873,228.423569,235.842163,232.411179,233.083359,231.764023,230.663422,230.086731
7,,236.717484,235.615921,235.137924,235.514786,236.348709,236.234436,237.80307,236.568253,235.830093,...,232.828857,227.40538,230.838425,228.005875,235.258011,231.913055,232.763809,231.308929,230.119751,229.83429
8,,237.886581,236.245514,236.315689,235.825272,236.983246,236.757599,238.418518,237.626877,236.777191,...,233.218369,227.674103,231.170975,228.323929,235.89621,232.667252,233.077911,231.762802,230.6539,230.159348
9,,237.954742,236.174866,236.176971,235.735123,236.950073,236.658981,237.982986,237.628311,236.640182,...,233.170593,227.648788,231.15593,228.310944,235.843399,232.669113,233.063919,231.760147,230.628372,230.224197


In [2]:
df_results = pd.read_parquet('extended_results_ordered.parquet').astype(float)

In [4]:
df_results.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1640681 entries, 0 to 1640680
Columns: 3028 entries, 0 to 3027
dtypes: float64(3028)
memory usage: 37.0 GB


In [7]:
# # max_k
# accs = []


# for k in range(1, max_k + 1):
#     correct = 0
    

In [8]:
def accuracy_at_k(preds_df, true_items, max_k=25):
    accs = []
    for k in range(1, max_k + 1):
        correct = 0
        for i in range(len(true_items)):
            scores = preds_df.iloc[i]
            top_k_items = scores.sort_values(ascending=False).index[:k]
            if true_items.iloc[i] in top_k_items:
                correct += 1
        accs.append(correct / len(true_items))
    return accs

In [9]:
def fast_accuracy_at_k(preds_df: pd.DataFrame, true_items: pd.Series, max_k: int = 25):
    # Convert DataFrame to numpy array
    scores = preds_df.to_numpy()
    item_ids = np.array(preds_df.columns)

    print(1)
    # Get indices of top-k items for each row
    topk_indices = np.argpartition(-scores, range(max_k), axis=1)[:, :max_k]  # partial sort
    topk_scores = np.take_along_axis(scores, topk_indices, axis=1)
    print(1)
    # Fully sort only the top-k items
    sorted_topk_idx = np.argsort(-topk_scores, axis=1)
    topk_sorted_indices = np.take_along_axis(topk_indices, sorted_topk_idx, axis=1)
    print(1)
    # Map true_items to column indices
    item_to_index = {item: i for i, item in enumerate(preds_df.columns)}
    true_indices = true_items.map(item_to_index).to_numpy()
    print(1)
    # Compute match matrix: shape (num_samples, max_k)
    matches = (topk_sorted_indices == true_indices[:, None])
    print(1)
    # Cumulative accuracy: if true item is in top-k
    acc_at_k = matches.cumsum(axis=1).clip(0, 1).mean(axis=0)
    print(1)
    return acc_at_k.tolist()

In [None]:
fast_accuracy_at_k(df_results, df['track_id'], max_k=1)

1
