In [376]:
import torch
import geoopt
from torch.nn.functional import conv2d, normalize
from torch import sqrt
import tensorflow

ModuleNotFoundError: No module named 'tensorflow'

In [188]:
pcb = geoopt.manifolds.PoincareBall(
            c=1.)

x = geoopt.ManifoldTensor(
    torch.zeros(10, ),
    manifold=pcb,
    requires_grad=True)

y = torch.ones(10, ) * 3



In [200]:
def torch_sqnorm(u: torch.Tensor, keepdims: bool = True, dim: int = 1) -> torch.Tensor:
    """ calculate vector norms over given axis """
    return torch.sum(u * u, dim=dim, keepdims=keepdims)

def torch_cross_correlate(inputs: torch.Tensor, filters: torch.Tensor) -> torch.Tensor:
    return conv2d(inputs, weight=filters, stride=(1, 1), padding='same')

In [373]:
EPS = 1e-15
PROJ_EPS = 1e-3


def torch_hyp_mlr(inputs: torch.Tensor, c: torch.Tensor, P_mlr: torch.Tensor, A_mlr: torch.Tensor) -> torch.Tensor:
    """ Perform hyperbolic MLR to calculate the class probabilies of inputs based on their 
    similarity to the parameters defining the class hyperplanes. The number of classes is
    for the full hierarchy.
    
    args:
        inputs: shape (B, ch, H, W) embeddding model outputs.
        P_mlr: shape (ncls, ch) class hyperplane offsets
        A_mlr: shape (ncls, ch) class hyperplane normals
        
    returns:
        logits: shape (batch, num_classes, height_width)  """
    
    xx = torch_sqnorm(inputs, keepdims=False)  
    
    pp = torch_sqnorm(-P_mlr, keepdims=False, dim=0)  
    
    P_kernel = -P_mlr.T[:, :, None, None]
    
    px = torch_cross_correlate(inputs, filters=P_kernel)  
    
    return px.shape
    
    sqsq = torch.multiply(
        c * xx, c * pp[None, :, None, None])  
    
    A_norm = torch.linalg.norm(A_mlr, dim=0)  
    
    normed_A = normalize(A_mlr, dim=0)
    
    A_kernel = torch.transpose(normed_A, 0, 1)[:, :, None, None]
    
    A = 1 + torch.add(2 * c * px, c * xx) 
    
    B = 1. - c * pp  
    
    D = 1 + torch.add(2 * c * px, sqsq)  
    
    D = torch.maximum(D, torch.full(D.shape, EPS, device=inputs.device))
    
    alpha = A / D  
    
    beta = B[None, :, None, None] / D  
    
    mobaddnorm = ((alpha ** 2 * pp[None, :, None, None]) +
                  (beta ** 2 * xx) + (2 * alpha * beta * px))
    
    maxnorm = (1.0 - PROJ_EPS) / sqrt(c)
    
  
    check = torch.sqrt(mobaddnorm) > maxnorm
    if_true = maxnorm / \
        torch.maximum(torch.sqrt(mobaddnorm),
                      torch.full(mobaddnorm.shape, EPS, device=inputs.device))
    
    if_false = torch.ones_like(mobaddnorm, device=inputs.device)
    
    project_normalized = torch.zeros_like(mobaddnorm, device=inputs.device)
    
    project_normalized[check] = if_true[check]
    
    project_normalized[~check] = if_false[~check]
    
    check = torch.sqrt(mobaddnorm) < maxnorm
    
    if_false = torch.ones_like(mobaddnorm, device=inputs.device) * maxnorm ** 2
    
    mobaddnormprojected = torch.zeros_like(mobaddnorm, device=inputs.device)
    
    mobaddnormprojected[check] = mobaddnorm[check]
    
    mobaddnormprojected[~check] = if_false[~check]
    
    xdota = beta * torch_cross_correlate(inputs, filters=A_kernel)
    
    pdota = (alpha * torch.sum(-P_mlr * normed_A, dim=0)[None, :, None, None])
    
    mobdota = xdota + pdota  
    
    mobdota *= project_normalized  
    
    lamb_px = 2.0 / torch.maximum(
        1. - c * mobaddnormprojected,
        torch.tensor(EPS, device=inputs.device))
    
    
    sineterm = sqrt(c) * mobdota * lamb_px
    
    return 2.0 / sqrt(c) * A_norm.view(1, -1, 1, 1) * torch.asinh(sineterm)

In [374]:
ch = 5
ncls = 2
h, w = 5, 5
batch = 2


embs  = torch.rand(batch, dim, h, w, requires_grad=True)
P_mlr = torch.rand(ncls, ch, requires_grad=True)
A_mlr = torch.rand(ncls, ch, requires_grad=True)

print(P_mlr, P_mlr.shape, '\n\n----------------')

P_mlr = geoopt.ManifoldTensor(
    P_mlr,
    manifold=pcb,
    requires_grad=True)

hyperbolic_optimizer = geoopt.optim.rsgd.RiemannianSGD(
            [P_mlr], # offsets are in hyperbolic space
            lr=0.01,
            momentum=0.997,
            weight_decay=0.997,
            stabilize=1)

c = torch.tensor(1.)

tensor([[0.3093, 0.8054, 0.0877, 0.6038, 0.8264],
        [0.4499, 0.4868, 0.8578, 0.0964, 0.1782]], requires_grad=True) torch.Size([2, 5]) 

----------------


In [375]:
torch_hyp_mlr(embs, c, P_mlr, A_mlr, )

# logits.sum().backward()
# hyperbolic_optimizer.step()
# torch.linalg.norm(P_mlr, dim=1)

RuntimeError: Given groups=1, weight of size [5, 2, 1, 1], expected input[2, 3, 5, 5] to have 2 channels, but got 3 channels instead