In [27]:
import torch

def batch_pearson_coherency(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """
    Computes Pearson correlation for a batch of matrices.
    
    Args:
        A (torch.Tensor): Tensor of shape (batch_size, m, n).
        B (torch.Tensor): Tensor of shape (batch_size, m, n).
    
    Returns:
        torch.Tensor: PCC for each matrix pair in the batch.
    """
    # Reshape to (batch_size, m*n)
    a = A.view(A.shape[0], -1)
    b = B.view(B.shape[0], -1)
    
    # Compute mean-centered matrices
    a_centered = a - a.mean(dim=1, keepdim=True)
    b_centered = b - b.mean(dim=1, keepdim=True)
    
    # Compute covariance and stds
    cov = (a_centered * b_centered).sum(dim=1) / (a.shape[1] - 1)
    std_a = torch.sqrt((a_centered ** 2).sum(dim=1) / (a.shape[1] - 1))
    std_b = torch.sqrt((b_centered ** 2).sum(dim=1) / (b.shape[1] - 1))
    
    # Avoid division by zero
    eps = 1e-8
    rho = cov / (std_a * std_b + eps)
    return rho


# Example usage

a = torch.randn((5,5), dtype=torch.float32)
b = a + torch.randn((5,5), dtype=torch.float32) * 1 # Correlated with some noise

print(pearson_coherency(a, b))  # Should be ~1.0

tensor(0.6348)
