In [1]:
import torch
import torch.nn.functional as F

class OPAMetric:
    def __init__(self, name):
        """Constructor."""
        self._name = name
    
    @property
    def name(self):
        """The metric name."""
        return self._name
    
    def compute(self, labels, predictions, weights, mask):
        """
        Args:
            labels (Tensor): The ground truth labels.
            predictions (Tensor): The predicted scores.
            weights (Tensor): The weights for each element.
            mask (Tensor): A mask tensor indicating valid elements.
            
        Returns:
            per_list_opa (Tensor): The OPA for each list.
            per_list_weights (Tensor): The weights for each list.
        """
        
        valid_pair = mask.unsqueeze(2) & mask.unsqueeze(1)
        
        pair_label_diff = labels.unsqueeze(2) - labels.unsqueeze(1)
        pair_pred_diff = predictions.unsqueeze(2) - predictions.unsqueeze(1)
        
        # Construct correct pairs
        correct_pairs = (pair_label_diff > 0).float() * (pair_pred_diff > 0).float()
        
        # Compute pair weights
        pair_weights = (pair_label_diff > 0).float() * weights.unsqueeze(2) * valid_pair.float()
        
        # Sum over the pairs dimension to get per_list_weights and per_list_opa
        per_list_weights = pair_weights.sum(dim=[1, 2]).unsqueeze(1)
        per_list_opa = (correct_pairs * pair_weights).sum(dim=[1, 2]).unsqueeze(1) / (per_list_weights + 1e-8)
        
        return per_list_opa, per_list_weights


# Example Usage
labels = torch.tensor([[1.0, 0.0, 2.0], [0.0, 1.0, 2.0]])
preds = torch.tensor([[0.2, 0.4, 0.1], [0.5, 0.3, 0.7]])
weights = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
mask = torch.tensor([[True, True, True], [True, True, True]], dtype=torch.bool)

opa_metric = OPAMetric(name='opa')
per_list_opa, per_list_weights = opa_metric.compute(labels, preds, weights, mask)

print("Per List OPA: ", per_list_opa)
print("Per List Weights: ", per_list_weights)

Per List OPA:  tensor([[0.0000],
        [0.6667]])
Per List Weights:  tensor([[3.],
        [3.]])


In [None]:
import torch
import torchmetrics

class OPAMetric(torchmetrics.Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.add_state("correct_pairs_sum", default=torch.tensor([0.0]), dist_reduce_fx="sum")
    
    def update(self, labels: torch.Tensor, predictions: torch.Tensor):        
        pair_label_diff = labels.unsqueeze(2) - labels.unsqueeze(1)
        pair_pred_diff = predictions.unsqueeze(2) - predictions.unsqueeze(1)

        correct_pairs = (pair_label_diff > 0).float() * (pair_pred_diff > 0).float()
        print(correct_pairs.shape)
        self.correct_pairs_sum += correct_pairs.sum(dim=1)
        print(self.correct_pairs_sum)
        
    def compute(self):
        return self.correct_pairs_sum


preds = torch.randn(100,)
labels = torch.randn(100,)

opa_metric = OPAMetric()
opa_metric.update(labels, preds)
print("OPA: ", opa_metric.compute())


In [110]:
import torch
import torchmetrics

class OPAMetric(torchmetrics.Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.add_state("correct_pairs_sum", default=torch.tensor([0.0]), dist_reduce_fx="sum")
        self.add_state("pair_weights_sum", default=torch.tensor([0.0]), dist_reduce_fx="sum")
    
    def update(self, preds: torch.Tensor, labels: torch.Tensor):
        pair_label_diff = labels.unsqueeze(-1) - labels.unsqueeze(1)
        pair_pred_diff = preds.unsqueeze(-1) - preds.unsqueeze(1)
                
        correct_pairs = ((pair_label_diff > 0).float() * (pair_pred_diff > 0).float())
        pair_weights = (pair_label_diff > 0).float()

        self.correct_pairs_sum += (correct_pairs * pair_weights).sum(dim=[1, 2])
        self.pair_weights_sum += pair_weights.sum(dim=[1, 2])

    def compute(self):
        opa = self.correct_pairs_sum / self.pair_weights_sum
        return opa

# Example Usage

labels = torch.tensor([0, 1, 2])
predictions = torch.tensor([0, 2, 1])

opa_metric = OPAMetric()
opa_metric.update(predictions[None], labels[None])
print("OPA: ", opa_metric.compute())

OPA:  tensor(0.6667)


In [55]:
import tensorflow_ranking as tfr


In [111]:
import tensorflow as tf


y_true = [[0., 1., 2.]]
y_pred = [[0., 2., 1.]]
opa = tfr.keras.metrics.OPAMetric()
opa(y_true, y_pred).numpy()

0.6666667

In [127]:
opa_metric = OPAMetric()
tfr_opa = tfr.keras.metrics.OPAMetric(name='opa_metric')

for _ in range(100):
    labels = torch.randn(5,)
    preds = torch.randn(5,)
    
    pt_result = opa_metric(preds[None], labels[None])
    tfr_result = tfr_opa(labels.numpy()[None], preds.numpy()[None])
    
    print(pt_result, tfr_result)
    assert pt_result.item() == tfr_result.numpy().item()


tensor(0.8000) tf.Tensor(0.0, shape=(), dtype=float32)


AssertionError: 

In [53]:
torch.argsort(labels), torch.argsort(preds)

(tensor([1, 8, 9, 5, 2, 4, 3, 0, 6, 7]),
 tensor([5, 0, 9, 6, 4, 7, 3, 8, 2, 1]))