In [12]:
import numpy as np

In [13]:
def recall_at_k(y_true, y_pred, k):
    """
    Compute Recall@K for multiple samples.
    
    Args:
    y_true: 2D array of true relevance labels, shape (n_samples, n_labels)
    y_pred: 2D array of predicted scores or probabilities, shape (n_samples, n_labels)
    k: The number of top items to consider for Recall@K
    
    Returns:
    Array of Recall@K scores for each sample and the mean Recall@K.
    """
    n_samples = y_true.shape[0]
    recalls = np.zeros(n_samples)
    
    for i in range(n_samples):
        # Total number of relevant items for this sample
        total_relevant = np.sum(y_true[i])
        print(f'r:{total_relevant}')
        total_relevant = min(total_relevant, k)
        print(f'min(r,k):{total_relevant}')
        # Get top K indices for this sample
        top_k_indices = np.argsort(y_pred[i])[::-1][:k]
        
        # Count relevant items in top K
        relevant_in_k = np.sum(y_true[i][top_k_indices])
        
        # Compute Recall@K for this sample
        if total_relevant > 0:
            recalls[i] = relevant_in_k / total_relevant
        else:
            recalls[i] = 0  # If no relevant items, recall is 0
    
    # Compute mean Recall@K
    mean_recall = np.mean(recalls)
    
    return recalls, mean_recall


In [14]:
# Example usage
y_true = np.array([
    [1, 0, 1, 1, 0, 1, 0],
    [1, 1, 0, 0, 1, 0, 1],
    [0, 1, 1, 1, 0, 0, 1],
    [0, 0, 0, 1, 0, 0, 1]
])
y_pred = np.array([
    [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3],
    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
    [0.3, 0.6, 0.2, 0.7, 0.1, 0.9, 0.5],
    [0.3, 0.6, 0.2, 0.7, 0.1, 0.9, 0.9]
])

# Compute Recall@3 and Recall@5
k_values = [3, 5]

for k in k_values:
    recalls, mean_recall = recall_at_k(y_true, y_pred, k)
    print(f"\nRecall@{k}:")
    for i, recall in enumerate(recalls):
        print(f"Sample {i+1}: {recall:.4f}")
    print(f"Mean Recall@{k}: {mean_recall:.4f}")

r:4
min(r,k):3
r:4
min(r,k):3
r:4
min(r,k):3
r:2
min(r,k):2

Recall@3:
Sample 1: 0.6667
Sample 2: 0.6667
Sample 3: 0.6667
Sample 4: 1.0000
Mean Recall@3: 0.7500
r:4
min(r,k):4
r:4
min(r,k):4
r:4
min(r,k):4
r:2
min(r,k):2

Recall@5:
Sample 1: 0.7500
Sample 2: 0.5000
Sample 3: 0.7500
Sample 4: 1.0000
Mean Recall@5: 0.7500
