Skip to content

'geomstats.geometry.riemannian_metric' has no attribute 'loss' or 'grad' #28

@zhangjipinggom

Description

@zhangjipinggom

in the train.py line 103 to 110

`riem_dist = np.sqrt(riem.loss(rtvec.detach().cpu(), rtvec_gt.detach().cpu(), METRIC))

z = Variable(torch.ones(l2_loss.shape)).to(device)
rtvec_grad = torch.autograd.grad(l2_loss, rtvec, grad_outputs=z, only_inputs=True, create_graph=True,
retain_graph=True)[0]

riem_grad = riem.grad(rtvec.detach().cpu(), rtvec_gt.detach().cpu(), METRIC)`

#riem.loss() and riem.grad() are called
##but module 'geomstats.geometry.riemannian_metric' has no attribute 'loss' or 'grad'

added information:
`from geomstats.geometry.riemannian_metric import RiemannianMetric
import geomstats.geometry.riemannian_metric as riem

SE3_GROUP = SpecialEuclidean(n=3, point_type='vector')
RiemMetric = RiemannianMetric(dim=6)
METRIC = SE3_GROUP.left_canonical_metric`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions