In [1]:
%matplotlib notebook
import torch
import os
# Turn on CPP Stacktraces for more debug detail
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = "1"
from torch_sym3eig import Sym3Eig as se

In [2]:
from lazy_imports import np
from lazy_imports import plt
from lazy_imports import loadmat, savemat
plt.rcParams["figure.figsize"] = (4, 4) # (w, h)

In [3]:
import pickle
import math
import os
import scipy.io as sio
import pandas as pd
import scipy

In [4]:
import time

In [5]:
from data.io import readRaw, ReadScalars, ReadTensors, WriteTensorNPArray, WriteScalarNPArray, readPath3D
from data.convert import GetNPArrayFromSITK, GetSITKImageFromNP

In [6]:
from lazy_imports import itkwidgets
from lazy_imports import itkview
from lazy_imports import interactive
from lazy_imports import ipywidgets
from lazy_imports import pv


In [7]:
from tqdm import tqdm

from util.RegistrationFunc3DCuda import *
from util.SplitEbinMetric3DCuda import *
#from util.RegistrationFunc3D import *
#from util.SplitEbinMetric3D import *

In [8]:
from util.tensors import tens_6_to_tens_3x3, tens_3x3_to_tens_6, get_framework, fractional_anisotropy

# Setup torch device defaults

In [9]:
cuda_dev = 'cuda:0'
device = torch.device(cuda_dev if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type('torch.cuda.DoubleTensor' if torch.cuda.is_available() else 'torch.DoubleTensor')
print(device)
#cuda_dev = 'cpu'
#device = torch.device(cuda_dev)
#torch.set_default_tensor_type('torch.DoubleTensor')
#print(device)

cuda:0


# Display Configuration

In [10]:
# from colorbrewer2, sequential 9 values pasted together YlGnBu (reverse order) then YlOrRd
# EXCEPT THESE ARE NOT PRINT FRIENDLY OR PHOTOCOPY SAFE!!!
# But skip the yellows in the middle -- too light
geo_colors = ['tab:red', 'tab:pink', 'tab:orange', 'tab:blue', 'tab:purple', 'tab:green', 'tab:cyan']
eul_colors = ['k', 'tab:gray', 'tab:brown', 'm', 'y', 'tab:olive', 'maroon']
#interp_colors = ['#081d58', '#253494', '#225ea8', '#1d91c0', '#41b6c4', '#7fcdbb', '#c7e9b4', '#edf8b1', '#ffffd9',
#                 '#ffffcc', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026']
interp_colors = ['#081d58', '#253494', '#225ea8', '#1d91c0', '#41b6c4', '#7fcdbb', '#c7e9b4',
                 '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026']
# Compromise, use 5 class YlGnBu (reverse order) and YlOrBr again leaving out yellow
#interp_colors = ['#253494', '#2c7fb8', '#41b6c4', '#a1dab4', '#fed98e', '#fe9929', '#d95f0e', '#993404']

case_color='tab:orange'
ctrl_color='tab:blue'
train_color='tab:blue'
test_color='tab:orange'


ncycle = 100 # Number of colors 
# winter
cycle_colors = [plt.get_cmap('copper')(1. * i/ncycle) for i in range(ncycle)]

#plt.rc('axes', 
#       prop_cycle=(cycler('color', new_colors) + 
#                   cycler('linestyle', ['-', '--', ':', '-.'])))

plt.rc('axes', 
       prop_cycle=(plt.cycler('color', cycle_colors)))


# Helper Routines

In [11]:
def plot_6_components(np_img, title, figsz=4, filename=None,vmin=-2,vmax=2):
  f, axes = plt.subplots(3, 2, figsize=(figsz*2,figsz*3))
  im0=axes[0,0].imshow(np_img[...,0],vmin=vmin,vmax=vmax)
  axes[0,0].set_title(f'{title} 0,0')
  im1=axes[0,1].imshow(np_img[...,1],vmin=vmin,vmax=vmax)
  axes[0,1].set_title(f'{title} 0,1')
  im2=axes[1,0].imshow(np_img[...,2],vmin=vmin,vmax=vmax)
  axes[1,0].set_title(f'{title} 0,2')
  im3=axes[1,1].imshow(np_img[...,3],vmin=vmin,vmax=vmax)
  axes[1,1].set_title(f'{title} 1,1')
  im4=axes[2,0].imshow(np_img[...,4],vmin=vmin,vmax=vmax)
  axes[2,0].set_title(f'{title} 1,2')
  im5=axes[2,1].imshow(np_img[...,5],vmin=vmin,vmax=vmax)
  axes[2,1].set_title(f'{title} 2,2')
  plt.tight_layout()

  cbar=f.colorbar(im0, ax=axes[0,0],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im1, ax=axes[0,1],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im2, ax=axes[1,0],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im3, ax=axes[1,1],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im4, ax=axes[2,0],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im5, ax=axes[2,1],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  
  if filename:
    f.savefig(filename, bbox_inches='tight', pad_inches=0, 
              dpi=400, transparent=True)

  return(f,axes)

def plot_3x3_components(np_img, title, figsz=4, filename=None,vmin=-2,vmax=2):
  f, axes = plt.subplots(3, 2, figsize=(figsz*2,figsz*3))
  im0=axes[0,0].imshow(np_img[...,0,0],vmin=vmin,vmax=vmax)
  axes[0,0].set_title(f'{title} 0,0')
  im1=axes[0,1].imshow(np_img[...,0,1],vmin=vmin,vmax=vmax)
  axes[0,1].set_title(f'{title} 0,1')
  im2=axes[1,0].imshow(np_img[...,0,2],vmin=vmin,vmax=vmax)
  axes[1,0].set_title(f'{title} 0,2')
  im3=axes[1,1].imshow(np_img[...,1,1],vmin=vmin,vmax=vmax)
  axes[1,1].set_title(f'{title} 1,1')
  im4=axes[2,0].imshow(np_img[...,1,2],vmin=vmin,vmax=vmax)
  axes[2,0].set_title(f'{title} 1,2')
  im5=axes[2,1].imshow(np_img[...,2,2],vmin=vmin,vmax=vmax)
  axes[2,1].set_title(f'{title} 2,2')
  plt.tight_layout()

  cbar=f.colorbar(im0, ax=axes[0,0],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im1, ax=axes[0,1],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im2, ax=axes[1,0],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im3, ax=axes[1,1],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im4, ax=axes[2,0],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  cbar=f.colorbar(im5, ax=axes[2,1],anchor=(0, 0.3), shrink=0.9, format='%-.2g')
  cbar.ax.tick_params(labelsize=6)
  
  if filename:
    f.savefig(filename, bbox_inches='tight', pad_inches=0, 
              dpi=400, transparent=True)

  return(f,axes)


In [12]:
def tensor_cleaning(g, mask, iso_tens, det_threshold=1e-11):
# 1e-8 matches CPU version
#def tensor_cleaning(g, det_threshold=1e-8):
    #g[torch.det(g)<=det_threshold] = torch.eye((3))
    #g[mask==0] = torch.eye((3))
    g[torch.det(g)<=det_threshold] = iso_tens
    g[mask==0] = iso_tens
    # Sylvester's criterion https://en.wikipedia.org/wiki/Sylvester%27s_criterion
    psd_map = torch.where(g[...,0,0]>0, 1, 0) + torch.where(torch.det(g[...,:2,:2])>0, 1, 0) + torch.where(torch.det(g)>0, 1, 0)
    nonpsd_idx = torch.where(psd_map!=3)
    # nonpsd_idx = torch.where(torch.isnan(torch.sum(batch_cholesky(g), (3,4))))
    for i in range(len(nonpsd_idx[0])):
        #g[nonpsd_idx[0][i], nonpsd_idx[1][i], nonpsd_idx[2][i]] = torch.eye((3))
        g[nonpsd_idx[0][i], nonpsd_idx[1][i], nonpsd_idx[2][i]] = iso_tens
    return g


In [13]:
def inverse_masked(tens, mask):
  tens_reshape = tens.reshape((-1,3,3))
  mask_reshape = mask.reshape((-1))
  shape = tens.shape  
  inv_tens = torch.zeros_like(tens_reshape)
  Ind_mask = (mask_reshape > 0).nonzero().reshape(-1)  
  inv_tens[Ind_mask] = torch.inverse(tens_reshape[Ind_mask])
  return(inv_tens.reshape(*shape))  


# Setup for Reading Results

In [14]:
#sim_name = 'noshape'
sim_name = 'noshapeImg'
#sim_name = 'sim1'
#sim_name = 'sim1Img'
data_dir=f'/usr/sci/projects/abcd/simdata/3d_cubics/{sim_name}/'
atlasdir=f'/usr/sci/projects/abcd/simresults/3d_cubics/{sim_name}/'

ann_prefix=f'{data_dir}metpy_annulus_3D_'
ann_tens_file = ann_prefix + 'tens.nhdr'
ann_mask_file = ann_prefix + 'mask.nhdr'

num_cubics=10 
group_prefix=f'{data_dir}metpy_3D_'
cubic_prefix=f'{data_dir}metpy_3D_cubic'
group_log_files = [group_prefix + 'group_1_cubics.txt', group_prefix + 'group_2_cubics.txt']
group_novar_tens_files = [cubic_prefix + '1_novar_tens.nhdr', cubic_prefix + '2_novar_tens.nhdr']
group_novar_mask_files = [cubic_prefix + '1_novar_mask.nhdr', cubic_prefix + '2_novar_mask.nhdr']

group_1_tens_files = [cubic_prefix + f'1_{cc}_tens.nhdr' for cc in range(num_cubics)]
group_1_mask_files = [cubic_prefix + f'1_{cc}_mask.nhdr' for cc in range(num_cubics)]
group_1_eval_files = [cubic_prefix + f'1_{cc}_evals.nhdr' for cc in range(num_cubics)]
group_2_tens_files = [cubic_prefix + f'2_{cc}_tens.nhdr' for cc in range(num_cubics)]
group_2_mask_files = [cubic_prefix + f'2_{cc}_mask.nhdr' for cc in range(num_cubics)]
group_2_eval_files = [cubic_prefix + f'2_{cc}_evals.nhdr' for cc in range(num_cubics)]

cubic_res_prefix=f'{atlasdir}metpy_3D_cubic'
group_novar_scaled_tens_files = [cubic_res_prefix + '1_novar_scaled_orig_tensors_v2.nhdr', cubic_res_prefix + '2_novar_scaled_orig_tensors_v2.nhdr']
group_novar_img_files = [cubic_res_prefix + '1_novar_T1_flip_y.nhdr', cubic_res_prefix + '2_novar_T1_flip_y.nhdr']
group_novar_expct_Ebin_dist_files = [cubic_res_prefix + '1_novar_expected_ebindist_to_mean.nhdr', cubic_res_prefix + '2_novar_expected_ebindist_to_mean.nhdr']
group_novar_expct_logmap_scipy_files = [cubic_res_prefix + '1_novar_expected_ebinlogmap_scipy_at_mean.nhdr', cubic_res_prefix + '2_novar_expected_ebinlogmap_scipy_at_mean.nhdr']
group_novar_expct_inner_prod_scipy_files = [cubic_res_prefix + '1_novar_expected_ebininnerprod_scipy_at_mean.nhdr', cubic_res_prefix + '2_novar_expected_ebininnerprod_scipy_at_mean.nhdr']
group_novar_expct_logmap_files = [cubic_res_prefix + '1_novar_expected_ebinlogmap_at_mean.nhdr', cubic_res_prefix + '2_novar_expected_ebinlogmap_at_mean.nhdr']
group_novar_expct_inner_prod_files = [cubic_res_prefix + '1_novar_expected_ebininnerprod_at_mean.nhdr', cubic_res_prefix + '2_novar_expected_ebininnerprod_at_mean.nhdr']
group_novar_Ebin_dist_files = [cubic_res_prefix + '1_novar_final_ebindist_to_atlas.nhdr', cubic_res_prefix + '2_novar_expected_ebindist_to_mean.nhdr']
group_novar_logmap_scipy_files = [cubic_res_prefix + '1_novar_final_ebinlogmap_scipy_at_atlas.nhdr', cubic_res_prefix + '2_novar_expected_ebinlogmap_scipy_at_mean.nhdr']
group_novar_inner_prod_scipy_files = [cubic_res_prefix + '1_novar_final_ebininnerprod_scipy_at_atlas.nhdr', cubic_res_prefix + '2_novar_expected_ebininnerprod_scipy_at_mean.nhdr']
group_novar_logmap_files = [cubic_res_prefix + '1_novar_final_ebinlogmap_at_atlas.nhdr', cubic_res_prefix + '2_novar_expected_ebinlogmap_at_mean.nhdr']
group_novar_inner_prod_files = [cubic_res_prefix + '1_novar_final_ebininnerprod_at_atlas.nhdr', cubic_res_prefix + '2_novar_expected_ebininnerprod_at_mean.nhdr']

group_1_scaled_tens_files = [cubic_res_prefix + f'1_{cc}_scaled_orig_tensors_v2.nhdr' for cc in range(num_cubics)]
group_1_img_files = [cubic_res_prefix + f'1_{cc}_T1_flip_y.nhdr' for cc in range(num_cubics)]
group_1_alpha_files = [cubic_res_prefix + f'1_{cc}_alpha.nhdr' for cc in range(num_cubics)]
group_1_expct_Ebin_dist_files = [cubic_res_prefix + f'1_{cc}_expected_ebindist_to_mean.nhdr' for cc in range(num_cubics)]
group_1_expct_logmap_scipy_files = [cubic_res_prefix + f'1_{cc}_expected_ebinlogmap_scipy_at_mean.nhdr' for cc in range(num_cubics)]
group_1_expct_inner_prod_scipy_files = [cubic_res_prefix + f'1_{cc}_expected_ebininnerprod_scipy_at_mean.nhdr' for cc in range(num_cubics)]
group_1_expct_logmap_files = [cubic_res_prefix + f'1_{cc}_expected_ebinlogmap_at_mean.nhdr' for cc in range(num_cubics)]
group_1_expct_inner_prod_files = [cubic_res_prefix + f'1_{cc}_expected_ebininnerprod_at_mean.nhdr' for cc in range(num_cubics)]
group_1_Ebin_dist_files = [cubic_res_prefix + f'1_{cc}_final_ebindist_to_atlas.nhdr' for cc in range(num_cubics)]
group_1_logmap_scipy_files = [cubic_res_prefix + f'1_{cc}_final_ebinlogmap_scipy_at_atlas.nhdr' for cc in range(num_cubics)]
group_1_inner_prod_scipy_files = [cubic_res_prefix + f'1_{cc}_final_ebininnerprod_scipy_at_atlas.nhdr' for cc in range(num_cubics)]
group_1_logmap_files = [cubic_res_prefix + f'1_{cc}_final_ebinlogmap_at_atlas.nhdr' for cc in range(num_cubics)]
group_1_inner_prod_files = [cubic_res_prefix + f'1_{cc}_final_ebininnerprod_at_atlas.nhdr' for cc in range(num_cubics)]
group_2_scaled_tens_files = [cubic_res_prefix + f'2_{cc}_scaled_orig_tensors_v2.nhdr' for cc in range(num_cubics)]
group_2_img_files = [cubic_res_prefix + f'2_{cc}_T1_flip_y.nhdr' for cc in range(num_cubics)]
group_2_alpha_files = [cubic_res_prefix + f'2_{cc}_alpha.nhdr' for cc in range(num_cubics)]
group_2_expct_Ebin_dist_files = [cubic_res_prefix + f'2_{cc}_expected_ebindist_to_mean.nhdr' for cc in range(num_cubics)]
group_2_expct_logmap_scipy_files = [cubic_res_prefix + f'2_{cc}_expected_ebinlogmap_scipy_at_mean.nhdr' for cc in range(num_cubics)]
group_2_expct_inner_prod_scipy_files = [cubic_res_prefix + f'2_{cc}_expected_ebininnerprod_scipy_at_mean.nhdr' for cc in range(num_cubics)]
group_2_expct_logmap_files = [cubic_res_prefix + f'2_{cc}_expected_ebinlogmap_at_mean.nhdr' for cc in range(num_cubics)]
group_2_expct_inner_prod_files = [cubic_res_prefix + f'2_{cc}_expected_ebininnerprod_at_atlas.nhdr' for cc in range(num_cubics)]
group_2_Ebin_dist_files = [cubic_res_prefix + f'2_{cc}_final_ebindist_to_atlas.nhdr' for cc in range(num_cubics)]
group_2_logmap_scipy_files = [cubic_res_prefix + f'2_{cc}_final_ebinlogmap_scipy_at_atlas.nhdr' for cc in range(num_cubics)]
group_2_inner_prod_scipy_files = [cubic_res_prefix + f'2_{cc}_final_ebininnerprod_scipy_at_atlas.nhdr' for cc in range(num_cubics)]
group_2_logmap_files = [cubic_res_prefix + f'2_{cc}_final_ebinlogmap_at_atlas.nhdr' for cc in range(num_cubics)]
group_2_inner_prod_files = [cubic_res_prefix + f'2_{cc}_final_ebininnerprod_at_atlas.nhdr' for cc in range(num_cubics)]

expct_atlas = ReadTensors(f'{cubic_res_prefix}_1_2_mean_scaled_orig_tensors_v2.nhdr')
expct_atlas_mask = ReadScalars(f'{cubic_res_prefix}_1_2_mean_orig_mask.nhdr')

est_atlas = ReadTensors(f'{atlasdir}atlas_tens.nhdr')
est_atlas_mask = ReadScalars(f'{atlasdir}atlas_mask.nhdr')
atlas = ReadTensors(f'{atlasdir}final_atlas_tens.nhdr')
atlas_mask = ReadScalars(f'{atlasdir}final_atlas_mask.nhdr')



In [15]:
# Read in cubics, compute Karcher mean
dim = 3
iso_tens = torch.eye(dim,dtype=torch.double).to(device)

img_scale = 1
tens_scale = 1

tens_files = group_1_scaled_tens_files + group_2_scaled_tens_files
img_files = group_1_img_files + group_2_img_files
tensor_lin_list, tensor_met_list, mask_list = [], [], []
first_time = True
for s in range(len(tens_files)):
  tensor_np = ReadTensors(tens_files[s])
  img_np = ReadScalars(img_files[s]) / img_scale
  mask_np = np.zeros_like(img_np) # instead of reading from file
  mask_np[img_np > 0] = 1
    
  tensor_lin_list.append(torch.from_numpy(tens_scale * tensor_np).double().permute(3,2,1,0))
  mask_list.append(torch.from_numpy(mask_np).double().permute(2,1,0).to(device))

  if first_time:
    # since haven't permuted yet, shape is currently depth, width, height
    depth, width, height = tensor_np.shape[:dim]
    first_time = False
    mask_union = torch.zeros_like(mask_list[s])

  mask_union[mask_list[s] > 0] = 1
  
  tensor_met_zeros = torch.zeros(height,width,depth,dim,dim,dtype=torch.float64).to(device)
  tensor_met_zeros[:,:,:,0,0] = tensor_lin_list[s][0]
  tensor_met_zeros[:,:,:,0,1] = tensor_lin_list[s][1]
  tensor_met_zeros[:,:,:,0,2] = tensor_lin_list[s][2]
  tensor_met_zeros[:,:,:,1,0] = tensor_lin_list[s][1]
  tensor_met_zeros[:,:,:,1,1] = tensor_lin_list[s][3]
  tensor_met_zeros[:,:,:,1,2] = tensor_lin_list[s][4]
  tensor_met_zeros[:,:,:,2,0] = tensor_lin_list[s][2]
  tensor_met_zeros[:,:,:,2,1] = tensor_lin_list[s][4]
  tensor_met_zeros[:,:,:,2,2] = tensor_lin_list[s][5]  
    
  tensor_met_zeros = tensor_cleaning(tensor_met_zeros, mask_list[s], iso_tens)  
  tensor_met_list.append(torch.inverse(tensor_met_zeros))
# end for each tensor file  

G = torch.stack(tuple(tensor_met_list)).double()


# Setup for Writing Results

In [16]:
figbase = '/usr/sci/projects/abcd/Figures/simdata/'
figdir=f'{figbase}/{sim_name}'
#save_figs = True
save_figs = False

# Compare Campbell et al 2022 to Clarke 2013 to Gil-Medrano, Michor (GMM) 1991

## $$\operatorname{Exp}_{g_0}^{-1}(g_1)$$
Notation:

$ g_0, g_1 \in M $, $ h, u, b \in T_{g_0} M $


|| GMM 1991 | Clarke 2013 | Campbell 2022 | Old Code | New Code |
|:---:|:---:|:---:|:---:|:---:|:---:|
|1|$$H := (g_0^{-1}h)$$|$$b = g_0 \operatorname{logm} (g_0^{-1} g_1) == h == g_0H$$|$$k =  \operatorname{logm} (g_0^{-1} g_1) == g_0^{-1}h == H$$|$$K = \operatorname{logm} (g_0^{-1} g_1) == H$$|$$b = g_0\operatorname{logm} (g_0^{-1} g_1) == h$$|
|2|$$H_0 = H - \frac{\operatorname{tr}(H)}{n}\operatorname{Id}$$|$$b_T = b - \frac{\operatorname{tr}_{g_0}(b)}{n}g_0$$|$$k_0 = k - \frac{\operatorname{tr}(k)}{n}\operatorname{Id} == H_0$$|$$K_T=K-\frac{\operatorname{tr}(K)}{n} \operatorname{Id} = H_0$$|$$b_T=b-\frac{\operatorname{tr}(g_0^{-1}b)}{n} g_0$$|
|3|$$H_0 = g_0^{-1}h - \frac{\operatorname{tr}(g_0^{-1}h)}{n}\operatorname{Id}$$|$$b_T = b - \frac{\operatorname{tr}(g_0^{-1}b)}{n}g_0$$|$$k_0 = \operatorname{logm} (g_0^{-1} g_1) - \frac{\operatorname{tr}(\operatorname{logm} (g_0^{-1} g_1))}{n}\operatorname{Id}$$|$$K_T = \operatorname{logm}(g_0^{-1} g_1) - \frac{\operatorname{tr}(\operatorname{logm} (g_0^{-1} g_1))}{n}\operatorname{Id}$$|$$b_T = g_0\operatorname{logm}(g_0^{-1} g_1) - \frac{\operatorname{tr}(\operatorname{logm} (g_0^{-1} g_1))}{n}g_0$$|
|4|$$H_0 = g_0^{-1}h - \frac{\operatorname{tr}(g_0^{-1}h)}{n}\operatorname{Id}$$|$$b_T = g_0 \operatorname{logm} (g_0^{-1} g_1) - \frac{\operatorname{tr}(g_0^{-1}g_0 \operatorname{logm} (g_0^{-1} g_1))}{n}g_0$$|$$k_0 = \operatorname{logm} (g_0^{-1} g_1) - \frac{\operatorname{tr}(\operatorname{logm} (g_0^{-1} g_1))}{n}\operatorname{Id}$$|$$K_T = \operatorname{logm}(g_0^{-1} g_1) - \frac{\operatorname{tr}(\operatorname{logm} (g_0^{-1} g_1))}{n}\operatorname{Id}$$|$$b_T = g_0K - \frac{\operatorname{tr}(K)}{n}g_0$$|
|5||$$b_T = g_0H_0 $$|$$k_0 = g_0^{-1}b_T = H_0$$|$$K_T = g_0^{-1}b_T = H_0$$|$$b_T = g_0K_T = g_0H_0$$|
|6|$$\theta= \frac{\sqrt{n \operatorname{tr}(H_0^2)}}{4} $$|$$\theta= \frac{\sqrt{n \operatorname{tr}_{g_0}(b_T^2)}}{4}$$|$$\kappa =  \frac{\sqrt{n \operatorname{tr}(k_0^2)}}{4}$$|$$\theta =  \frac{\sqrt{n \operatorname{tr}((K_T)^2)}}{4}$$|$$\theta = \frac{\sqrt{n \operatorname{tr}((g_0^{-1}b_T)^2)}}{4}$$|
|7|$$\theta= \frac{\sqrt{n \operatorname{tr}(H_0^2)}}{4} $$|$$\theta= \frac{\sqrt{n \operatorname{tr}((g_0^{-1}b_T)^2)}}{4}$$|$$\kappa =  \frac{\sqrt{n \operatorname{tr}((g_0^{-1}b_T)^2)}}{4}$$|$$\theta =  \frac{\sqrt{n \operatorname{tr}((g_0^{-1}b_T)^2)}}{4} $$|$$\theta = \frac{\sqrt{n \operatorname{tr}(K_T^2)}}{4}$$|
|8|$$\gamma=e^{\frac{\operatorname{tr}(H)}{4}}$$|$$\gamma=e^{\frac{\operatorname{tr}_{g_0}(b)}{4}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=e^{\frac{\operatorname{tr}(g_0^{-1}b)}{4}}$$|
|9|$$\gamma=e^{\frac{\operatorname{tr}(g_0^{-1}h)}{4}}$$|$$\gamma=e^{\frac{\operatorname{tr}(g_0^{-1}h)}{4}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=e^{\frac{\operatorname{tr}(K)}{4}}$$|
|10|$$\begin{align}
\psi(H) = \begin{cases}
\frac{4}{n}\left(\gamma\cos\theta-1\right)\operatorname{Id}+\gamma\frac{\sin\theta}{\theta}H_0  & H_0 \neq 0,\\
\frac{4}{n}\left(\gamma-1\right)\operatorname{Id} & \text{o.w.} \\
\end{cases}
\end{align}$$|$$\begin{align}
		\psi(b) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}b_T  & b_T \neq 0,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & b_T = 0
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{Exp}_{g_0}^{-1}g_1\big|_x = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0k_0  & 0<\theta<\pi,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & \theta=0.
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0K_T  & \\
		\frac4n\left((g_0^{-1}g_1)^\frac{n}{4}-1\right)g_0 & 
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0^{-1}b_T  & \\
		\frac4n\left(\gamma-1\right)g_0 & 
		\end{cases}
	\end{align}$$|
|11|$$\begin{align}
		\psi(H) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)\operatorname{Id}+\gamma\frac{\sin\theta}{\theta}H_0  & H_0 \neq 0,\\[.5em]
		\frac4n\left(\gamma-1\right)Id & o.w.
		\end{cases}
	\end{align}$$|$$\begin{align}
		\psi(b) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0H_0  & b_T \neq 0,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & b_T = 0
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{Exp}_{g_0}^{-1}g_1\big|_x = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0H_0  & 0<\theta<\pi,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & \theta=0.
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0H_0  & \\
		\frac4n\left((g_0^{-1}g_1)^\frac{n}{4}-1\right)g_0 & 
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}H_0  & \\
		\frac4n\left(\gamma-1\right)g_0 & 
		\end{cases}
	\end{align}$$|    




## $$\operatorname{Exp}(g_0, u)$$

Case: $g(t,x)$ does NOT pass through $[0]$

TODO: Case where $g(t,x)$ does pass through $[0]$

Notation:

$ g_0, g_1 \in M $, $ h, u, b \in T_{g_0} M $


|| GMM 1991 | Clarke 2013 | Campbell 2022 | Old Code | New Code |
|:---:|:---:|:---:|:---:|:---:|:---:|
|1|$$H := (g_0^{-1}h)$$|$$b = h == g_0H$$|$$k = h$$|$$u = h, U = g_0^{-1}u = H$$|$$u = h, U = g_0^{-1}u = H$$|
|2|$$H_0 = H - \frac{\operatorname{tr}(H)}{n}\operatorname{Id}$$|$$b_T = b - \frac{\operatorname{tr}_{g_0}(b)}{n}g_0$$|$$k_0 = k - \frac{\operatorname{tr}(k)}{n}\operatorname{Id}$$|$$U_T=U-\frac{\operatorname{tr}(U)}{n} \operatorname{Id} = H_0$$|$$U_T=u-\frac{\operatorname{tr}(U)}{n} g_0$$|
|3|$$H_0 = g_0^{-1}h - \frac{\operatorname{tr}(g_0^{-1}h)}{n}\operatorname{Id}$$|$$b_T = h - \frac{\operatorname{tr}(g_0^{-1}h)}{n}g_0$$|$$k_0 = h - \frac{\operatorname{tr}(h)}{n}\operatorname{Id}$$|$$U_T = g_0^{-1} u - \frac{\operatorname{tr}(g_0^{-1} u)}{n}\operatorname{Id}$$|$$U_T = u - \frac{\operatorname{tr}(g_0^{-1} u)}{n}g_0$$|
|4||$$b_T = g_0H_0 $$|$$k_0 = g_0^{-1}b_T = H_0$$|$$U_T = g_0^{-1}b_T = H_0$$|$$U_T = b_T = g_0H_0$$|
|5|$$q= 1 + \frac{t}{4}\operatorname{tr}(H)$$|$$q= 1 + \frac{t}{4}\operatorname{tr}_{g_0}(b)$$|$$q = 1+t( \frac{\sqrt[4]{\operatorname{det}(g_1)}\cos(\kappa)-\sqrt[4]{\operatorname{det}(g_0)}}{\sqrt[4]{\operatorname{det}(g_0)}})$$|$$q = \frac{\operatorname{tr}(U)}{4}+1$$|$$q = \frac{t\operatorname{tr}(U)}{4}+1$$|
|6|$$q= 1 + \frac{t}{4}\operatorname{tr}(g_0^{-1}h)$$|$$q= 1 + \frac{t}{4}\operatorname{tr}(g_0^{-1}h)$$|$$q = 1+t( \frac{\sqrt[4]{\operatorname{det}(g_1)}\cos(\kappa)-\sqrt[4]{\operatorname{det}(g_0)}}{\sqrt[4]{\operatorname{det}(g_0)}})$$|$$q = \frac{\operatorname{tr}(g_0^{-1}u)}{4}+1$$|$$q = \frac{t\operatorname{tr}(g_0^{-1}u)}{4}+1$$|
|7|$$r= \frac{t}{4}\sqrt{n \operatorname{tr}(H_0^2)}$$|$$r= \frac{t}{4}\sqrt{n \operatorname{tr}_{g_0}(b_T^2)}$$|$$r = \frac{t\sqrt[4]{\operatorname{det}(g_1)}\sin \kappa}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$r= \frac{\sqrt{n \operatorname{tr}(U_T^2)}}{4}$$|$$r= \frac{t\sqrt{n \operatorname{tr}((g_0^{-1}U_T)^2)}}{4}$$|
|8|$$r= \frac{t}{4}\sqrt{n \operatorname{tr}(H_0^2)}$$|$$r= \frac{t}{4}\sqrt{n \operatorname{tr}((g_0^{-1}b_T)^2)}$$|$$r = \frac{t\sqrt[4]{\operatorname{det}(g_1)}\sin \kappa}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$r= \frac{\sqrt{n \operatorname{tr}(H_0^2)}}{4}$$|$$r= \frac{t\sqrt{n \operatorname{tr}((g_0^{-1}b_T)^2)}}{4}$$|
|9|$$r= \frac{t}{4}\sqrt{n \operatorname{tr}(H_0^2)}$$|$$r= \frac{t}{4}\sqrt{n \operatorname{tr}(H_0^2)}$$|$$r = \frac{t\sqrt[4]{\operatorname{det}(g_1)}\sin \kappa}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$r= \frac{\sqrt{n \operatorname{tr}(H_0^2)}}{4}$$|$$r= \frac{t\sqrt{n \operatorname{tr}(H_0^2)}}{4}$$|
|10|$$\begin{align}
g(t)= \begin{cases}
       g_0(q^2+r^2)^{2/n}e^{\frac{t\operatorname{atan}(r/q)}{r}H_0} & \operatorname{tr}(H_0^2) \neq 0 \\
       g_0q^{4/n}e^{\frac{tH_0}{q}} & \operatorname{tr}(H_0^2) = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{t\operatorname{atan}(r/q)}{r}b_T} & b_T \neq 0 \\
       q^{4/n}g_0 & \operatorname{tr}(H_0^2) = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{\operatorname{atan}(r/q)}{\kappa}k_0} & \kappa \neq 0 \\
       q^{4/n}g_0 & \kappa = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{\operatorname{atan}(r/q)}{r}U_T} & U_T \neq 0 \\
       q^{4/n}g_0 & \operatorname{tr}(U_T^2) = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{t\operatorname{atan}(r/q)}{r}g_0^{-1}U_T} & U_T \neq 0 \\
       q^{4/n}g_0 & \operatorname{tr}(U_T^2) = 0 \\
       \end{cases}
       \end{align}$$|
|11|$$\begin{align}
g(t)= \begin{cases}
       g_0(q^2+r^2)^{2/n}e^{\frac{t\operatorname{atan}(r/q)}{r}H_0} & \operatorname{tr}(H_0^2) \neq 0 \\
       g_0q^{4/n}e^{\frac{tH_0}{q}} & \operatorname{tr}(H_0^2) = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{t\operatorname{atan}(r/q)}{r}g_0H_0} & b_T \neq 0 \\
       q^{4/n}g_0 & \operatorname{tr}(H_0^2) = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{\operatorname{atan}(r/q)}{\kappa}H_0} & \kappa \neq 0 \\
       q^{4/n}g_0 & \kappa = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{\operatorname{atan}(r/q)}{r}H_0} & U_T \neq 0 \\
       q^{4/n}g_0 & \operatorname{tr}(U_T^2) = 0 \\
       \end{cases}
       \end{align}$$|$$\begin{align}
g(t)= \begin{cases}
       (q^2+r^2)^{2/n}g_0e^{\frac{t\operatorname{atan}(r/q)}{r}H_0} & U_T \neq 0 \\
       q^{4/n}g_0 & \operatorname{tr}(U_T^2) = 0 \\
       \end{cases}
       \end{align}$$|


## $$\operatorname{Dist}(g_0, g_1)$$
Notation:

$ g_0, g_1 \in M $, $ h, u, b \in T_{g_0} M $


|| Clarke 2013 | Campbell 2022 | Old Code | New Code |
|:---:|:---:|:---:|:---:|:---:|
|1|$$\begin{align}
d_x^2(g_0, g_1) = \begin{cases}
\psi(b)_{g_0}^2 & x \in N \cap P, b_T \neq 0 \\
\frac{16}{n} \left|\sqrt[4]{\operatorname{det}(g_1)} - \sqrt[4]{\operatorname{det}(g_0)}\right|^2 & x \in N \cap P, b_T = 0 \\ 
\frac{16}{n} \left(\sqrt[4]{\operatorname{det}(g_0)} + \sqrt[4]{\operatorname{det}(g_1)}\right)^2 & x \not\in N \cap P\\ \end{cases}
\end{align}$$||||
|2|$$\begin{align}
d_x^2(g_0, g_1) = \begin{cases}
\frac{16}{n}\left(\sqrt{\operatorname{det}(g_0)} - 2\sqrt[4]{\operatorname{det}(g_0)}\sqrt[4]{\operatorname{det}(g_1)}\cos\theta + \sqrt{\operatorname{det}(g_1)}\right) & x \in N \cap P, b_T \neq 0 \\
\frac{16}{n} \left|\sqrt[4]{\operatorname{det}(g_1)} - \sqrt[4]{\operatorname{det}(g_0)}\right|^2 & x \in N \cap P, b_T = 0 \\ 
\frac{16}{n} \left(\sqrt[4]{\operatorname{det}(g_0)} + \sqrt[4]{\operatorname{det}(g_1)}\right)^2 & x \not\in N \cap P\\ 
\end{cases}
\end{align}$$||||

$$
$$


|| GMM 1991 | Clarke 2013 | Campbell 2022 | Old Code | New Code |
|:---:|:---:|:---:|:---:|:---:|:---:|
|8|$$\gamma=e^{\frac{\operatorname{tr}(H)}{4}}$$|$$\gamma=e^{\frac{\operatorname{tr}_{g_0}(b)}{4}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=e^{\frac{\operatorname{tr}(g_0^{-1}b)}{4}}$$|
|9|$$\gamma=e^{\frac{\operatorname{tr}(g_0^{-1}h)}{4}}$$|$$\gamma=e^{\frac{\operatorname{tr}(g_0^{-1}h)}{4}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=\frac{\sqrt[4]{\operatorname{det}(g_1)}}{\sqrt[4]{\operatorname{det}(g_0)}}$$|$$\gamma=e^{\frac{\operatorname{tr}(K)}{4}}$$|
|10|$$\begin{align}
\psi(H) = \begin{cases}
\frac{4}{n}\left(\gamma\cos\theta-1\right)\operatorname{Id}+\gamma\frac{\sin\theta}{\theta}H_0  & H_0 \neq 0,\\
\frac{4}{n}\left(\gamma-1\right)\operatorname{Id} & \text{o.w.} \\
\end{cases}
\end{align}$$|$$\begin{align}
		\psi(b) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}b_T  & b_T \neq 0,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & b_T = 0
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{Exp}_{g_0}^{-1}g_1\big|_x = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0k_0  & 0<\theta<\pi,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & \theta=0.
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0K_T  & \\
		\frac4n\left((g_0^{-1}g_1)^\frac{3}{4}-1\right)g_0 & 
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0^{-1}b_T  & \\
		\frac4n\left((\gamma-1\right)g_0 & 
		\end{cases}
	\end{align}$$|
|11|$$\begin{align}
		\psi(H) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)\operatorname{Id}+\gamma\frac{\sin\theta}{\theta}H_0  & H_0 \neq 0,\\[.5em]
		\frac4n\left(\gamma-1\right)Id & o.w.
		\end{cases}
	\end{align}$$|$$\begin{align}
		\psi(b) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0H_0  & b_T \neq 0,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & b_T = 0
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{Exp}_{g_0}^{-1}g_1\big|_x = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0H_0  & 0<\theta<\pi,\\[.5em]
		\frac4n\left(\gamma-1\right)g_0 & \theta=0.
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}g_0H_0  & \\
		\frac4n\left((g_0^{-1}g_1)^\frac{3}{4}-1\right)g_0 & 
		\end{cases}
	\end{align}$$|$$\begin{align}
		\operatorname{invRieExp}(g_0,g_1) = \begin{cases}
		\frac4n\left(\gamma\cos\theta-1\right)g_0+\gamma\frac{\sin\theta}{\theta}K_T  & \\
		\frac4n\left((\gamma-1\right)g_0 & 
		\end{cases}
	\end{align}$$| 

# Run Some Verifying Tests

In [17]:
def orig_se_logm_invB_A(B, A):
#     inputs: A/B.shape = [h, w, d, 3, 3]
#     output: shape = [h, w, d, 3, 3]
    G = torch.linalg.cholesky(B)
    inv_G = torch.inverse(G)
    W = torch.einsum("...ij,...jk,...lk->...il", inv_G, A, inv_G)
    #lamda, Q = torch.symeig(W, eigenvectors=True)
    lamda, Q = se.apply(W)

    log_lamda = torch.zeros((*lamda.shape, lamda.shape[-1]),dtype=torch.double)
    ## for i in range(lamda.shape[-1]):
    ##     log_lamda[:, i, i] = torch.log(lamda[:, i])
    #lamda, Q = torch.linalg.eig(W)#, eigenvectors=True
    #lamda, Q = lamda.real, Q.real
    log_lamda = torch.diag_embed(torch.log(lamda))
    V = torch.einsum('...ji,...jk->...ik', inv_G, Q)
    inv_V = torch.inverse(V)
    return torch.einsum('...ij,...jk,...kl->...il', V, log_lamda, inv_V)


def old_inv_RieExp(g0, g1, a):  # g0,g1: two tensors of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the inverse Riemannian exponential of g1 in the image of the maximal domain of the Riemannian exponential at g0
    '''
    n = g1.size(-1)
    #     matrix multiplication
    inv_g0_g1 = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), g1)  # (s,t,...,3,3)

    def get_u_g0direction(g0, inv_g0_g1):  # (-1,3,3) first reshape g0,g1,inv_g..
        #         permute
        inv_g0_g1 = torch.einsum("...ij->ij...", inv_g0_g1)  # (3,3,-1)
        s = inv_g0_g1[0, 0]  # (-1)
        u = 4 / n * (s ** (n / 4) - 1) * torch.einsum("...ij->ij...", g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        return u.permute(2, 0, 1)  # (-1,3,3)

    def get_u_ng0direction(g0, g1, inv_g0_g1, a):  # (-1,3,3) first reshape g0,g1,inv_g..
        print('u_ng0direction')
        K = orig_se_logm_invB_A(g0, g1)
        KTrless = K - torch.einsum("...ii,kl->...kl", K, torch.eye(n, dtype=torch.double)) / n  # (-1,3,3)
        #         AA^T
        theta = (1 / a * torch.einsum("...ik,...ki->...", KTrless, KTrless)).sqrt() / 4  # (-1)
        gamma = torch.det(g1).pow(1 / 4) / (torch.det(g0).pow(1 / 4))  # (-1)

        A = 4 / n * (gamma * torch.cos(theta) - 1)  # (-1)
        B = 1 / theta * gamma * torch.sin(theta)
        u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", g0, KTrless)  # (-1)@(3,3,-1) -> (3,3,-1)
        #print('FIXED')
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ij->ij...", KTrless)  # (-1)@(3,3,-1) -> (3,3,-1)
        return u.permute(2, 0, 1)  # (-1,3,3)

    inv_g0_g1_trless = inv_g0_g1 - torch.einsum("...ii,kl->...kl", inv_g0_g1, torch.eye(n, dtype=torch.double)) / n  # (s,t,...,2,2)
    norm0 = torch.einsum("...ij,...ij->...", inv_g0_g1_trless, inv_g0_g1_trless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)  # using squeeze results in [1,1]->[]
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    u = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double)  # (-1,3,3)
    if len(Indn0) == 0:
        u = get_u_g0direction(g0.reshape(-1, n, n), inv_g0_g1.reshape(-1, n, n))
    elif len(Ind0) == 0:
        u = get_u_ng0direction(g0.reshape(-1, n, n), g1.reshape(-1, n, n), inv_g0_g1.reshape(-1, n, n), a)
    else:
        u[Ind0] = get_u_g0direction(g0.reshape(-1, n, n)[Ind0], inv_g0_g1.reshape(-1, n, n)[Ind0])
        u[Indn0] = get_u_ng0direction(g0.reshape(-1, n, n)[Indn0], g1.reshape(-1, n, n)[Indn0], inv_g0_g1.reshape(-1, n, n)[Indn0], a)

    return u.reshape(g1.size())

def old_inv_RieExp_scipy(g0, g1, a):  # g0,g1: two tensors of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the inverse Riemannian exponential of g1 in the image of the maximal domain of the Riemannian exponential at g0
    '''
    n = g1.size(-1)
    #     matrix multiplication
    inv_g0_g1 = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), g1)  # (s,t,...,3,3)

    def get_u_g0direction(g0, inv_g0_g1):  # (-1,3,3) first reshape g0,g1,inv_g..
        #         permute
        inv_g0_g1 = torch.einsum("...ij->ij...", inv_g0_g1)  # (3,3,-1)
        s = inv_g0_g1[0, 0]  # (-1)
        u = 4 / n * (s ** (n / 4) - 1) * torch.einsum("...ij->ij...", g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        return u.permute(2, 0, 1)  # (-1,3,3)

    def get_u_ng0direction(g0, g1, inv_g0_g1, a):  # (-1,3,3) first reshape g0,g1,inv_g..
        print('u_ng0direction')
        K = scipy_logm_invB_A(g0, g1)
        KTrless = K - torch.einsum("...ii,kl->...kl", K, torch.eye(n, dtype=torch.double)) / n  # (-1,3,3)
        #         AA^T
        theta = (1 / a * torch.einsum("...ik,...ki->...", KTrless, KTrless)).sqrt() / 4  # (-1)
        gamma = torch.det(g1).pow(1 / 4) / (torch.det(g0).pow(1 / 4))  # (-1)

        A = 4 / n * (gamma * torch.cos(theta) - 1)  # (-1)
        B = 1 / theta * gamma * torch.sin(theta)
        u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", g0, KTrless)  # (-1)@(3,3,-1) -> (3,3,-1)
        #print('FIXED')
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ij->ij...", KTrless)  # (-1)@(3,3,-1) -> (3,3,-1)
        return u.permute(2, 0, 1)  # (-1,3,3)

    inv_g0_g1_trless = inv_g0_g1 - torch.einsum("...ii,kl->...kl", inv_g0_g1, torch.eye(n, dtype=torch.double)) / n  # (s,t,...,2,2)
    norm0 = torch.einsum("...ij,...ij->...", inv_g0_g1_trless, inv_g0_g1_trless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)  # using squeeze results in [1,1]->[]
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    u = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double)  # (-1,3,3)
    if len(Indn0) == 0:
        u = get_u_g0direction(g0.reshape(-1, n, n), inv_g0_g1.reshape(-1, n, n))
    elif len(Ind0) == 0:
        u = get_u_ng0direction(g0.reshape(-1, n, n), g1.reshape(-1, n, n), inv_g0_g1.reshape(-1, n, n), a)
    else:
        u[Ind0] = get_u_g0direction(g0.reshape(-1, n, n)[Ind0], inv_g0_g1.reshape(-1, n, n)[Ind0])
        u[Indn0] = get_u_ng0direction(g0.reshape(-1, n, n)[Indn0], g1.reshape(-1, n, n)[Indn0], inv_g0_g1.reshape(-1, n, n)[Indn0], a)

    return u.reshape(g1.size())

def old_Rie_Exp(g0, u, a):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the Riemannian exponential of u in the the maximal domain of the Riemannian exponential at g0
    '''
    n = g0.size(-1)

    U = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), u)  # (s,t,...,3,3)
    trU = torch.einsum("...ii->...", U)  # (s,t,...)
    UTrless = U - torch.einsum("...,ij->...ij", trU, torch.eye(n, n, dtype=torch.double)) / n  # (s,t,...,3,3)

    #     in g0 direction:K_0=0
    def get_g1_g0direction(g0, trU):  # first reshape g0 (-1,3,3) and trU (-1)
        g1 = (trU / 4 + 1).pow(4 / n) * torch.einsum("...ij->ij...", g0)  # (3,3,-1)
        return g1.permute(2, 0, 1)  # (-1,3,3)

    #     not in g0 direction SplitEbinMetric.pdf Theorem 1 :K_0\not=0
    def get_g1_ng0direction(g0, trU, UTrless, a):  # first reshape g0,UTrless (-1,3,3) and trU (-1)
        print('g1_ng0direction')
        if len((trU < -4).nonzero().reshape(-1)) != 0:
            warnings.warn('The tangent vector u is out of the maximal domain of the Riemannian exponential.', DeprecationWarning)

        q = trU / 4 + 1  # (-1)
        r = (1 / a * torch.einsum("...ik,...ki->...", UTrless, UTrless)).sqrt() / 4  # (-1)

        ArctanUtrless = torch.atan2(r, q) * torch.einsum("...ij->ij...", UTrless) / r  # use (2,2,-1) for computation
        ExpArctanUtrless = torch.matrix_exp(ArctanUtrless.permute(2, 0, 1)).permute(1, 2, 0)

        g1 = (q ** 2 + r ** 2).pow(2 / n) * torch.einsum("...ik,kj...->ij...", g0, ExpArctanUtrless)  # (2,2,-1)
        return g1.permute(2, 0, 1)  # (-1,2,2)

    #     pointwise multiplication Tr(U^TU)
    norm0 = torch.einsum("...ij,...ij->...", UTrless, UTrless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    #     k_0=0 or \not=0
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    g1 = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double)  # (-1,2,2)
    if len(Indn0) == 0:
        g1 = get_g1_g0direction(g0.reshape(-1, n, n), trU.reshape(-1))
    elif len(Ind0) == 0:
        g1 = get_g1_ng0direction(g0.reshape(-1, n, n), trU.reshape(-1), UTrless.reshape(-1, n, n), a)
    else:
        g1[Ind0] = get_g1_g0direction(g0.reshape(-1, n, n)[Ind0], trU.reshape(-1)[Ind0])
        g1[Indn0] = get_g1_ng0direction(g0.reshape(-1, n, n)[Indn0], trU.reshape(-1)[Indn0], UTrless.reshape(-1, n, n)[Indn0], a)

    return g1.reshape(g0.size())

''' 
The following Riemannian exponential and inverse Riemannian exponential are extended to the case g0=0 
'''
def old_Rie_Exp_extended(g0, u, a):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3)
    size = g0.size()
    g0, u = g0.reshape(-1, *size[-2:]), u.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        g1 = u * g0.size(-1) / 4
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        print('yep, all PD')
        g1 = old_Rie_Exp(g0, u, a)
    else:
        g1 = torch.zeros(g0.size(), dtype=torch.double)
        g1[Ind_g0_is0] = u[Ind_g0_is0] * g0.size(-1) / 4
        g1[Ind_g0_isnot0] = old_Rie_Exp(g0[Ind_g0_isnot0], u[Ind_g0_isnot0], a)
    return g1.reshape(size)


def old_inv_RieExp_extended(g0, g1, a):  # g0, g1: (s,t,...,3,3)
    size = g0.size()
    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        u = g1 * 4 / g0.size(-1)
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        print('yep, all PD')
        u = old_inv_RieExp(g0, g1, a)
    else:
        u = torch.zeros(g0.size(), dtype=torch.double)
        u[Ind_g0_is0] = g1[Ind_g0_is0] * 4 / g0.size(-1)
        u[Ind_g0_isnot0] = old_inv_RieExp(g0[Ind_g0_isnot0], g1[Ind_g0_isnot0], a)
    return u.reshape(size)

def old_inv_RieExp_extended_scipy(g0, g1, a):  # g0, g1: (s,t,...,3,3)
    size = g0.size()
    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        u = g1 * 4 / g0.size(-1)
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        print('yep, all PD')
        u = old_inv_RieExp_scipy(g0, g1, a)
    else:
        u = torch.zeros(g0.size(), dtype=torch.double)
        u[Ind_g0_is0] = g1[Ind_g0_is0] * 4 / g0.size(-1)
        u[Ind_g0_isnot0] = old_inv_RieExp_scipy(g0[Ind_g0_isnot0], g1[Ind_g0_isnot0], a)
    return u.reshape(size)

def old_logm_invB_A(B, A):
#     inputs: A/B.shape = [hxwxd, 3, 3]
#     output: shape = [hxwxd, 3, 3]
    G = batch_cholesky(B)
    nonpsd_idx = torch.where(torch.isnan(G))
    if len(nonpsd_idx[0]) > 0:
      print(len(nonpsd_idx[0]), 'non psd entries found in logm_invB_A', nonpsd_idx)
    for i in range(len(nonpsd_idx[0])):
      G[nonpsd_idx[0][i]] = torch.eye((3)).double().to(device=B.device)

    # KMC The following clamp reduces crashes, but adds many bad artifacts  
    #inv_G = torch.inverse(G.clamp(min=1.0e-10))
    #inv_G = torch.inverse(G)

    # KMC comment out following 4 lines and see if pseudo inverse sufficient instead
    det_G = torch.det(G)
    inv_G = torch.zeros_like(G)
    inv_G[det_G>0.] = torch.pinverse(G[det_G>0.])
    inv_G[det_G<=0.] = torch.eye((3)).double().to(device=B.device)
    #inv_G = torch.pinverse(G)    
    W = torch.einsum("...ij,...jk,...lk->...il", inv_G, A, inv_G)

    #W_sym = (W + torch.transpose(W,len(W.shape)-2,len(W.shape)-1))/2
    # The eigenvector computation becomes inaccurate for matrices close to identity
    # Set those closer than float machine eps from identity matrix to identity
    #W[(torch.abs(W[:]-torch.eye((3)))<1.1921e-7).sum(dim=(1,2))<9] = torch.eye((3))
    lamda, Q = se.apply(W)

    # Get consistent eigenvector sign following approach of:
    # https://www.osti.gov/servlets/purl/920802
    S = torch.ones_like(lamda)
    lQ0outer = torch.einsum('...i,...i,...j->...ij',lamda[...,0].reshape((-1,1)),Q[...,0],Q[...,0])
    lQ1outer = torch.einsum('...i,...i,...j->...ij',lamda[...,1].reshape((-1,1)),Q[...,1],Q[...,1])
    lQ2outer = torch.einsum('...i,...i,...j->...ij',lamda[...,2].reshape((-1,1)),Q[...,2],Q[...,2])
    Y0 = W - lQ1outer - lQ2outer
    Y1 = W - lQ0outer - lQ2outer
    Y2 = W - lQ0outer - lQ1outer
    q0y0 = torch.einsum('...i,...i->...',Q[...,0],Y0[...,0])
    q0y1 = torch.einsum('...i,...i->...',Q[...,0],Y0[...,1])
    q0y2 = torch.einsum('...i,...i->...',Q[...,0],Y0[...,2])
    S[...,0] = torch.sign(q0y0) * q0y0 * q0y0 + torch.sign(q0y1) * q0y1 * q0y1 + torch.sign(q0y2) * q0y2 * q0y2
    q1y0 = torch.einsum('...i,...i->...',Q[...,1],Y1[...,0])
    q1y1 = torch.einsum('...i,...i->...',Q[...,1],Y1[...,1])
    q1y2 = torch.einsum('...i,...i->...',Q[...,1],Y1[...,2])
    S[...,1] = torch.sign(q1y0) * q1y0 * q1y0 + torch.sign(q1y1) * q1y1 * q1y1 + torch.sign(q1y2) * q1y2 * q1y2
    q2y0 = torch.einsum('...i,...i->...',Q[...,2],Y2[...,0])
    q2y1 = torch.einsum('...i,...i->...',Q[...,2],Y2[...,1])
    q2y2 = torch.einsum('...i,...i->...',Q[...,2],Y2[...,2])
    S[...,2] = torch.sign(q2y0) * q2y0 * q2y0 + torch.sign(q2y1) * q2y1 * q2y1 + torch.sign(q2y2) * q2y2 * q2y2
    
    #lamda, Q = se.apply(W_sym)
    #log_lamda = torch.zeros((*lamda.shape, lamda.shape[-1]),dtype=torch.double)
    #log_lamda = torch.diag_embed(torch.log(lamda))
    #log_lamda = torch.diag_embed(torch.log(torch.where(lamda>1.0e-20,lamda,1.0e-20)))
    log_lamda = torch.diag_embed(torch.log(lamda.clamp(min=1.0e-15)))
    #V = torch.einsum('...ji,...jk->...ik', inv_G, Q)
    # include S here to use best signs for Q
    V = torch.einsum('...ji,...jk,...k->...ik', inv_G, Q, torch.sign(S))

    inv_V = torch.inverse(V)
    # KMC comment out following 4 lines and see if pseudo inverse sufficient instead
    #det_V = torch.det(V)
    #inv_V = torch.zeros_like(V)
    #inv_V[det_V>0.] = torch.pinverse(V[det_V>0.])
    #inv_V[det_V>1.e-8] = torch.pinverse(V[det_V>1.e-8])
    #inv_V[det_V<=0.] = torch.eye((3)).double().to(device=B.device)

    result = torch.einsum('...ij,...jk,...kl->...il', V, log_lamda, inv_V)
    ill_cond_idx = (inv_V > 1e20).nonzero().reshape(-1)
    num_ill_cond = len(ill_cond_idx)
    if num_ill_cond > 0:
      dbg_ill = ill_cond_idx[0]
      print('Replacing', num_ill_cond, 'ill-conditioned results in logm_invB_A with identity. First index is', dbg_ill)
      result[ill_cond_idx] = torch.eye((3)).double().to(device=B.device)
    
    return result



In [18]:
def fix_logm_invB_A(B, A):
#     inputs: A/B.shape = [h, w, d, 3, 3]
#     output: shape = [h, w, d, 3, 3]
    #G = torch.linalg.cholesky(B)
    G = batch_cholesky(B)
    inv_G = torch.pinverse(G)
    nonpsd_idx = torch.where(torch.isnan(G))
    if len(nonpsd_idx[0]) > 0:
      print(len(nonpsd_idx[0]), 'non psd entries found in logm_invB_A', nonpsd_idx)
    for i in range(len(nonpsd_idx[0])):
      G[nonpsd_idx[0][i]] = torch.eye((3)).double().to(device=B.device)

    # KMC The following clamp reduces crashes, but adds many bad artifacts  
    #inv_G = torch.inverse(G.clamp(min=1.0e-10))
    #inv_G = torch.inverse(G)

    # KMC comment out following 4 lines and see if pseudo inverse sufficient instead
    #det_G = torch.det(G)
    #print('det_G:',det_G)
    #inv_G = torch.zeros_like(G)
    #inv_G[det_G>0.] = torch.pinverse(G[det_G>0.])
    #inv_G[det_G<=0.] = torch.eye((3)).double().to(device=B.device)
    W = torch.einsum("...ij,...jk,...lk->...il", inv_G, A, inv_G)
    #print('W shape:',W.shape)
    print('torch eig')
    lamda, Q = torch.symeig(W, eigenvectors=True)
    lamda_good = lamda
    Q_good = Q
    print('se eig')
    lamda, Q = se.apply(W)
    lamda_se = lamda
    Q_se = Q
    
    print('lamda good\n:',lamda_good)
    print('lamda se\n:',lamda_se)
    print('Q good\n:',Q_good)
    print('Q se\n:',Q_se)
    print('Using Q, lamda se')
    Q = Q_se
    lamda = lamda_se
    #print('Using Q, lamda good')
    #Q = Q_good
    #lamda = lamda_good
    print('lamda good - lamda se:\n',lamda_good-lamda_se)
    print('Q good - Q se:\n',Q_good-Q_se)
    print('W:\n',W)
    print('Q lamda Q_T good:\n',
          torch.einsum("...ij,...jk,...lk->...il", Q_good, torch.diag_embed(lamda_good), 
                       Q_good))
    print('Q lamda Q_T se:\n',
          torch.einsum("...ij,...jk,...lk->...il", Q_se, torch.diag_embed(lamda_se), 
                       Q_se))
    
    log_lamda = torch.zeros((*lamda.shape, lamda.shape[-1]),dtype=torch.double)
    ## for i in range(lamda.shape[-1]):
    ##     log_lamda[:, i, i] = torch.log(lamda[:, i])
    #lamda, Q = torch.linalg.eig(W)#, eigenvectors=True
    #lamda, Q = lamda.real, Q.real
    log_lamda = torch.diag_embed(torch.log(lamda))
    log_lamda_good = torch.diag_embed(torch.log(lamda_good))
    V = torch.einsum('...ji,...jk->...ik', inv_G, Q)
    V_good = torch.einsum('...ji,...jk->...ik', inv_G, Q_good)
    print('V shape',V.shape)
    print('V:\n',V)
    print('inv_G:\n',inv_G)
    print('Q:\n',Q)
    #inv_V = torch.pinverse(V)
    inv_V_good = torch.pinverse(V_good)
    det_V = torch.abs(torch.det(V))
    inv_V = torch.zeros_like(V)
    #inv_V[det_V>0.] = torch.pinverse(V[det_V>0.])
    inv_V[det_V>1.e-8] = torch.pinverse(V[det_V>1.e-8])
    inv_V[det_V<=1.e-8] = torch.eye((3)).double().to(device=B.device)
   #det_V = torch.det(V)
    #inv_V = torch.zeros_like(V)
    #print('det_V:',det_V)
    #inv_V[det_V>0.] = torch.pinverse(V[det_V>0.])
    #inv_V[det_V>1.e-8] = torch.pinverse(V[det_V>1.e-8])
    #inv_V[det_V<=0.] = torch.eye((3)).double().to(device=B.device)
    #print('inv_V bad:',inv_V)
    #print('inv_V good:',torch.inverse(V))
    ans_good = torch.einsum('...ij,...jk,...kl->...il', V_good, log_lamda_good, inv_V_good)
    ans_se = torch.einsum('...ij,...jk,...kl->...il', V, log_lamda, inv_V)
    print('logm_invB_A good:\n',ans_good)
    print('logm_invB_A se:\n',ans_se)
    print('diff logm_invB_A:\n',ans_good-ans_se)
    return torch.einsum('...ij,...jk,...kl->...il', V, log_lamda, inv_V)


def orig_logm_invB_A(B, A):
#     inputs: A/B.shape = [h, w, d, 3, 3]
#     output: shape = [h, w, d, 3, 3]
    G = torch.linalg.cholesky(B)
    inv_G = torch.inverse(G)
    W = torch.einsum("...ij,...jk,...lk->...il", inv_G, A, inv_G)
    lamda, Q = torch.symeig(W, eigenvectors=True)
    log_lamda = torch.zeros((*lamda.shape, lamda.shape[-1]),dtype=torch.double)
    ## for i in range(lamda.shape[-1]):
    ##     log_lamda[:, i, i] = torch.log(lamda[:, i])
    #lamda, Q = torch.linalg.eig(W)#, eigenvectors=True
    #lamda, Q = lamda.real, Q.real
    log_lamda = torch.diag_embed(torch.log(lamda))
    V = torch.einsum('...ji,...jk->...ik', inv_G, Q)
    inv_V = torch.inverse(V)
    return torch.einsum('...ij,...jk,...kl->...il', V, log_lamda, inv_V)


def orig_inv_RieExp(g0, g1, a):  # g0,g1: two tensors of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the inverse Riemannian exponential of g1 in the image of the maximal domain of the Riemannian exponential at g0
    '''
    n = g1.size(-1)
    #     matrix multiplication
    inv_g0_g1 = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), g1)  # (s,t,...,3,3)

    def get_u_g0direction(g0, inv_g0_g1):  # (-1,3,3) first reshape g0,g1,inv_g..
        #         permute
        inv_g0_g1 = torch.einsum("...ij->ij...", inv_g0_g1)  # (3,3,-1)
        s = inv_g0_g1[0, 0]  # (-1)
        u = 4 / n * (s ** (n / 4) - 1) * torch.einsum("...ij->ij...", g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        return u.permute(2, 0, 1)  # (-1,3,3)

    def get_u_ng0direction(g0, g1, inv_g0_g1, a):  # (-1,3,3) first reshape g0,g1,inv_g..
        K = orig_logm_invB_A(g0, g1)
        #print('Using fixed old_logm_invB_A')
        #K = old_logm_invB_A(g0, g1)
        KTrless = K - torch.einsum("...ii,kl->...kl", K, torch.eye(n, dtype=torch.double)) / n  # (-1,3,3)
        #         AA^T
        theta = (1 / a * torch.einsum("...ik,...ki->...", KTrless, KTrless)).sqrt() / 4  # (-1)
        gamma = torch.det(g1).pow(1 / 4) / (torch.det(g0).pow(1 / 4))  # (-1)

        A = 4 / n * (gamma * torch.cos(theta) - 1)  # (-1)
        B = 1 / theta * gamma * torch.sin(theta)
        u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", g0, KTrless)  # (-1)@(3,3,-1) -> (3,3,-1)
        return u.permute(2, 0, 1)  # (-1,3,3)

    inv_g0_g1_trless = inv_g0_g1 - torch.einsum("...ii,kl->...kl", inv_g0_g1, torch.eye(n, dtype=torch.double)) / n  # (s,t,...,2,2)
    norm0 = torch.einsum("...ij,...ij->...", inv_g0_g1_trless, inv_g0_g1_trless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)  # using squeeze results in [1,1]->[]
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    u = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double)  # (-1,3,3)
    if len(Indn0) == 0:
        u = get_u_g0direction(g0.reshape(-1, n, n), inv_g0_g1.reshape(-1, n, n))
    elif len(Ind0) == 0:
        u = get_u_ng0direction(g0.reshape(-1, n, n), g1.reshape(-1, n, n), inv_g0_g1.reshape(-1, n, n), a)
    else:
        u[Ind0] = get_u_g0direction(g0.reshape(-1, n, n)[Ind0], inv_g0_g1.reshape(-1, n, n)[Ind0])
        u[Indn0] = get_u_ng0direction(g0.reshape(-1, n, n)[Indn0], g1.reshape(-1, n, n)[Indn0], inv_g0_g1.reshape(-1, n, n)[Indn0], a)

    return u.reshape(g1.size())

def orig_inv_RieExp_extended(g0, g1, a):  # g0, g1: (s,t,...,3,3)
    size = g0.size()
    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        u = g1 * 4 / g0.size(-1)
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        u = orig_inv_RieExp(g0, g1, a)
    else:
        u = torch.zeros(g0.size(), dtype=torch.double)
        u[Ind_g0_is0] = g1[Ind_g0_is0] * 4 / g0.size(-1)
        u[Ind_g0_isnot0] = orig_inv_RieExp(g0[Ind_g0_isnot0], g1[Ind_g0_isnot0], a)
    return u.reshape(size)

def orig_Rie_Exp(g0, u, a):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the Riemannian exponential of u in the the maximal domain of the Riemannian exponential at g0
    '''
    n = g0.size(-1)

    U = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), u)  # (s,t,...,3,3)
    trU = torch.einsum("...ii->...", U)  # (s,t,...)
    UTrless = U - torch.einsum("...,ij->...ij", trU, torch.eye(n, n, dtype=torch.double)) / n  # (s,t,...,3,3)

    #     in g0 direction:K_0=0
    def get_g1_g0direction(g0, trU):  # first reshape g0 (-1,3,3) and trU (-1)
        g1 = (trU / 4 + 1).pow(4 / n) * torch.einsum("...ij->ij...", g0)  # (3,3,-1)
        return g1.permute(2, 0, 1)  # (-1,3,3)

    #     not in g0 direction SplitEbinMetric.pdf Theorem 1 :K_0\not=0
    def get_g1_ng0direction(g0, trU, UTrless, a):  # first reshape g0,UTrless (-1,3,3) and trU (-1)
        if len((trU < -4).nonzero().reshape(-1)) != 0:
            warnings.warn('The tangent vector u is out of the maximal domain of the Riemannian exponential.', DeprecationWarning)

        q = trU / 4 + 1  # (-1)
        r = (1 / a * torch.einsum("...ik,...ki->...", UTrless, UTrless)).sqrt() / 4  # (-1)

        ArctanUtrless = torch.atan2(r, q) * torch.einsum("...ij->ij...", UTrless) / r  # use (2,2,-1) for computation
        ExpArctanUtrless = torch.matrix_exp(ArctanUtrless.permute(2, 0, 1)).permute(1, 2, 0)

        g1 = (q ** 2 + r ** 2).pow(2 / n) * torch.einsum("...ik,kj...->ij...", g0, ExpArctanUtrless)  # (2,2,-1)
        return g1.permute(2, 0, 1)  # (-1,2,2)

    #     pointwise multiplication Tr(U^TU)
    norm0 = torch.einsum("...ij,...ij->...", UTrless, UTrless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    #     k_0=0 or \not=0
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    g1 = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double)  # (-1,2,2)
    if len(Indn0) == 0:
        g1 = get_g1_g0direction(g0.reshape(-1, n, n), trU.reshape(-1))
    elif len(Ind0) == 0:
        g1 = get_g1_ng0direction(g0.reshape(-1, n, n), trU.reshape(-1), UTrless.reshape(-1, n, n), a)
    else:
        g1[Ind0] = get_g1_g0direction(g0.reshape(-1, n, n)[Ind0], trU.reshape(-1)[Ind0])
        g1[Indn0] = get_g1_ng0direction(g0.reshape(-1, n, n)[Indn0], trU.reshape(-1)[Indn0], UTrless.reshape(-1, n, n)[Indn0], a)

    return g1.reshape(g0.size())


''' 
The following Riemannian exponential and inverse Riemannian exponential are extended to the case g0=0 
'''
def orig_Rie_Exp_extended(g0, u, a):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3)
    size = g0.size()
    g0, u = g0.reshape(-1, *size[-2:]), u.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        g1 = u * g0.size(-1) / 4
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        g1 = orig_Rie_Exp(g0, u, a)
    else:
        g1 = torch.zeros(g0.size(), dtype=torch.double)
        g1[Ind_g0_is0] = u[Ind_g0_is0] * g0.size(-1) / 4
        g1[Ind_g0_isnot0] = orig_Rie_Exp(g0[Ind_g0_isnot0], u[Ind_g0_isnot0], a)
    return g1.reshape(size)


In [19]:
def new_scipy_logm_invB_A(B, A):
#     inputs: A/B.shape = [hxwxd, 3, 3]
#     output: shape = [hxwxd, 3, 3]
    G = batch_cholesky(B)
    nonpsd_idx = torch.where(torch.isnan(G))
    if len(nonpsd_idx[0]) > 0:
      print(len(nonpsd_idx[0]), 'non psd entries found in logm_invB_A', nonpsd_idx)
    for i in range(len(nonpsd_idx[0])):
      G[nonpsd_idx[0][i]] = torch.eye((3)).double().to(device=B.device)

    # KMC The following clamp reduces crashes, but adds many bad artifacts  
    #inv_G = torch.inverse(G.clamp(min=1.0e-10))
    #inv_G = torch.inverse(G)

    # KMC comment out following 4 lines and see if pseudo inverse sufficient instead
    #det_G = torch.det(G)
    #inv_G = torch.zeros_like(G)
    #inv_G[det_G>0.] = torch.pinverse(G[det_G>0.])
    #inv_G[det_G<=0.] = torch.eye((3)).double().to(device=B.device)
    inv_G = torch.pinverse(G)
    
    W = torch.einsum("...ij,...jk,...lk->...il", inv_G, A, inv_G)

 
    logm_W = torch.from_numpy(vectorized_logm(W.cpu().detach().numpy())).double().to(device=B.device)
    return(logm_W)


def new_inv_RieExp_extended(g0, g1, a):  # g0, g1: (s,t,...,3,3)
    size = g0.size()
    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        u = g1 * 4 / g0.size(-1)
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        u = new_inv_RieExp(g0, g1, a)
    else:
        u = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        u[Ind_g0_is0] = g1[Ind_g0_is0] * 4 / g0.size(-1)
        u[Ind_g0_isnot0] = new_inv_RieExp(g0[Ind_g0_isnot0], g1[Ind_g0_isnot0], a)
    return u.reshape(size)

def new_inv_RieExp(g0, g1, a):  # g0,g1: two tensors of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the inverse Riemannian exponential of g1 in the image of the maximal domain of the Riemannian exponential at g0
    '''
    n = g1.size(-1)
    #inv_g0_g1 = make_pos_def(torch.einsum("...ik,...kj->...ij", torch.inverse(g0), g1),None, 1.0e-10)  # (s,t,...,3,3)
    logm_invg0_g1 = fix_logm_invB_A(g0, g1)
    #logm_invg0_g1 = logm_invB_A(g0, g1)
    tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
    b = torch.einsum("...ik,...kj->...ij",g0, logm_invg0_g1)
    bT = b - torch.einsum("...,...ij->...ij", tr_g0_b, g0) * a
    
    def get_u_g0direction(g0, logm_invg0_g1):  # (-1,3,3) first reshape g0,g1,inv_g..
        #         permute
        #inv_g0_g1 = torch.einsum("...ij->ij...", inv_g0_g1)  # (3,3,-1)
        #s = inv_g0_g1[0, 0].clamp(min=1.0e-15)  # (-1)
        #u = 4 / n * (s ** (n / 4) - 1) * torch.einsum("...ij->ij...", g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
        u = 4 / n * (torch.exp(tr_g0_b / 4.0) - 1) * g0
        
        #return u.permute(2, 0, 1)  # (-1,3,3)
        return u  # (-1,3,3)

    def get_u_ng0direction(g0, g1, logm_invg0_g1, bT, a):  # (-1,3,3) first reshape g0,g1,inv_g..
        det_threshold=1e-11
        where_below = torch.where(torch.det(g0)<=det_threshold)
        num_below = len(where_below[0])
        if num_below > 0:
          print('inv_RieExp num det(g0) below thresh:', num_below)
        #K = scipy_logm_invB_A(g0, g1)
        tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
        expTrg0invb = torch.exp(tr_g0_b / 4.0)
        #         AA^T
        g0inv = torch.inverse(g0)
        g0bT = torch.einsum("...ik,...kj->...ij", g0inv, bT)
        theta = ((1. / a * torch.einsum("...ik,...ki->...", g0bT, g0bT)).clamp(min=1.0e-15).sqrt() / 4.).clamp(min=1.0e-15)  # (-1)

        A = 4. / n * (expTrg0invb * torch.cos(theta) - 1)  # (-1)
        B = 1. / theta * expTrg0invb * torch.sin(theta)
        # Clarke
        print('Clarke!')
        u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ij->ij...", bT)  # (-1)@(3,3,-1) -> (3,3,-1)
        # Kris
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", bT, g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kl,...lj->ij...", g0inv, bT, g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        # Kris trying to use conversion between GMM and Clarke, BT = g0H0
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", g0inv, bT)  # (-1)@(3,3,-1) -> (3,3,-1)

        return u.permute(2, 0, 1)  # (-1,3,3)

    #inv_g0_g1_trless = inv_g0_g1 - torch.einsum("...ii,kl->...kl", inv_g0_g1, torch.eye(n, dtype=torch.double, device=g0.device))  # (s,t,...,2,2)
    #norm0 = torch.einsum("...ij,...ij->...", inv_g0_g1_trless, inv_g0_g1_trless).reshape(-1)  # (-1)
    norm0 = torch.einsum("...ij,...ij->...", bT, bT).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)  # using squeeze results in [1,1]->[]
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    u = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double, device=g0.device)  # (-1,3,3)
    if len(Indn0) == 0:
        u = get_u_g0direction(g0.reshape(-1, n, n), logm_invg0_g1.reshape(-1, n, n))
    elif len(Ind0) == 0:
        u = get_u_ng0direction(g0.reshape(-1, n, n), g1.reshape(-1, n, n), 
                               logm_invg0_g1.reshape(-1, n, n), bT.reshape(-1, n, n), a)
    else:
        u[Ind0] = get_u_g0direction(g0.reshape(-1, n, n)[Ind0], logm_invg0_g1.reshape(-1, n, n)[Ind0])
        u[Indn0] = get_u_ng0direction(g0.reshape(-1, n, n)[Indn0], g1.reshape(-1, n, n)[Indn0], 
                                      logm_invg0_g1.reshape(-1, n, n)[Indn0], bT.reshape(-1, n, n)[Indn0], a)

    return u.reshape(g1.size())

def new_inv_RieExp_extended_scipy(g0, g1, a):  # g0, g1: (s,t,...,3,3)
    size = g0.size()
    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        u = g1 * 4 / g0.size(-1)
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        u = new_inv_RieExp_scipy(g0, g1, a)
    else:
        u = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        u[Ind_g0_is0] = g1[Ind_g0_is0] * 4 / g0.size(-1)
        u[Ind_g0_isnot0] = new_inv_RieExp_scipy(g0[Ind_g0_isnot0], g1[Ind_g0_isnot0], a)
    return u.reshape(size)

def new_inv_RieExp_scipy(g0, g1, a):  # g0,g1: two tensors of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the inverse Riemannian exponential of g1 in the image of the maximal domain of the Riemannian exponential at g0
    '''
    n = g1.size(-1)
    #inv_g0_g1 = make_pos_def(torch.einsum("...ik,...kj->...ij", torch.inverse(g0), g1),None, 1.0e-10)  # (s,t,...,3,3)
    logm_invg0_g1 = new_scipy_logm_invB_A(g0, g1)
    tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
    b = torch.einsum("...ik,...kj->...ij",g0, logm_invg0_g1)
    bT = b - torch.einsum("...,...ij->...ij", tr_g0_b, g0) * a
    
    def get_u_g0direction(g0, logm_invg0_g1):  # (-1,3,3) first reshape g0,g1,inv_g..
        #         permute
        #inv_g0_g1 = torch.einsum("...ij->ij...", inv_g0_g1)  # (3,3,-1)
        #s = inv_g0_g1[0, 0].clamp(min=1.0e-15)  # (-1)
        #u = 4 / n * (s ** (n / 4) - 1) * torch.einsum("...ij->ij...", g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
        u = 4 / n * (torch.exp(tr_g0_b / 4.0) - 1) * g0
        
        #return u.permute(2, 0, 1)  # (-1,3,3)
        return u  # (-1,3,3)

    def get_u_ng0direction(g0, g1, logm_invg0_g1, bT, a):  # (-1,3,3) first reshape g0,g1,inv_g..
        det_threshold=1e-11
        where_below = torch.where(torch.det(g0)<=det_threshold)
        num_below = len(where_below[0])
        if num_below > 0:
          print('inv_RieExp num det(g0) below thresh:', num_below)
        #K = scipy_logm_invB_A(g0, g1)
        tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
        expTrg0invb = torch.exp(tr_g0_b / 4.0)
        #         AA^T
        g0inv = torch.inverse(g0)
        g0bT = torch.einsum("...ik,...kj->...ij", g0inv, bT)
        theta = ((1. / a * torch.einsum("...ik,...ki->...", g0bT, g0bT)).clamp(min=1.0e-15).sqrt() / 4.).clamp(min=1.0e-15)  # (-1)

        A = 4. / n * (expTrg0invb * torch.cos(theta) - 1)  # (-1)
        B = 1. / theta * expTrg0invb * torch.sin(theta)
        # Clarke
        print('Clarke!')
        u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ij->ij...", bT)  # (-1)@(3,3,-1) -> (3,3,-1)
        # Kris
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", bT, g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kl,...lj->ij...", g0inv, bT, g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        # Kris trying to use conversion between GMM and Clarke, BT = g0H0
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", g0inv, bT)  # (-1)@(3,3,-1) -> (3,3,-1)

        return u.permute(2, 0, 1)  # (-1,3,3)

    #inv_g0_g1_trless = inv_g0_g1 - torch.einsum("...ii,kl->...kl", inv_g0_g1, torch.eye(n, dtype=torch.double, device=g0.device))  # (s,t,...,2,2)
    #norm0 = torch.einsum("...ij,...ij->...", inv_g0_g1_trless, inv_g0_g1_trless).reshape(-1)  # (-1)
    norm0 = torch.einsum("...ij,...ij->...", bT, bT).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)  # using squeeze results in [1,1]->[]
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    u = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double, device=g0.device)  # (-1,3,3)
    if len(Indn0) == 0:
        u = get_u_g0direction(g0.reshape(-1, n, n), logm_invg0_g1.reshape(-1, n, n))
    elif len(Ind0) == 0:
        u = get_u_ng0direction(g0.reshape(-1, n, n), g1.reshape(-1, n, n), 
                               logm_invg0_g1.reshape(-1, n, n), bT.reshape(-1, n, n), a)
    else:
        u[Ind0] = get_u_g0direction(g0.reshape(-1, n, n)[Ind0], logm_invg0_g1.reshape(-1, n, n)[Ind0])
        u[Indn0] = get_u_ng0direction(g0.reshape(-1, n, n)[Indn0], g1.reshape(-1, n, n)[Indn0], 
                                      logm_invg0_g1.reshape(-1, n, n)[Indn0], bT.reshape(-1, n, n)[Indn0], a)

    return u.reshape(g1.size())

def new_Rie_Exp(g0, u, a, t=1.0):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the Riemannian exponential of u in the the maximal domain of the Riemannian exponential at g0
    '''
    n = g0.size(-1)
    g0inv = torch.inverse(g0)
    U = torch.einsum("...ik,...kj->...ij", g0inv, u)  # (s,t,...,3,3)
    #Ug0inv = torch.einsum("...ik,...kj->...ij", U, g0inv)
    trU = torch.einsum("...ii->...", U)  # (s,t,...)
    # GMM UTrless = H0
    #UTrless = U - torch.einsum("...,ij->...ij", trU, torch.eye(n, n, dtype=torch.double, device=g0.device)) / n  # (s,t,...,3,3)
    # Clarke version of UTrless = BT   
    UTrless = u - torch.einsum("...,...ij->...ij", trU, g0) / n  # (s,t,...,3,3)

    #     in g0 direction:K_0=0
    def get_g1_g0direction(g0, trU, t):  # first reshape g0 (-1,3,3) and trU (-1)
        g1 = (t * trU / 4. + 1).pow(4. / n) * torch.einsum("...ij->ij...", g0)  # (3,3,-1)
        return g1.permute(2, 0, 1)  # (-1,3,3)

    #     not in g0 direction SplitEbinMetric.pdf Theorem 1 :K_0\not=0
    def get_g1_ng0direction(g0, trU, UTrless, a, t):  # first reshape g0,UTrless (-1,3,3) and trU (-1)
        if len((trU < -4).nonzero().reshape(-1)) != 0:
            warnings.warn('The tangent vector u is out of the maximal domain of the Riemannian exponential.', DeprecationWarning)

        # GMM and Clarke q match    
        q = t * trU / 4. + 1  # (-1)
        g0UTrless = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), UTrless)
        # GMM r
        #r = t * (1. / a * torch.einsum("...ik,...ki->...", UTrless, UTrless)).clamp(min=1.0e-15).sqrt() / 4.  # (-1)
        # Clarke r
        r = t * (1. / a * torch.einsum("...ik,...ki->...", g0UTrless, g0UTrless)).clamp(min=1.0e-15).sqrt() / 4.  # (-1)
        
        #ArctanUtrless = (torch.atan2(r, q) * torch.einsum("...ij->ij...", UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        # GMM Arctan
        #ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ij->ij...", UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        # Clarke Arctan
        ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ij->ij...", g0UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        # Kris Arctan
        #ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ik,...kj->ij...", g0UTrless, g0) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        #ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ik,...kj->ij...", g0inv, g0UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation

        ExpArctanUtrless = torch.nan_to_num(torch.matrix_exp(ArctanUtrless.permute(2, 0, 1)).permute(1, 2, 0))
        ExpArctanUtrless[torch.abs(ExpArctanUtrless) > 1e12] = 0

        # GMM, Clarke At
        At = 2./n * torch.einsum("...,ij->ij...", torch.log(q**2 + r**2), torch.eye(n, n, dtype=torch.double, device=g0.device))
        # Kris At
        #At = 2./n * torch.log(q**2 + r**2) * torch.einsum("...ij->ij...", g0)
        ExpAt = torch.nan_to_num(torch.matrix_exp(At.permute(2, 0, 1)).permute(1, 2, 0))
        ExpAt[torch.abs(ExpAt) > 1e12] = 0
        # GMM, Clarke g1
        g1 = (q ** 2 + r ** 2).pow(2. / n) * torch.einsum("...ik,kj...->ij...", g0, ExpArctanUtrless)  # (2,2,-1)
        #g1 = torch.einsum("...ik,kl...,lj...->ij...", g0, ExpAt, ExpArctanUtrless)  # (2,2,-1)
        # Kris g1
        #g1 = torch.einsum("...ik,kl...,lm...,...mj->ij...", g0, ExpAt, ExpArctanUtrless,g0)  # (2,2,-1)
        #g1 = torch.einsum("kl...,lm...->km...", ExpAt, ExpArctanUtrless)  # (2,2,-1)

        return g1.permute(2, 0, 1)  # (-1,2,2)

    #     pointwise multiplication Tr(U^TU)
    #UMinusTrU = U - torch.einsum("...,ij->...ij", trU, torch.eye(n, n, dtype=torch.double, device=g0.device))  # (s,t,...,3,3)

    #norm0 = torch.einsum("...ij,...ij->...", UMinusTrU, UMinusTrU).reshape(-1)  # (-1)
    norm0 = torch.einsum("...ij,...ij->...", UTrless, UTrless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    #     k_0=0 or \not=0
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    g1 = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double, device=g0.device)  # (-1,2,2)
    if len(Indn0) == 0:
        g1 = get_g1_g0direction(g0.reshape(-1, n, n), trU.reshape(-1), t)
    elif len(Ind0) == 0:
        g1 = get_g1_ng0direction(g0.reshape(-1, n, n), trU.reshape(-1), UTrless.reshape(-1, n, n), a, t)
    else:
        g1[Ind0] = get_g1_g0direction(g0.reshape(-1, n, n)[Ind0], trU.reshape(-1)[Ind0], t)
        g1[Indn0] = get_g1_ng0direction(g0.reshape(-1, n, n)[Indn0], trU.reshape(-1)[Indn0], UTrless.reshape(-1, n, n)[Indn0], a, t)
    return g1.reshape(g0.size())



''' 
The following Riemannian exponential and inverse Riemannian exponential are extended to the case g0=0 
'''
def new_Rie_Exp_extended(g0, u, a, t=1):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3)
    size = g0.size()
    g0, u = g0.reshape(-1, *size[-2:]), u.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        g1 = u * g0.size(-1) / 4
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        g1 = new_Rie_Exp(g0, u, a, t)
    else:
        g1 = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        g1[Ind_g0_is0] = u[Ind_g0_is0] * g0.size(-1) / 4
        g1[Ind_g0_isnot0] = new_Rie_Exp(g0[Ind_g0_isnot0], u[Ind_g0_isnot0], a, t)
    
    return g1.reshape(size)

In [20]:
print(G.device)

cuda:0


In [21]:
e = torch.exp(torch.tensor(1))
ident = torch.eye(dim).reshape((1,dim,dim))
oldinvREIde = old_inv_RieExp_extended(ident,e * ident, 1./dim)
print('e*Id\n',e*ident)
print('OLD invRieExp(Id, e*Id)\n',oldinvREIde)
oldRE_invREIde = old_Rie_Exp_extended(ident,oldinvREIde, 1./dim)
print('OLD RieExp(invRieExp(Id, e*Id))\n',oldRE_invREIde)

newinvREIde = new_inv_RieExp_extended(ident,e * ident, 1./dim)
print('e*Id\n',e*ident)
print('new invRieExp(Id, e*Id)\n',newinvREIde)
newRE_invREIde = new_Rie_Exp_extended(ident,newinvREIde, 1./dim)
print('new RieExp(invRieExp(Id, e*Id))\n',newRE_invREIde)


oldinvREeId = old_inv_RieExp_extended(e * ident,ident, 1./dim)
print('e*Id\n',e*ident)
# expect to get u = 4/3 * (e-(3/4)-1)*e*Id
expectu = 4./dim * (e**(-dim/4.) - 1.)*e*ident
print('Expect invRieExp(e*Id, Id)\n',expectu)
print('OLD invRieExp(e*Id, Id)\n',oldinvREeId)
# Should get Identity for both of following
print('Expect RieExp(Expect invRieExp(e*Id, Id))\n',ident)
oldRE_expectu = old_Rie_Exp_extended(e*ident,expectu, 1./dim)
print('OLD RieExp(Expect invRieExp(e*Id, Id))\n',oldRE_expectu)
oldRE_invREeId = old_Rie_Exp_extended(e*ident,oldinvREeId, 1./dim)
print('OLD RieExp(invRieExp(e*Id, Id))\n',oldRE_invREeId)

newinvREeId = new_inv_RieExp_extended(e * ident,ident, 1./dim)
print('e*Id\n',e*ident)
# expect to get u = 4/3 * (e-(3/4)-1)*e*Id
expectu = 4./dim * (e**(-dim/4.) - 1.)*e*ident
print('Expect invRieExp(e*Id, Id)\n',expectu)
print('new invRieExp(e*Id, Id)\n',newinvREeId)
# Should get Identity for both of following
print('Expect RieExp(Expect invRieExp(e*Id, Id))\n',ident)
newRE_expectu = new_Rie_Exp_extended(e*ident,expectu, 1./dim)
print('new RieExp(Expect invRieExp(e*Id, Id))\n',newRE_expectu)
newRE_invREeId = new_Rie_Exp_extended(e*ident,newinvREeId, 1./dim)
print('new RieExp(invRieExp(e*Id, Id))\n',newRE_invREeId)

originvREeId = orig_inv_RieExp_extended(e * ident,ident, 1./dim)
print('e*Id\n',e*ident)
# expect to get u = 4/3 * (e-(3/4)-1)*e*Id
expectu = 4./dim * (e**(-dim/4.) - 1.)*e*ident
print('Expect invRieExp(e*Id, Id)\n',expectu)
print('orig invRieExp(e*Id, Id)\n',originvREeId)
# Should get Identity for both of following
print('Expect RieExp(Expect invRieExp(e*Id, Id))\n',ident)
origRE_expectu = orig_Rie_Exp_extended(e*ident,expectu, 1./dim)
print('orig RieExp(Expect invRieExp(e*Id, Id))\n',origRE_expectu)
origRE_invREeId = orig_Rie_Exp_extended(e*ident,originvREeId, 1./dim)
print('orig RieExp(invRieExp(e*Id, Id))\n',origRE_invREeId)


yep, all PD
e*Id
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
OLD invRieExp(Id, e*Id)
 tensor([[[1.4893, 0.0000, 0.0000],
         [0.0000, 1.4893, 0.0000],
         [0.0000, 0.0000, 1.4893]]])
yep, all PD
OLD RieExp(invRieExp(Id, e*Id))
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
torch eig
se eig
lamda good
: tensor([[2.7183, 2.7183, 2.7183]])
lamda se
: tensor([[2.7183, 2.7183, 2.7183]])
Q good
: tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
Q se
: tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
Using Q, lamda se
lamda good - lamda se:
 tensor([[0., 0., 0.]])
Q good - Q se:
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
W:
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
Q lamda Q_T good:
 tensor([[[2.7183, 0.0000, 0.0000],
     

In [22]:
oldRE_invRE00 = old_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
oldRE_invRE11 = old_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('old RieExp[0,0]\n',oldRE_invRE00)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('old RieExp[1,1]\n',oldRE_invRE11)


oldinvRE00 = old_inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
oldinvRE11 = old_inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('old inv_RieExp[0,0]\n',oldinvRE00)
print('old inv_RieExp[1,1]\n',oldinvRE11)

oldinvRE01 = old_inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
oldinvRE10 = old_inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('old inv_RieExp[0,1]\n',oldinvRE01)
print('old inv_RieExp[1,0]\n',oldinvRE10)


oldRE_invRE01 = old_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), oldinvRE01, 1./dim)
oldRE_invRE10 = old_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), oldinvRE10, 1./dim)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('old RieExp(inv_RieExp)[0,1]\n',oldRE_invRE01)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('old RieExp(inv_RieExp)[1,0]\n',oldRE_invRE10)

yep, all PD
yep, all PD
g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
old RieExp[0,0]
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
old RieExp[1,1]
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
yep, all PD
yep, all PD
old inv_RieExp[0,0]
 tensor([[[0., -0., 0.],
         [-0., 0., 0.],
         [0., 0., 0.]]])
old inv_RieExp[1,1]
 tensor([[[0., -0., 0.],
         [-0., 0., 0.],
         [0., 0., 0.]]])
yep, all PD
u_ng0direction
yep, all PD
u_ng0direction
old inv_RieExp[0,1]
 tensor([[[-0.0948,  0.1113,  0.0000],
         [ 0.1113, -0.1431,  0.0000],
         [ 0.0000,  0.0000, -0.1725]]])
old inv_RieExp[1,0]
 tensor([[[ 0.0910, -0.1079,  0.0000],
        

In [23]:
print('\n---------------------------\n')

newRE_invRE00 = new_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
newRE_invRE11 = new_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('new RieExp[0,0]\n',newRE_invRE00)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('new RieExp[1,1]\n',newRE_invRE11)


newinvRE00 = new_inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
newinvRE11 = new_inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('new inv_RieExp[0,0]\n',newinvRE00)
print('new inv_RieExp[1,1]\n',newinvRE11)

newinvRE01 = new_inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
newinvRE10 = new_inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('new inv_RieExp[0,1]\n',newinvRE01)
print('new inv_RieExp[1,0]\n',newinvRE10)


newRE_invRE01 = new_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), newinvRE01, 1./dim)
newRE_invRE10 = new_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), newinvRE10, 1./dim)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('new RieExp(inv_RieExp)[0,1]\n',newRE_invRE01)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('new RieExp(inv_RieExp)[1,0]\n',newRE_invRE10)


---------------------------

g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
new RieExp[0,0]
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
new RieExp[1,1]
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
torch eig
se eig
lamda good
: tensor([[1.0000, 1.0000, 1.0000]])
lamda se
: tensor([[1.0000, 1.0000, 1.0000]])
Q good
: tensor([[[-0.7071, -0.7071,  0.0000],
         [-0.7071,  0.7071,  0.0000],
         [-0.0000,  0.0000,  1.0000]]])
Q se
: tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
Using Q, lamda se
lamda good - lamda se:
 tensor([[-8.8818e-16,  0.0000e+00,  6.6613e-16]])
Q good - Q se:
 tensor([[[-1.7071, -0.7071,  0.0000],
    

 tensor([[[ 9.0987e-02, -1.0791e-01,  5.0117e-19],
         [-1.0791e-01,  1.3785e-01,  4.0401e-19],
         [ 1.8087e-17,  1.4580e-17,  1.6655e-01]]])
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
new RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.3418e+00, -1.1340e+00,  1.8864e-17],
         [-1.1340e+00,  1.8344e+00, -2.3401e-17],
         [ 1.3973e-17, -1.7333e-17,  2.0359e+00]]])
g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
new RieExp(inv_RieExp)[1,0]
 tensor([[[ 1.4348e+00, -1.2436e+00,  5.0264e-19],
         [-1.2436e+00,  1.9749e+00,  4.0520e-19],
         [ 1.8140e-17,  1.4623e-17,  2.2054e+00]]])


In [24]:
print('\n---------------------------\n')

origRE_invRE00 = orig_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
origRE_invRE11 = orig_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('orig RieExp[0,0]\n',origRE_invRE00)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('orig RieExp[1,1]\n',origRE_invRE11)


originvRE00 = orig_inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
originvRE11 = orig_inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('orig inv_RieExp[0,0]\n',originvRE00)
print('orig inv_RieExp[1,1]\n',originvRE11)

originvRE01 = orig_inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
originvRE10 = orig_inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('orig inv_RieExp[0,1]\n',originvRE01)
print('orig inv_RieExp[1,0]\n',originvRE10)


origRE_invRE01 = orig_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), originvRE01, 1./dim)
origRE_invRE10 = orig_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), originvRE10, 1./dim)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('orig RieExp(inv_RieExp)[0,1]\n',origRE_invRE01)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('orig RieExp(inv_RieExp)[1,0]\n',origRE_invRE10)


---------------------------

g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
orig RieExp[0,0]
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
orig RieExp[1,1]
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
orig inv_RieExp[0,0]
 tensor([[[0., -0., 0.],
         [-0., 0., 0.],
         [0., 0., 0.]]])
orig inv_RieExp[1,1]
 tensor([[[0., -0., 0.],
         [-0., 0., 0.],
         [0., 0., 0.]]])
orig inv_RieExp[0,1]
 tensor([[[-0.0948,  0.1113,  0.0000],
         [ 0.1113, -0.1431,  0.0000],
         [ 0.0000,  0.0000, -0.1725]]])
orig inv_RieExp[1,0]
 tensor([[[ 0.0910, -0.1079,  0.0000],
         [-0.1079,  0.1379,  0.0000],
         [ 0.0000,  0.0000,  0.1665]

In [25]:
print('\n---------------------------\n')

RE_invRE00 = Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
RE_invRE11 = Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('RieExp[0,0]\n',RE_invRE00)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('RieExp[1,1]\n',RE_invRE11)


invRE00 = inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE11 = inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,0]\n',invRE00)
print('inv_RieExp[1,1]\n',invRE11)

invRE01 = inv_RieExp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE10 = inv_RieExp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,1]\n',invRE01)
print('inv_RieExp[1,0]\n',invRE10)


RE_invRE01 = Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), invRE01, 1./dim)
RE_invRE10 = Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), invRE10, 1./dim)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('RieExp(inv_RieExp)[0,1]\n',RE_invRE01)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('RieExp(inv_RieExp)[1,0]\n',RE_invRE10)


---------------------------

g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
RieExp[0,0]
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
RieExp[1,1]
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
inv_RieExp[0,0]
 tensor([[[0., -0., 0.],
         [-0., 0., 0.],
         [0., 0., 0.]]])
inv_RieExp[1,1]
 tensor([[[0., -0., 0.],
         [-0., 0., 0.],
         [0., 0., 0.]]])
inv_RieExp[0,1]
 tensor([[[-9.4783e-02,  1.1132e-01,  1.9580e-17],
         [ 1.1132e-01, -1.4313e-01, -2.4289e-17],
         [ 1.4504e-17, -1.7991e-17, -1.7249e-01]]])
inv_RieExp[1,0]
 tensor([[[ 9.0987e-02, -1.0791e-01,  5.0117e-19],
         [-1.0791e-01,  1.3785e-01,  4.0401e-19],
      

In [30]:
g0 = torch.tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
g1 = torch.tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
print('g0:\n',g0)
print('g1:\n',g1)

oldinvRE01 = old_inv_RieExp_extended(g0.reshape((1,3,3)), g1.reshape((1,3,3)), 1./dim)
oldinvRE10 = old_inv_RieExp_extended(g1.reshape((1,3,3)), g0.reshape((1,3,3)), 1./dim)
print('old inv_RieExp[0,1]\n',oldinvRE01)
print('old inv_RieExp[1,0]\n',oldinvRE10)


oldRE_invRE01 = old_Rie_Exp_extended(g0.reshape((1,3,3)), oldinvRE01, 1./dim)
oldRE_invRE10 = old_Rie_Exp_extended(g1.reshape((1,3,3)), oldinvRE10, 1./dim)
print('Expected RieExp(inv_RieExp)[0,1]\n',g1.reshape((1,3,3)))
print('old RieExp(inv_RieExp)[0,1]\n',oldRE_invRE01)
print('Expected RieExp(inv_RieExp)[1,0]\n',g0.reshape((1,3,3)))
print('old RieExp(inv_RieExp)[1,0]\n',oldRE_invRE10)

print('\n-----------------------------\n')
print('g0:\n',g0)
print('g1:\n',g1)

originvRE01 = orig_inv_RieExp_extended(g0.reshape((1,3,3)), g1.reshape((1,3,3)), 1./dim)
originvRE10 = orig_inv_RieExp_extended(g1.reshape((1,3,3)), g0.reshape((1,3,3)), 1./dim)
print('orig inv_RieExp[0,1]\n',originvRE01)
print('orig inv_RieExp[1,0]\n',originvRE10)


origRE_invRE01 = orig_Rie_Exp_extended(g0.reshape((1,3,3)), originvRE01, 1./dim)
origRE_invRE10 = orig_Rie_Exp_extended(g1.reshape((1,3,3)), originvRE10, 1./dim)
print('Expected RieExp(inv_RieExp)[0,1]\n',g1.reshape((1,3,3)))
print('orig RieExp(inv_RieExp)[0,1]\n',origRE_invRE01)
print('Expected RieExp(inv_RieExp)[1,0]\n',g0.reshape((1,3,3)))
print('orig RieExp(inv_RieExp)[1,0]\n',origRE_invRE10)

origRE10_invRE01 = orig_Rie_Exp_extended(g1.reshape((1,3,3)), -originvRE01, 1./dim)
origRE01_invRE10 = orig_Rie_Exp_extended(g0.reshape((1,3,3)), -originvRE10, 1./dim)
print('Expected RieExp(-inv_RieExp[0,1])[1,0]\n',g0.reshape((1,3,3)))
print('orig RieExp(-inv_RieExp[0,1])[1,0]\n',origRE10_invRE01)
print('Expected RieExp(-inv_RieExp[1,0])[0,1]\n',g1.reshape((1,3,3)))
print('orig RieExp(-inv_RieExp[1,0])[0,1]\n',origRE01_invRE10)


g0:
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
g1:
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
yep, all PD
u_ng0direction
yep, all PD
u_ng0direction
old inv_RieExp[0,1]
 tensor([[[-0.0949,  0.1113,  0.0000],
         [ 0.1113, -0.1431,  0.0000],
         [ 0.0000,  0.0000, -0.1724]]])
old inv_RieExp[1,0]
 tensor([[[ 0.0911, -0.1079,  0.0000],
         [-0.1079,  0.1378,  0.0000],
         [ 0.0000,  0.0000,  0.1666]]])
yep, all PD
g1_ng0direction
yep, all PD
g1_ng0direction
Expected RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
old RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
Expected RieExp(inv_RieExp)[1,0]
 tensor([[[ 1.4348, -1.2436,  0.0000],
       

In [29]:
g0 = torch.tensor([[[ 5.2354, -1.4173,  0.0000],
         [-1.4173,  1.5663,  0.0000],
         [ 0.0000,  0.0000, 10.1547]]]).double()
g1 = torch.tensor([[[ 1.4237, -0.3782,  0.0000],
         [-0.3782,  0.4445,  0.0000],
         [ 0.0000,  0.0000,  2.7570]]]).double()
print('g0:\n',g0)
print('g1:\n',g1)

newinvRE01 = new_inv_RieExp_extended_scipy(g0.reshape((1,3,3)), g1.reshape((1,3,3)), 1./dim)
newinvRE10 = new_inv_RieExp_extended_scipy(g1.reshape((1,3,3)), g0.reshape((1,3,3)), 1./dim)
print('new inv_RieExp[0,1]\n',newinvRE01)
print('new inv_RieExp[1,0]\n',newinvRE10)


newRE_invRE01 = new_Rie_Exp_extended(g0.reshape((1,3,3)), newinvRE01, 1./dim)
newRE_invRE10 = new_Rie_Exp_extended(g1.reshape((1,3,3)), newinvRE10, 1./dim)
print('Expected RieExp(inv_RieExp)[0,1]\n',g1.reshape((1,3,3)))
print('new RieExp(inv_RieExp)[0,1]\n',newRE_invRE01)
print('Expected RieExp(inv_RieExp)[1,0]\n',g0.reshape((1,3,3)))
print('new RieExp(inv_RieExp)[1,0]\n',newRE_invRE10)

print('\n-----------------------------\n')
print('g0:\n',g0)
print('g1:\n',g1)

originvRE01 = orig_inv_RieExp_extended(g0.reshape((1,3,3)), g1.reshape((1,3,3)), 1./dim)
originvRE10 = orig_inv_RieExp_extended(g1.reshape((1,3,3)), g0.reshape((1,3,3)), 1./dim)
print('orig inv_RieExp[0,1]\n',originvRE01)
print('orig inv_RieExp[1,0]\n',originvRE10)


origRE_invRE01 = orig_Rie_Exp_extended(g0.reshape((1,3,3)), originvRE01, 1./dim)
origRE_invRE10 = orig_Rie_Exp_extended(g1.reshape((1,3,3)), originvRE10, 1./dim)
print('Expected RieExp(inv_RieExp)[0,1]\n',g1.reshape((1,3,3)))
print('orig RieExp(inv_RieExp)[0,1]\n',origRE_invRE01)
print('Expected RieExp(inv_RieExp)[1,0]\n',g0.reshape((1,3,3)))
print('orig RieExp(inv_RieExp)[1,0]\n',origRE_invRE10)


g0:
 tensor([[[ 5.2354, -1.4173,  0.0000],
         [-1.4173,  1.5663,  0.0000],
         [ 0.0000,  0.0000, 10.1547]]])
g1:
 tensor([[[ 1.4237, -0.3782,  0.0000],
         [-0.3782,  0.4445,  0.0000],
         [ 0.0000,  0.0000,  2.7570]]])
Clarke!
Clarke!
new inv_RieExp[0,1]
 tensor([[[-4.3587,  1.1625,  0.0000],
         [ 1.1846, -1.2674,  0.0000],
         [ 0.0000,  0.0000, -8.4495]]])
new inv_RieExp[1,0]
 tensor([[[ 3.1507, -0.8073,  0.0000],
         [-0.8466,  0.9128,  0.0000],
         [ 0.0000,  0.0000,  6.0929]]])
Expected RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.4237, -0.3782,  0.0000],
         [-0.3782,  0.4445,  0.0000],
         [ 0.0000,  0.0000,  2.7570]]])
new RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.4196, -0.3972,  0.0000],
         [-0.3809,  0.4516,  0.0000],
         [ 0.0000,  0.0000,  2.7570]]])
Expected RieExp(inv_RieExp)[1,0]
 tensor([[[ 5.2354, -1.4173,  0.0000],
         [-1.4173,  1.5663,  0.0000],
         [ 0.0000,  0.0000, 10.1547]]])
new RieExp(inv_RieExp)[

In [28]:
g0 = torch.zeros((3,3)).double()
print('g0:\n',g0)

newinvREzero = new_inv_RieExp_extended(g0.reshape((1,3,3)), g0.reshape((1,3,3)), 1./dim)
print('new inv_RieExp zero\n',newinvREzero)

newRE_invREzero = new_Rie_Exp_extended(g0.reshape((1,3,3)), newinvREzero, 1./dim)
print('Expected RieExp(inv_RieExp) zero\n',g0.reshape((1,3,3)))
print('new RieExp(inv_RieExp) zero\n',newRE_invREzero)


g0:
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
new inv_RieExp zero
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
Expected RieExp(inv_RieExp) zero
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
new RieExp(inv_RieExp) zero
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])


# View and QC Images

# Determine Why Getting Isolated Large Value Pixels in Estimated Atlas
  1. Is it Karcher mean calculation?
     - Compute Karcher mean of noshape cubics, do the large value pixels appear? Yes
     - Found bug in karcher_mean_scipy where it wasn't calling inv_RieExp_extended_scipy
     - Results look better, but not mean is not fully symmetric.  Why not?
  2. Is it the estimation of the diffeomorphism itself?
     - Inspect estimated diffeomorphisms, any large / unexpected diffeomorphism values?
  3. Is it while computing diffeomorphism acting on metric?
     - Apply identity diffeomorphisms to metrics, do the large value pixels appear?
     - Apply estimated diffeomorphisms to metrics, do the large value pixels appear?



## 1. Check Karcher mean calculation

In [41]:
%%time
comp_atlas = get_karcher_mean(G, 1./dim, mask=mask_union, device=device) # time 58.5 ms
#comp_atlas = get_karcher_mean(G, 1./dim, device=device) # time 213 ms

atlas_inv = inverse_masked(comp_atlas, mask_union) / tens_scale

atlas_lin = np.zeros((6,height,width,depth))
atlas_lin[0] = atlas_inv[:,:,:,0,0].cpu()
atlas_lin[1] = atlas_inv[:,:,:,0,1].cpu()
atlas_lin[2] = atlas_inv[:,:,:,0,2].cpu()
atlas_lin[3] = atlas_inv[:,:,:,1,1].cpu()
atlas_lin[4] = atlas_inv[:,:,:,1,2].cpu()
atlas_lin[5] = atlas_inv[:,:,:,2,2].cpu()
comp_atlas_np = np.transpose(atlas_lin,(3,2,1,0))

RuntimeError: inverse_cpu: For batch 22: U(3,3) is zero, singular U.
Exception raised from batchCheckErrors at /pytorch/aten/src/ATen/native/LinearAlgebraUtils.h:123 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f4f58f3a2f2 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x7f4f58f3767b in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0xe77e73 (0x7f4f95d41e73 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #3: at::native::_inverse_helper_cpu(at::Tensor const&) + 0x1ac (0x7f4f95d4483c in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x169137c (0x7f4f9655b37c in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x1447f69 (0x7f4f96311f69 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #6: at::_inverse_helper(at::Tensor const&) + 0x4b (0x7f4f9631823b in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::native::inverse(at::Tensor const&) + 0x44 (0x7f4f95d39ee4 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x1753770 (0x7f4f9661d770 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x17537bc (0x7f4f9661d7bc in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x1447f69 (0x7f4f96311f69 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::inverse(at::Tensor const&) + 0x4b (0x7f4f9631838b in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x2d8a174 (0x7f4f97c54174 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x2d8a42c (0x7f4f97c5442c in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x18864e9 (0x7f4f967504e9 in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #15: at::Tensor::inverse() const + 0x4b (0x7f4f96755cbb in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x42f75b (0x7f4fa783075b in /home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #17: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5b4316]
frame #18: _PyObject_MakeTpCall + 0x79 (0x4fbf69 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #19: _PyEval_EvalFrameDefault + 0x5022 (0x561e32 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #20: _PyFunction_Vectorcall + 0xc5 (0x4fc555 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #21: _PyEval_EvalFrameDefault + 0x3d6 (0x55d1e6 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #22: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x55c2d5]
frame #23: _PyFunction_Vectorcall + 0x149 (0x4fc5d9 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x3d6 (0x55d1e6 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #25: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x55c2d5]
frame #26: _PyFunction_Vectorcall + 0x149 (0x4fc5d9 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0x1083 (0x55de93 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #28: _PyEval_EvalCodeWithName + 0x3c8 (0x55b378 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #29: PyEval_EvalCode + 0x23 (0x55afa3 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #30: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x643565]
frame #31: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5b4222]
frame #32: _PyEval_EvalFrameDefault + 0x3d6 (0x55d1e6 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #33: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x55c2d5]
frame #34: _PyFunction_Vectorcall + 0x149 (0x4fc5d9 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #35: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5ade61]
frame #36: _PyObject_Call + 0x11b (0x4fd77b in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #37: _PyEval_EvalFrameDefault + 0x2075 (0x55ee85 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #38: _PyFunction_Vectorcall + 0xc5 (0x4fc555 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #39: _PyEval_EvalFrameDefault + 0x674 (0x55d484 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #40: _PyEval_EvalCodeWithName + 0x3c8 (0x55b378 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #41: PyEval_EvalCode + 0x23 (0x55afa3 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #42: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x643565]
frame #43: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5b4222]
frame #44: _PyEval_EvalFrameDefault + 0x3d6 (0x55d1e6 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #45: _PyGen_Send + 0x15b (0x5b2b1b in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #46: _PyEval_EvalFrameDefault + 0x59ec (0x5627fc in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #47: _PyGen_Send + 0x15b (0x5b2b1b in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #48: _PyEval_EvalFrameDefault + 0x59ec (0x5627fc in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #49: _PyGen_Send + 0x15b (0x5b2b1b in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #50: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5b1516]
frame #51: _PyEval_EvalFrameDefault + 0x674 (0x55d484 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #52: _PyFunction_Vectorcall + 0xc5 (0x4fc555 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #53: _PyEval_EvalFrameDefault + 0x3d6 (0x55d1e6 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #54: _PyFunction_Vectorcall + 0xc5 (0x4fc555 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #55: _PyEval_EvalFrameDefault + 0x674 (0x55d484 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #56: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x55c2d5]
frame #57: _PyFunction_Vectorcall + 0x149 (0x4fc5d9 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #58: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5ade61]
frame #59: _PyObject_Call + 0x11b (0x4fd77b in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #60: _PyEval_EvalFrameDefault + 0x2075 (0x55ee85 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #61: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x55c2d5]
frame #62: _PyFunction_Vectorcall + 0x149 (0x4fc5d9 in /home/sci/kris/Software/py3_venv_dhcp19220/bin/python)
frame #63: /home/sci/kris/Software/py3_venv_dhcp19220/bin/python() [0x5ade61]


In [42]:
def dbg_karcher_mean(G, a, mask=None, scale_factor=1.0, filename='', device='cuda:0'):
  size = G.size()
  G = G.reshape(size[0], -1, *size[-2:])  # (T,-1,3,3)
  if mask is not None:
    mask = mask.reshape(-1)  # (-1,1)

  def get_karcher_mean_masked(G, a, scale_factor, filename, device):
      gm = G[0].to(device)

      for i in range(1, G.size(0)):
        #print('logm_invB_A, i', i, 'max gm', torch.max(gm))
        print(i)
        G_i = G[i].to(device)
        #U = logm_invB_A(gm, G_i)
        #U = logm_invB_A(make_pos_def(gm, mask.reshape(-1), 1.0e-10, skip_small_eval=True), G[i])
        #UTrless = U - torch.einsum("...ii,kl->...kl", U, torch.eye(size[-1], dtype=torch.double, device=device)) / size[-1]  # (...,2,2)

        #theta = ((torch.einsum("...ik,...ki->...", UTrless, UTrless) / a).sqrt() / 4 - np.pi)
        #tr_gm_b = torch.einsum("...kii->...k", U)
        #b = torch.einsum("...ik,...kj->...ij",gm, U)
        #bT = b - torch.einsum("...,...ij->...ij", tr_gm_b, gm) * a
        #theta = ((torch.einsum("...ik,...ki->...", bT, bT) / a).sqrt() / 4 - np.pi)

        logm_invgm_gi = fix_logm_invB_A(gm, G_i)
        tr_gm_b = torch.einsum("...kii->...k", logm_invgm_gi)
        b = torch.einsum("...ik,...kj->...ij",gm, logm_invgm_gi)
        bT = b - torch.einsum("...,...ij->...ij", tr_gm_b, gm) * a
        gminv = torch.inverse(gm)
        gmbT = torch.einsum("...ik,...kj->...ij", gminv, bT)
        theta = torch.einsum("...ik,...ki->...", gmbT, gmbT)
        #print('tr(b):',tr_gm_b)
        #print('theta(b):',theta)
    

        #for di in dbgi:
        #  if U.shape[0] > di:
        #    print('\nG[',i,',',di,'] =\n', G[i,di],'\nU[',di,'] =\n', U[di],'\nUTrless[',di,'] =\n', UTrless[di],'\ntheta[',di,'] =\n', theta[di], '\n')
        if filename:
          sitk.WriteImage(sitk.GetImageFromArray(np.transpose(theta.reshape(*size[1:-2]).cpu(),(2,1,0))), filename+'_theta.nhdr')
          U_lin = np.zeros((6,*size[1:4])).cpu()
          #U_inv = torch.inverse(U) / tens_scale
          U_lin[0] = U.reshape(*size[1:])[:,:,:,0,0].cpu()
          U_lin[1] = U.reshape(*size[1:])[:,:,:,0,1].cpu()
          U_lin[2] = U.reshape(*size[1:])[:,:,:,0,2].cpu()
          U_lin[3] = U.reshape(*size[1:])[:,:,:,1,1].cpu()
          U_lin[4] = U.reshape(*size[1:])[:,:,:,1,2].cpu()
          U_lin[5] = U.reshape(*size[1:])[:,:,:,2,2].cpu()
          WriteTensorNPArray(np.transpose(U_lin,(3,2,1,0)), filename+'_U.nhdr')

        #thresh = 0
        thresh = a * (4*np.pi)**2
  
        Ind_inRange = (theta < thresh).nonzero().reshape(-1)  ## G[i] is in the range of the exponential map at gm
        Ind_notInRange = (theta >= thresh).nonzero().reshape(-1)  ## G[i] is not in the range

        # when g1 = 0, len(Ind_notInRange) and len(Ind_inRange) are both zero. So check len(Ind_notInRange) first
        if len(Ind_notInRange) == 0:  # all in the range
            #print('Before Rie_Exp_extended, i', i, 'max gm', torch.max(gm))
            #gm = Rie_Exp_extended(gm, inv_RieExp_extended(gm, G_i, a) / (i + 1), a)
            gm = new_Rie_Exp_extended(gm, new_inv_RieExp_extended(gm, G_i, a), a, 1.0 / (i + 1))
            #print('after Rie_Exp_extended, i', i, 'max gm', torch.max(gm))
        elif len(Ind_inRange) == 0:  # all not in range
            #print('Before ptPick_notInRange, i', i, 'max gm', torch.max(gm))
            print('entering ptPick_notInRange')
            gm = ptPick_notInRange(gm, G_i, logm_invgm_gi, i)
            #print('after ptPick_notInRange, i', i, 'max gm', torch.max(gm))
        else:
            #print('Before Rie_Exp_extended, ptPick_notInRange, i', i, 'max gm', torch.max(gm))
            #gm[Ind_inRange] = Rie_Exp_extended(gm[Ind_inRange],
            #                                   inv_RieExp_extended(gm[Ind_inRange], G_i[Ind_inRange], a) / (i + 1),
            #                                   a)  # stop here
            gm[Ind_inRange] = new_Rie_Exp_extended(gm[Ind_inRange],
                                               new_inv_RieExp_extended(gm[Ind_inRange], G_i[Ind_inRange], a),
                                                                   a, 1.0 / (i + 1))  # stop here
            #print('after Rie_Exp_extended, i', i, 'max gm', torch.max(gm))
            print('entering ptPick_notInRange')
            gm[Ind_notInRange] = ptPick_notInRange(gm[Ind_notInRange], G_i[Ind_notInRange], logm_invgm_gi[Ind_notInRange], i)
#             print('end')
            #print('after ptPick_notInRange, i', i, 'max gm', torch.max(gm))
        #print('get_karcher_mean num zeros', len(torch.where(gm[:] == torch.zeros((size[-2],size[-2])))[0]))
        #print("WARNING! Don't know why need to scale atlas by scale_factor")
        gm[:] = torch.where(gm[:] == torch.zeros((size[-2],size[-2]), device=device),
                       torch.eye(size[-2], dtype=G.dtype, device=device), gm[:])

        #del G_i
        #torch.cuda.empty_cache()
      return(gm)      
  # end def get_karcher_mean_masked

  if mask is None:
    gm = get_karcher_mean_masked(G, a, scale_factor, filename, device)
  else:
    Ind_inRange = (mask > 0.1).nonzero().reshape(-1)
    gm = torch.zeros_like(G[0])
    gm[Ind_inRange] = get_karcher_mean_masked(G[:,Ind_inRange], a, scale_factor, filename, device)

  #return gm.reshape(*size[1:])
  #for di in dbgi:
  #  if gm.shape[0] > di:
  #    print('\ngm[',di,'] =\n', gm[di])

  #gm_cpu = gm.cpu()
  #del gm
  torch.cuda.empty_cache()

  #return gm_cpu.reshape(*size[1:])
  return gm.reshape(*size[1:])
  #return(torch.where(gm[:] == torch.zeros((size[-2],size[-2])),
  #                   scale_factor * torch.eye(size[-2], dtype=G.dtype), scale_factor * gm[:]).reshape(*size[1:]))


In [43]:
# problem batches
batch=246
batch=34
batch=22

print(G.reshape((20,-1,3,3))[:,batch:batch+1,:].shape)
print(mask_union.shape)
mu = mask_union.reshape((-1,1))
#print(mu[246:248].to('cpu'))
print(mask_union.reshape((-1,1))[batch:batch+1] > torch.Tensor([0.1]).to(device))
Ind_inRange = (mask_union.reshape(-1) > 0.1).nonzero().reshape(-1)
G_in = G.reshape((20,-1,3,3))[:,Ind_inRange]
m_in = mask_union.reshape(-1)[Ind_inRange]
comp_atlas = dbg_karcher_mean(G_in[0:10,batch,:].reshape((10,1,3,3)), 1./dim, 
                              mask=m_in[batch].reshape(1,1), device=device)

#g0 = torch.tensor([[[ 1.1746, -0.5826, -0.0000],
#         [-0.5826,  0.6959, -0.0000],
#         [ 0.0000,  0.0000,  1.3303]]])

#g1 = torch.tensor([[[ 1.1213, -0.5516, -0.0000],
#         [-0.5516,  0.6682, -0.0000],
#         [ 0.0000,  0.0000,  1.2674]]])

#g2 = torch.tensor([[[ 1.5406, -0.7950, -0.0000],
#         [-0.7950,  0.8874, -0.0000],
#         [ 0.0000,  0.0000,  1.7625]]])

#g3 = torch.tensor([[[ 1.2569, -0.6345, -0.0000],
#         [-0.6345,  0.7356, -0.0000],
#         [ 0.0000,  0.0000,  1.4299]]])

#g4 = torch.tensor([[[ 1.3523, -0.6979, -0.0000],
#         [-0.6979,  0.7789, -0.0000],
#         [ 0.0000,  0.0000,  1.5471]]])

#comp_atlas = dbg_karcher_mean(torch.stack([g0,g1,g2,g3,g4]).reshape((5,1,3,3)), 1./dim,
#                              mask=m_in[batch].reshape(1,1), device=device)
print(comp_atlas)

torch.Size([20, 1, 3, 3])
torch.Size([27, 60, 60])
tensor([[False]])
1
torch eig
se eig
lamda good
: tensor([[0.7981, 0.7981, 0.8564]])
lamda se
: tensor([[0.8564, 0.7981, 0.7981]])
Q good
: tensor([[[-0.9750,  0.0000,  0.2220],
         [ 0.2220,  0.0000,  0.9750],
         [ 0.0000,  1.0000,  0.0000]]])
Q se
: tensor([[[ 0.2220, -0.0000,  0.9750],
         [ 0.9750, -0.0000, -0.2220],
         [ 0.0000,  1.0000, -0.0000]]])
Using Q, lamda se
lamda good - lamda se:
 tensor([[-5.8308e-02, -1.1824e-09,  5.8308e-02]])
Q good - Q se:
 tensor([[[-1.1971,  0.0000, -0.7530],
         [-0.7530,  0.0000,  1.1971],
         [ 0.0000,  0.0000,  0.0000]]])
W:
 tensor([[[0.8010, 0.0126, 0.0000],
         [0.0126, 0.8535, 0.0000],
         [0.0000, 0.0000, 0.7981]]])
Q lamda Q_T good:
 tensor([[[0.8010, 0.0126, 0.0000],
         [0.0126, 0.8535, 0.0000],
         [0.0000, 0.0000, 0.7981]]])
Q lamda Q_T se:
 tensor([[[0.8010, 0.0126, 0.0000],
         [0.0126, 0.8535, 0.0000],
         [0.0000, 0.00

In [29]:
%%time
shuffle_atlas = get_karcher_mean_shuffle(G, 1./dim, mask=mask_union, device=device)
atlas_lin = np.zeros((6,height,width,depth))
atlas_inv = inverse_masked(shuffle_atlas, mask_union) / tens_scale
atlas_lin[0] = atlas_inv[:,:,:,0,0].cpu()
atlas_lin[1] = atlas_inv[:,:,:,0,1].cpu()
atlas_lin[2] = atlas_inv[:,:,:,0,2].cpu()
atlas_lin[3] = atlas_inv[:,:,:,1,1].cpu()
atlas_lin[4] = atlas_inv[:,:,:,1,2].cpu()
atlas_lin[5] = atlas_inv[:,:,:,2,2].cpu()
shuffle_atlas_np = np.transpose(atlas_lin,(3,2,1,0))

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/sci/kris/Software/py3_venv_dhcp19220/lib/python3.9/site-packages/IPython/core/magics/execution.py", line 1316, in time
    exec(code, glob, local_ns)
  File "<timed exec>", line 1, in <module>
  File "/home/sci/kris/Software/DiffeomorphicMetricMatching/util/SplitEbinMetric3DCuda.py", line 1515, in get_karcher_mean_shuffle
  File "/home/sci/kris/Software/DiffeomorphicMetricMatching/util/SplitEbinMetric3DCuda.py", line 1469, in get_karcher_mean_shuffle_masked
    tr_gm_b = torch.einsum("...kii->...k", logm_invgm_gi)
  File "/home/sci/kris/Software/DiffeomorphicMetricMatching/util/SplitEbinMetric3DCuda.py", line 197, in logm_invB_A
    inv_V = torch.inverse(V)
RuntimeError: inverse_cpu: For batch 11: U(2,2) is zero, singular U.
Exception raised from batchCheckErrors at /pytorch/aten/src/ATen/native/LinearAlgebraUtils.h:123 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f84e88542f2 in /home

In [20]:
%%time
comp_scipy_atlas = get_karcher_mean_scipy(G, 1./dim, mask=mask_union, device=device)
atlas_lin = np.zeros((6,height,width,depth))

atlas_inv = inverse_masked(comp_scipy_atlas, mask_union) / tens_scale
atlas_lin[0] = atlas_inv[:,:,:,0,0].cpu()
atlas_lin[1] = atlas_inv[:,:,:,0,1].cpu()
atlas_lin[2] = atlas_inv[:,:,:,0,2].cpu()
atlas_lin[3] = atlas_inv[:,:,:,1,1].cpu()
atlas_lin[4] = atlas_inv[:,:,:,1,2].cpu()
atlas_lin[5] = atlas_inv[:,:,:,2,2].cpu()
comp_scipy_atlas_np = np.transpose(atlas_lin,(3,2,1,0))

shuffle_scipy_atlas = get_karcher_mean_shuffle_scipy(G, 1./dim, mask=mask_union, device=device)
atlas_lin = np.zeros((6,height,width,depth))
atlas_inv = inverse_masked(shuffle_scipy_atlas, mask_union) / tens_scale
atlas_lin[0] = atlas_inv[:,:,:,0,0].cpu()
atlas_lin[1] = atlas_inv[:,:,:,0,1].cpu()
atlas_lin[2] = atlas_inv[:,:,:,0,2].cpu()
atlas_lin[3] = atlas_inv[:,:,:,1,1].cpu()
atlas_lin[4] = atlas_inv[:,:,:,1,2].cpu()
atlas_lin[5] = atlas_inv[:,:,:,2,2].cpu()
shuffle_scipy_atlas_np = np.transpose(atlas_lin,(3,2,1,0))

CPU times: user 7min 1s, sys: 3min 29s, total: 10min 30s
Wall time: 3min 30s


In [53]:
print(comp_atlas_np[22,60-37-1,13])
print(shuffle_atlas_np[16,60-37-1,13])

[-6.0669059   5.36089705  0.33815651  7.4538336   0.46740723  1.22329408]
[ 1.52246037e+02  8.27568923e+01 -2.15080423e-14  1.98011960e+02
  1.63679672e-14  5.25469516e+01]


In [63]:
# Good pixel
print(G[:,13,60-34-1,22])
print('mean\n',get_karcher_mean(G[:,13,60-34-1,22], 1./dim, device=device))
print('scipy mean\n',get_karcher_mean_scipy(G[:,13,60-34-1,22], 1./dim, device=device))


tensor([[[ 0.6795, -0.4731,  0.0000],
         [-0.4731,  1.1383,  0.0000],
         [ 0.0000,  0.0000,  1.9738]],

        [[ 0.6560, -0.4318,  0.0000],
         [-0.4318,  1.0748,  0.0000],
         [ 0.0000,  0.0000,  1.8510]],

        [[ 0.6479, -0.4184,  0.0000],
         [-0.4184,  1.0538,  0.0000],
         [ 0.0000,  0.0000,  1.8105]],

        [[ 0.7148, -0.5170,  0.0000],
         [-0.5170,  1.2162,  0.0000],
         [ 0.0000,  0.0000,  2.1189]],

        [[ 0.6454, -0.4428,  0.0000],
         [-0.4428,  1.0749,  0.0000],
         [ 0.0000,  0.0000,  1.8606]],

        [[ 0.6289, -0.4550,  0.0000],
         [-0.4550,  1.0702,  0.0000],
         [ 0.0000,  0.0000,  1.8646]],

        [[ 0.6844, -0.4452,  0.0000],
         [-0.4452,  1.1161,  0.0000],
         [ 0.0000,  0.0000,  1.9193]],

        [[ 0.6944, -0.5103,  0.0000],
         [-0.5103,  1.1893,  0.0000],
         [ 0.0000,  0.0000,  2.0761]],

        [[ 1.2585, -1.3817,  0.0000],
         [-1.3817,  2.5985,  0.000

In [18]:
# Error pixel :2, :3 good, :4, :5 ok, :6, :7, :8, :10 bad, 
print(G[:10,13,60-37-1,22])
print('mean\n',get_karcher_mean(G[:10,13,60-37-1,22], 1./dim, device=device))
print('scipy mean\n',get_karcher_mean_scipy(G[:10,13,60-37-1,22], 1./dim, device=device))
for ii in range(10):
  print('shuffle', ii)
  print('mean\n',get_karcher_mean_shuffle(G[:10,13,60-37-1,22], 1./dim, device=device))
  print('scipy mean\n',get_karcher_mean_shuffle_scipy(G[:10,13,60-37-1,22], 1./dim, device=device))


tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]],

        [[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]],

        [[ 1.3123, -1.0994,  0.0000],
         [-1.0994,  1.7898,  0.0000],
         [ 0.0000,  0.0000,  1.9822]],

        [[ 1.5353, -1.3539,  0.0000],
         [-1.3539,  2.1233,  0.0000],
         [ 0.0000,  0.0000,  2.3812]],

        [[ 1.3524, -1.1644,  0.0000],
         [-1.1644,  1.8581,  0.0000],
         [ 0.0000,  0.0000,  2.0716]],

        [[ 1.3527, -1.1931,  0.0000],
         [-1.1931,  1.8709,  0.0000],
         [ 0.0000,  0.0000,  2.0983]],

        [[ 1.3923, -1.1702,  0.0000],
         [-1.1702,  1.9005,  0.0000],
         [ 0.0000,  0.0000,  2.1065]],

        [[ 1.4861, -1.3199,  0.0000],
         [-1.3199,  2.0594,  0.0000],
         [ 0.0000,  0.0000,  2.3136]],

        [[ 1.3383, -1.1266,  0.0000],
         [-1.1266,  1.8276,  0.000

In [22]:
# Error pixel :2, :3 good, :4, :5 ok, :6, :7, :8, :10 bad, 
print(G[:10,13,60-37-1,22])
print('mean\n',get_karcher_mean(G[:10,13,60-37-1,22], 1./dim, device=device))
print('scipy mean\n',get_karcher_mean_scipy(G[:10,13,60-37-1,22], 1./dim, device=device))
for ii in range(10):
  print('shuffle', ii)
  print('mean\n',get_karcher_mean_shuffle(G[:10,13,60-37-1,22], 1./dim, device=device))
  print('scipy mean\n',get_karcher_mean_shuffle_scipy(G[:10,13,60-37-1,22], 1./dim, device=device))


tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]],

        [[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]],

        [[ 1.3123, -1.0994,  0.0000],
         [-1.0994,  1.7898,  0.0000],
         [ 0.0000,  0.0000,  1.9822]],

        [[ 1.5353, -1.3539,  0.0000],
         [-1.3539,  2.1233,  0.0000],
         [ 0.0000,  0.0000,  2.3812]],

        [[ 1.3524, -1.1644,  0.0000],
         [-1.1644,  1.8581,  0.0000],
         [ 0.0000,  0.0000,  2.0716]],

        [[ 1.3527, -1.1931,  0.0000],
         [-1.1931,  1.8709,  0.0000],
         [ 0.0000,  0.0000,  2.0983]],

        [[ 1.3923, -1.1702,  0.0000],
         [-1.1702,  1.9005,  0.0000],
         [ 0.0000,  0.0000,  2.1065]],

        [[ 1.4861, -1.3199,  0.0000],
         [-1.3199,  2.0594,  0.0000],
         [ 0.0000,  0.0000,  2.3136]],

        [[ 1.3383, -1.1266,  0.0000],
         [-1.1266,  1.8276,  0.000

In [17]:
print('logm_invB_A [0,1]\n',logm_invB_A(G[0,0,60-0-1,0].reshape((1,3,3)), G[1,0,60-0-1,0].reshape((1,3,3))))
print('scipy_logm_invB_A [0,1]\n',scipy_logm_invB_A(G[0,0,60-0-1,0].reshape((1,3,3)), G[1,0,60-0-1,0].reshape((1,3,3))))


logm_invB_A [0,1]
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
scipy_logm_invB_A [0,1]
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])


In [73]:
print(G[:2,13,60-37-1,22])
print('logm_invB_A [0,1]\n',scipy_logm_invB_A(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3))))
print('logm_invB_A [1,0]\n',scipy_logm_invB_A(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3))))
K01 = scipy_logm_invB_A(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)))
#K01g0inv = torch.einsum("...ik,...kj->...ij", K01, torch.inverse(G[0,13,60-37-1,22].reshape((1,3,3))))
K01g0inv = torch.einsum("...ij,...jk->...ik", torch.inverse(G[0,13,60-37-1,22].reshape((1,3,3))), K01)
trK01g0inv = torch.einsum("...ii->", K01g0inv)
#trK01g0inv = torch.einsum("...ij,...il,...lm,...jm->", invG0, G[0,13,60-37-1,22].reshape((1,3,3)), invG0, K01)
#expTrK01g0inv = torch.exp(torch.einsum("...ii->", K01g0inv) / 4.0)
expTrK01g0inv = torch.exp(trK01g0inv / 4.0)
#KTrless01 = K01 - torch.einsum("...ii,kl->...kl", K01, torch.eye(dim, dtype=torch.double, device=G.device)) / dim  # (-1,3,3)
#KTrless01 = K01 - torch.einsum("...ii,kl->...kl", K01g0inv, torch.eye(dim, dtype=torch.double, device=G.device)) / dim  # (-1,3,3)
KTrless01 = K01 - trK01g0inv * torch.eye(dim, dtype=torch.double, device=G.device) / dim  # (-1,3,3)
theta01 = ((dim * torch.einsum("...ik,...ki->...", KTrless01, KTrless01)).clamp(min=1.0e-15).sqrt() / 4.).clamp(min=1.0e-15)  # (-1)
gamma01 = torch.det(G[1,13,60-37-1,22].reshape((1,3,3))).pow(1 / 4.) / (torch.det(G[0,13,60-37-1,22].reshape((1,3,3))).clamp(min=1.0e-15).pow(1 / 4.))  # (-1)
#A01 = 4 / dim * (gamma01 * torch.cos(theta01) - 1)  # (-1)
A01 = 4. / dim * (expTrK01g0inv * torch.cos(theta01) - 1)  # (-1)
#B01 = 1 / theta01 * gamma01 * torch.sin(theta01)
B01 = 1. / theta01 * expTrK01g0inv * torch.sin(theta01)
u01A = A01 * torch.einsum("...ij->ij...", G[0,13,60-37-1,22].reshape((1,3,3)))
#u01B = B01 * torch.einsum("...ik,...kj->ij...", G[0,13,60-37-1,22].reshape((1,3,3)), KTrless01) 
#u01A = A01 * G[0,13,60-37-1,22].reshape((1,3,3))
u01B = B01 * torch.einsum("...ik,...kl,...lj->ij...", G[0,13,60-37-1,22].reshape((1,3,3)), KTrless01, G[0,13,60-37-1,22].reshape((1,3,3))) 
u01 = u01A + u01B
#u01 = u01A + torch.einsum("...ij->ij...", u01B)
K10 = scipy_logm_invB_A(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)))
#K10g0inv = torch.einsum("...ik,...kj->...ij", K10, torch.inverse(G[1,13,60-37-1,22].reshape((1,3,3))))
K10g0inv = torch.einsum("...ik,...kj->...ij", torch.inverse(G[1,13,60-37-1,22].reshape((1,3,3))), K10)
trK10g0inv = torch.einsum("...ii->", K10g0inv)
expTrK10g0inv = torch.exp(trK10g0inv / 4.0)
#KTrless10 = K10 - torch.einsum("...ii,kl->...kl", K10, torch.eye(dim, dtype=torch.double, device=G.device)) / dim  # (-1,3,3)
#KTrless10 = K10 - torch.einsum("...ii,kl->...kl", K10g0inv, torch.eye(dim, dtype=torch.double, device=G.device)) / dim  # (-1,3,3)
KTrless10 = K10 - trK10g0inv * torch.eye(dim, dtype=torch.double, device=G.device) / dim  # (-1,3,3)
theta10 = ((dim * torch.einsum("...ik,...ki->...", KTrless10, KTrless10)).clamp(min=1.0e-15).sqrt() / 4.).clamp(min=1.0e-15)  # (-1)
gamma10 = torch.det(G[0,13,60-37-1,22].reshape((1,3,3))).pow(1 / 4.) / (torch.det(G[1,13,60-37-1,22].reshape((1,3,3))).clamp(min=1.0e-15).pow(1 / 4.))  # (-1)
#A10 = 4 / dim * (gamma10 * torch.cos(theta10) - 1)  # (-1)
A10 = 4. / dim * (expTrK10g0inv * torch.cos(theta10) - 1)  # (-1)
#B10 = 1 / theta10 * gamma10 * torch.sin(theta10)
B10 = 1. / theta10 * expTrK10g0inv * torch.sin(theta10)
u10A = A10 * torch.einsum("...ij->ij...", G[1,13,60-37-1,22].reshape((1,3,3)))
u10B = B10 * torch.einsum("...ik,...kl,...lj->ij...", G[1,13,60-37-1,22].reshape((1,3,3)), KTrless10, G[1,13,60-37-1,22].reshape((1,3,3))) 
#u10B = B10 * G[1,13,60-37-1,22].reshape((1,3,3))*KTrless10 
u10 = u10A + u10B
#u10 = u10A + torch.einsum("...ij->ij...",u10B)
print('K[0,1]\n',K01)
print('K[1,0]\n',K10)
print('KTrless[0,1]\n',KTrless01)
print('KTrless[1,0]\n',KTrless10)
print('theta[0,1]\n',theta01)
print('theta[1,0]\n',theta10)
print('gamma[0,1]\n',gamma01)
print('gamma[1,0]\n',gamma10)
print('tr(Kg0inv)[0,1]\n',torch.einsum("...ii->",K01g0inv))
print('tr(Kg0inv)[1,0]\n',torch.einsum("...ii->",K01g0inv))
print('trKg0inv[0,1]\n',trK01g0inv)
print('trKg0inv[1,0]\n',trK01g0inv)
print('A[0,1]\n',A01)
print('A[1,0]\n',A10)
print('B[0,1]\n',B01)
print('B[1,0]\n',B10)
print('uA[0,1]\n',u01A)
print('uA[1,0]\n',u10A)
print('uB[0,1]\n',u01B)
print('uB[1,0]\n',u10B)
print('u[0,1]\n',u01)
print('u[1,0]\n',u10)
print('g0k0g0[0,1]\n',G[0,13,60-37-1,22]@KTrless01@G[0,13,60-37-1,22])
print('g0k0g0[1,0]\n',G[1,13,60-37-1,22]@KTrless10@G[1,13,60-37-1,22])

print('inv_RieExp[0,1]\n',inv_RieExp_extended_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim))
print('inv_RieExp[1,0]\n',inv_RieExp_extended_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim))
print('mean\n',get_karcher_mean_scipy(G[:2,13,60-37-1,22], 1./dim, device=device))
print('mean\n',get_karcher_mean_scipy(torch.stack([G[1,13,60-37-1,22],G[0,13,60-37-1,22]]) , 1./dim, device=device))


tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]],

        [[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
logm_invB_A [0,1]
 tensor([[[-0.0673,  0.0268,  0.0000],
         [ 0.0268, -0.0232,  0.0000],
         [ 0.0000,  0.0000, -0.0800]]])
logm_invB_A [1,0]
 tensor([[[ 0.0666, -0.0274,  0.0000],
         [-0.0274,  0.0240,  0.0000],
         [ 0.0000,  0.0000,  0.0800]]])
K[0,1]
 tensor([[[-0.0673,  0.0268,  0.0000],
         [ 0.0268, -0.0232,  0.0000],
         [ 0.0000,  0.0000, -0.0800]]])
K[1,0]
 tensor([[[ 0.0666, -0.0274,  0.0000],
         [-0.0274,  0.0240,  0.0000],
         [ 0.0000,  0.0000,  0.0800]]])
KTrless[0,1]
 tensor([[[-0.0294,  0.0268,  0.0000],
         [ 0.0268,  0.0146,  0.0000],
         [ 0.0000,  0.0000, -0.0421]]])
KTrless[1,0]
 tensor([[[ 0.0274, -0.0274,  0.0000],
         [-0.0274, -0.0152,  0.0000],
         [ 0.0000,  0.0000,  0

In [49]:
print(torch.einsum("...ii,kl->...kl", K10g0inv, torch.eye(dim, dtype=torch.double, device=G.device)))
print(torch.einsum("...ii,kl->...kl", K10g0inv, torch.eye(dim, dtype=torch.double, device=G.device))/3)


tensor([[[0.1176, 0.0000, 0.0000],
         [0.0000, 0.1176, 0.0000],
         [0.0000, 0.0000, 0.1176]]])
tensor([[[0.0392, 0.0000, 0.0000],
         [0.0000, 0.0392, 0.0000],
         [0.0000, 0.0000, 0.0392]]])


In [17]:
print(G[:2,13,60-37-1,22])
print('logm_invB_A [0,1]\n',scipy_logm_invB_A(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3))))
print('logm_invB_A [1,0]\n',scipy_logm_invB_A(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3))))
#K01 = scipy_logm_invB_A(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)))
#KTrless01 = K01 - torch.einsum("...ii,kl->...kl", K01, torch.eye(dim, dtype=torch.double, device=G.device)) / dim  # (-1,3,3)
#theta01 = ((dim * torch.einsum("...ik,...ki->...", KTrless01, KTrless01)).clamp(min=1.0e-15).sqrt() / 4).clamp(min=1.0e-15)  # (-1)
#gamma01 = torch.det(G[1,13,60-37-1,22].reshape((1,3,3))).pow(1 / 4) / (torch.det(G[0,13,60-37-1,22].reshape((1,3,3))).clamp(min=1.0e-15).pow(1 / 4))  # (-1)
#A01 = 4 / dim * (gamma01 * torch.cos(theta01) - 1)  # (-1)
#B01 = 1 / theta01 * gamma01 * torch.sin(theta01)
#u01A = A01 * torch.einsum("...ij->ij...", G[0,13,60-37-1,22].reshape((1,3,3)))
#u01B = B01 * torch.einsum("...ij,...ij->ij...", G[0,13,60-37-1,22].reshape((1,3,3)), KTrless01) 
##u01B = B01 * G[0,13,60-37-1,22].reshape((1,3,3))*KTrless01 
#u01 = u01A + u01B
##u01 = u01A + torch.einsum("...ij->ij...", u01B)
#K10 = scipy_logm_invB_A(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)))
#KTrless10 = K10 - torch.einsum("...ii,kl->...kl", K10, torch.eye(dim, dtype=torch.double, device=G.device)) / dim  # (-1,3,3)
#theta10 = ((dim * torch.einsum("...ik,...ki->...", KTrless10, KTrless10)).clamp(min=1.0e-15).sqrt() / 4).clamp(min=1.0e-15)  # (-1)
#gamma10 = torch.det(G[0,13,60-37-1,22].reshape((1,3,3))).pow(1 / 4) / (torch.det(G[1,13,60-37-1,22].reshape((1,3,3))).clamp(min=1.0e-15).pow(1 / 4))  # (-1)
#A10 = 4 / dim * (gamma10 * torch.cos(theta10) - 1)  # (-1)
#B10 = 1 / theta10 * gamma10 * torch.sin(theta10)
#u10A = A10 * torch.einsum("...ij->ij...", G[1,13,60-37-1,22].reshape((1,3,3)))
#u10B = B10 * torch.einsum("...ij,...ij->ij...", G[1,13,60-37-1,22].reshape((1,3,3)), KTrless10) 
##u10B = B10 * G[1,13,60-37-1,22].reshape((1,3,3))*KTrless10 
#u10 = u10A + u10B
##u10 = u10A + torch.einsum("...ij->ij...",u10B)
#print('K[0,1]\n',K01)
#print('K[1,0]\n',K10)
#print('KTrless[0,1]\n',KTrless01)
#print('KTrless[1,0]\n',KTrless10)
#print('theta[0,1]\n',theta01)
#print('theta[1,0]\n',theta10)
#print('gamma[0,1]\n',gamma01)
#print('gamma[1,0]\n',gamma10)
#print('A[0,1]\n',A01)
#print('A[1,0]\n',A10)
#print('B[0,1]\n',B01)
#print('B[1,0]\n',B10)
#print('uA[0,1]\n',u01A)
#print('uA[1,0]\n',u10A)
#print('uB[0,1]\n',u01B)
#print('uB[1,0]\n',u10B)
#print('u[0,1]\n',u01)
#print('u[1,0]\n',u10)

scale = 4.0

invRE01 = inv_RieExp_extended_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE10 = inv_RieExp_extended_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,1]\n',invRE01)
print('inv_RieExp[1,0]\n',invRE10)

RE_invRE01 = Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), invRE01 / scale, 1./dim)
RE_invRE10 = Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), invRE10, 1./dim)
print('RieExp[0,1]\n',RE_invRE01)
print('RieExp[1,0]\n',RE_invRE10)

g0inv = torch.inverse(G[0,13,60-37-1,22].reshape((1,3,3)))
U01 = torch.einsum("...ik,...kj->...ij", g0inv, invRE01 / scale)  # (s,t,...,3,3)
Ug0inv01 = torch.einsum("...ik,...kj->...ij", U01, g0inv)  # (s,t,...,3,3)
trU01 = torch.einsum("...ii->...", Ug0inv01)  # (s,t,...)
UTrless01 = U01 - torch.einsum("...,ij->...ij", trU01, torch.eye(dim, dim, dtype=torch.double, device=G.device)) / dim  # (s,t,...,3,3)
q01 = trU01 / 4. + 1  # (-1)
r01 = (dim * torch.einsum("...ik,...ki->...", UTrless01, UTrless01)).clamp(min=1.0e-15).sqrt() / 4  # (-1)
ArctanUtrless01 = (torch.atan2(r01, q01) * torch.einsum("...ij->ij...", UTrless01) / r01.clamp(min=1.0e-15))  # use (2,2,-1) for computation

ExpArctanUtrless01 = torch.nan_to_num(torch.matrix_exp(ArctanUtrless01.permute(2, 0, 1)).permute(1, 2, 0))
RE01 = (q01 ** 2 + r01 ** 2).pow(2 / dim) * torch.einsum("...ik,kj...->ij...", G[0,13,60-37-1,22].reshape((1,3,3)), 
                                                         ExpArctanUtrless01)
#RE01 = (q01 ** 2 + r01 ** 2).pow(2 / dim) * torch.einsum("...ik,kl...,...lj->ij...", G[0,13,60-37-1,22].reshape((1,3,3)), 
#                                                         ExpArctanUtrless01, G[0,13,60-37-1,22].reshape((1,3,3)))

print('U[0,1]\n',U01)
print('trU[0,1]\n',trU01)
print('UTrless[0,1]\n',UTrless01)
print('q[0,1]\n',q01)
print('r[0,1]\n',r01)
print('AtanUtrless[0,1]\n',ArctanUtrless01)
print('ExpAtanUtrless[0,1]\n',ExpArctanUtrless01)
print('RieExp[0,1]\n',RE01)


print('mean\n',get_karcher_mean_scipy(G[:2,13,60-37-1,22], 1./dim, device=device))
print('mean\n',get_karcher_mean_scipy(torch.stack([G[1,13,60-37-1,22],G[0,13,60-37-1,22]]) , 1./dim, device=device))


tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]],

        [[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
logm_invB_A [0,1]
 tensor([[[-0.0673,  0.0268,  0.0000],
         [ 0.0268, -0.0232,  0.0000],
         [ 0.0000,  0.0000, -0.0800]]])
logm_invB_A [1,0]
 tensor([[[ 0.0666, -0.0274,  0.0000],
         [-0.0274,  0.0240,  0.0000],
         [ 0.0000,  0.0000,  0.0800]]])
inv_RieExp[0,1]
 tensor([[[-0.0904,  0.0954,  0.0000],
         [ 0.0954, -0.0785,  0.0000],
         [ 0.0000,  0.0000, -0.1458]]])
inv_RieExp[1,0]
 tensor([[[ 0.0875, -0.0940,  0.0000],
         [-0.0940,  0.0715,  0.0000],
         [ 0.0000,  0.0000,  0.1415]]])
RieExp[0,1]
 tensor([[[ 1.4123, -1.2199,  0.0000],
         [-1.2199,  1.9555,  0.0000],
         [ 0.0000,  0.0000,  2.1692]]])
RieExp[1,0]
 tensor([[[ 1.4319, -1.2299,  0.0000],
         [-1.2299,  1.9096,  0.0000],
         [ 0.0

### Confirm identities: 
  - RieExp(g0, 0) = g0
  - Inv_RieExp(g0, g0) = 0  


In [16]:
def dbg_inv_RieExp_extended_scipy(g0, g1, a):  # g0, g1: (s,t,...,3,3)
    size = g0.size()
    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        u = g1 * 4 / g0.size(-1)
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        u = dbg_inv_RieExp_scipy(g0, g1, a)
    else:
        u = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        u[Ind_g0_is0] = g1[Ind_g0_is0] * 4 / g0.size(-1)
        u[Ind_g0_isnot0] = dbg_inv_RieExp_scipy(g0[Ind_g0_isnot0], g1[Ind_g0_isnot0], a)
    return u.reshape(size)

def dbg_inv_RieExp_scipy(g0, g1, a):  # g0,g1: two tensors of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the inverse Riemannian exponential of g1 in the image of the maximal domain of the Riemannian exponential at g0
    '''
    n = g1.size(-1)
    #inv_g0_g1 = make_pos_def(torch.einsum("...ik,...kj->...ij", torch.inverse(g0), g1),None, 1.0e-10)  # (s,t,...,3,3)
    logm_invg0_g1 = scipy_logm_invB_A(g0, g1)
    tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
    b = torch.einsum("...ik,...kj->...ij",g0, logm_invg0_g1)
    bT = b - torch.einsum("...,...ij->...ij", tr_g0_b, g0) * a
    
    def get_u_g0direction(g0, logm_invg0_g1):  # (-1,3,3) first reshape g0,g1,inv_g..
        #         permute
        #inv_g0_g1 = torch.einsum("...ij->ij...", inv_g0_g1)  # (3,3,-1)
        #s = inv_g0_g1[0, 0].clamp(min=1.0e-15)  # (-1)
        #u = 4 / n * (s ** (n / 4) - 1) * torch.einsum("...ij->ij...", g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
        u = 4 / n * (torch.exp(tr_g0_b / 4.0) - 1) * g0
        
        #return u.permute(2, 0, 1)  # (-1,3,3)
        return u  # (-1,3,3)

    def get_u_ng0direction(g0, g1, logm_invg0_g1, bT, a):  # (-1,3,3) first reshape g0,g1,inv_g..
        det_threshold=1e-11
        where_below = torch.where(torch.det(g0)<=det_threshold)
        num_below = len(where_below[0])
        if num_below > 0:
          print('inv_RieExp num det(g0) below thresh:', num_below)
        #K = scipy_logm_invB_A(g0, g1)
        tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
        expTrg0invb = torch.exp(tr_g0_b / 4.0)
        #         AA^T
        g0inv = torch.inverse(g0)
        g0bT = torch.einsum("...ik,...kj->...ij", g0inv, bT)
        theta = ((1. / a * torch.einsum("...ik,...ki->...", g0bT, g0bT)).clamp(min=1.0e-15).sqrt() / 4.).clamp(min=1.0e-15)  # (-1)

        A = 4. / n * (expTrg0invb * torch.cos(theta) - 1)  # (-1)
        B = 1. / theta * expTrg0invb * torch.sin(theta)
        # Clarke
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ij->ij...", bT)  # (-1)@(3,3,-1) -> (3,3,-1)
        # Kris
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", bT, g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        #u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kl,...lj->ij...", g0inv, bT, g0)  # (-1)@(3,3,-1) -> (3,3,-1)
        # Kris trying to use conversion between GMM and Clarke, BT = g0H0
        u = A * torch.einsum("...ij->ij...", g0) + B * torch.einsum("...ik,...kj->ij...", g0inv, bT)  # (-1)@(3,3,-1) -> (3,3,-1)

        return u.permute(2, 0, 1)  # (-1,3,3)

    #inv_g0_g1_trless = inv_g0_g1 - torch.einsum("...ii,kl->...kl", inv_g0_g1, torch.eye(n, dtype=torch.double, device=g0.device))  # (s,t,...,2,2)
    #norm0 = torch.einsum("...ij,...ij->...", inv_g0_g1_trless, inv_g0_g1_trless).reshape(-1)  # (-1)
    norm0 = torch.einsum("...ij,...ij->...", bT, bT).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)  # using squeeze results in [1,1]->[]
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    u = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double, device=g0.device)  # (-1,3,3)
    if len(Indn0) == 0:
        u = get_u_g0direction(g0.reshape(-1, n, n), logm_invg0_g1.reshape(-1, n, n))
    elif len(Ind0) == 0:
        u = get_u_ng0direction(g0.reshape(-1, n, n), g1.reshape(-1, n, n), 
                               logm_invg0_g1.reshape(-1, n, n), bT.reshape(-1, n, n), a)
    else:
        u[Ind0] = get_u_g0direction(g0.reshape(-1, n, n)[Ind0], logm_invg0_g1.reshape(-1, n, n)[Ind0])
        u[Indn0] = get_u_ng0direction(g0.reshape(-1, n, n)[Indn0], g1.reshape(-1, n, n)[Indn0], 
                                      logm_invg0_g1.reshape(-1, n, n)[Indn0], bT.reshape(-1, n, n)[Indn0], a)

    return u.reshape(g1.size())

def dbg_Rie_Exp(g0, u, a, t=1.0):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3), where g0\neq 0
    '''this function is to calculate the Riemannian exponential of u in the the maximal domain of the Riemannian exponential at g0
    '''
    n = g0.size(-1)
    g0inv = torch.inverse(g0)
    U = torch.einsum("...ik,...kj->...ij", g0inv, u)  # (s,t,...,3,3)
    #Ug0inv = torch.einsum("...ik,...kj->...ij", U, g0inv)
    trU = torch.einsum("...ii->...", U)  # (s,t,...)
    # GMM UTrless = H0
    #UTrless = U - torch.einsum("...,ij->...ij", trU, torch.eye(n, n, dtype=torch.double, device=g0.device)) / n  # (s,t,...,3,3)
    # Clarke version of UTrless = BT   
    UTrless = u - torch.einsum("...,...ij->...ij", trU, g0) / n  # (s,t,...,3,3)

    #     in g0 direction:K_0=0
    def get_g1_g0direction(g0, trU, t):  # first reshape g0 (-1,3,3) and trU (-1)
        g1 = (t * trU / 4. + 1).pow(4. / n) * torch.einsum("...ij->ij...", g0)  # (3,3,-1)
        return g1.permute(2, 0, 1)  # (-1,3,3)

    #     not in g0 direction SplitEbinMetric.pdf Theorem 1 :K_0\not=0
    def get_g1_ng0direction(g0, trU, UTrless, a, t):  # first reshape g0,UTrless (-1,3,3) and trU (-1)
        if len((trU < -4).nonzero().reshape(-1)) != 0:
            warnings.warn('The tangent vector u is out of the maximal domain of the Riemannian exponential.', DeprecationWarning)

        # GMM and Clarke q match    
        q = t * trU / 4. + 1  # (-1)
        g0UTrless = torch.einsum("...ik,...kj->...ij", torch.inverse(g0), UTrless)
        # GMM r
        #r = t * (1. / a * torch.einsum("...ik,...ki->...", UTrless, UTrless)).clamp(min=1.0e-15).sqrt() / 4.  # (-1)
        # Clarke r
        r = t * (1. / a * torch.einsum("...ik,...ki->...", g0UTrless, g0UTrless)).clamp(min=1.0e-15).sqrt() / 4.  # (-1)
        
        #ArctanUtrless = (torch.atan2(r, q) * torch.einsum("...ij->ij...", UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        # GMM Arctan
        #ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ij->ij...", UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        # Clarke Arctan
        ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ij->ij...", g0UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        # Kris Arctan
        #ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ik,...kj->ij...", g0UTrless, g0) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation
        #ArctanUtrless = t * (torch.atan2(r, q) * torch.einsum("...ik,...kj->ij...", g0inv, g0UTrless) / r.clamp(min=1.0e-15))  # use (2,2,-1) for computation

        ExpArctanUtrless = torch.nan_to_num(torch.matrix_exp(ArctanUtrless.permute(2, 0, 1)).permute(1, 2, 0))
        ExpArctanUtrless[torch.abs(ExpArctanUtrless) > 1e12] = 0

        # GMM, Clarke At
        At = 2./n * torch.einsum("...,ij->ij...", torch.log(q**2 + r**2), torch.eye(n, n, dtype=torch.double, device=g0.device))
        # Kris At
        #At = 2./n * torch.log(q**2 + r**2) * torch.einsum("...ij->ij...", g0)
        ExpAt = torch.nan_to_num(torch.matrix_exp(At.permute(2, 0, 1)).permute(1, 2, 0))
        ExpAt[torch.abs(ExpAt) > 1e12] = 0
        # GMM, Clarke g1
        g1 = (q ** 2 + r ** 2).pow(2. / n) * torch.einsum("...ik,kj...->ij...", g0, ExpArctanUtrless)  # (2,2,-1)
        #g1 = torch.einsum("...ik,kl...,lj...->ij...", g0, ExpAt, ExpArctanUtrless)  # (2,2,-1)
        # Kris g1
        #g1 = torch.einsum("...ik,kl...,lm...,...mj->ij...", g0, ExpAt, ExpArctanUtrless,g0)  # (2,2,-1)
        #g1 = torch.einsum("kl...,lm...->km...", ExpAt, ExpArctanUtrless)  # (2,2,-1)

        return g1.permute(2, 0, 1)  # (-1,2,2)

    #     pointwise multiplication Tr(U^TU)
    #UMinusTrU = U - torch.einsum("...,ij->...ij", trU, torch.eye(n, n, dtype=torch.double, device=g0.device))  # (s,t,...,3,3)

    #norm0 = torch.einsum("...ij,...ij->...", UMinusTrU, UMinusTrU).reshape(-1)  # (-1)
    norm0 = torch.einsum("...ij,...ij->...", UTrless, UTrless).reshape(-1)  # (-1)

    # find the indices for which the entries are 0s and non0s
    #     k_0=0 or \not=0
    Ind0 = (norm0 <= 1e-12).nonzero().reshape(-1)
    Indn0 = (norm0 > 1e-12).nonzero().reshape(-1)

    g1 = torch.zeros(g0.reshape(-1, n, n).size(), dtype=torch.double, device=g0.device)  # (-1,2,2)
    if len(Indn0) == 0:
        g1 = get_g1_g0direction(g0.reshape(-1, n, n), trU.reshape(-1), t)
    elif len(Ind0) == 0:
        g1 = get_g1_ng0direction(g0.reshape(-1, n, n), trU.reshape(-1), UTrless.reshape(-1, n, n), a, t)
    else:
        g1[Ind0] = get_g1_g0direction(g0.reshape(-1, n, n)[Ind0], trU.reshape(-1)[Ind0], t)
        g1[Indn0] = get_g1_ng0direction(g0.reshape(-1, n, n)[Indn0], trU.reshape(-1)[Indn0], UTrless.reshape(-1, n, n)[Indn0], a, t)
    return g1.reshape(g0.size())



''' 
The following Riemannian exponential and inverse Riemannian exponential are extended to the case g0=0 
'''
def dbg_Rie_Exp_extended(g0, u, a, t=1):  # here g0 is of size (s,t,...,3,3) and u is of size (s,t,...,3,3)
    size = g0.size()
    g0, u = g0.reshape(-1, *size[-2:]), u.reshape(-1, *size[-2:])  # (-1,3,3)
    detg0 = torch.det(g0)

    Ind_g0_is0 = (detg0 == 0).nonzero().reshape(-1)
    Ind_g0_isnot0 = (detg0 != 0).nonzero().reshape(-1)

    if len(Ind_g0_isnot0) == 0:  # g0x are 0s for all x
        g1 = u * g0.size(-1) / 4
    elif len(Ind_g0_is0) == 0:  # g0x are PD for all x
        g1 = dbg_Rie_Exp(g0, u, a, t)
    else:
        g1 = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        g1[Ind_g0_is0] = u[Ind_g0_is0] * g0.size(-1) / 4
        g1[Ind_g0_isnot0] = dbg_Rie_Exp(g0[Ind_g0_isnot0], u[Ind_g0_isnot0], a, t)
    
    return g1.reshape(size)




In [17]:
e = torch.exp(torch.tensor(1))
ident = torch.eye(3).reshape((1,3,3))
invREIde = dbg_inv_RieExp_extended_scipy(ident,e * ident, 1./dim)
print('e*Id\n',e*ident)
print('invRieExp(Id, e*Id)\n',invREIde)
RE_invREIde = dbg_Rie_Exp_extended(ident,invREIde, 1./dim)
print('RieExp(invRieExp(Id, e*Id))\n',RE_invREIde)

e = torch.exp(torch.tensor(1))
ident = torch.eye(3).reshape((1,3,3))
invREeId = dbg_inv_RieExp_extended_scipy(e * ident,ident, 1./dim)
print('e*Id\n',e*ident)
# expect to get u = 4/3 * (e-(3/4)-1)*e*Id
expectu = 4./dim * (e**(-dim/4.) - 1.)*e*ident
print('Expect invRieExp(e*Id, Id)\n',expectu)
print('invRieExp(e*Id, Id)\n',invREeId)
# Should get Identity for both of following
print('Expect RieExp(Expect invRieExp(e*Id, Id))\n',ident)
RE_expectu = dbg_Rie_Exp_extended(e*ident,expectu, 1./dim)
print('RieExp(Expect invRieExp(e*Id, Id))\n',RE_expectu)
RE_invREeId = dbg_Rie_Exp_extended(e*ident,invREeId, 1./dim)
print('RieExp(invRieExp(e*Id, Id))\n',RE_invREeId)

RE_invRE00 = dbg_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
RE_invRE11 = dbg_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('RieExp[0,0]\n',RE_invRE00)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('RieExp[1,1]\n',RE_invRE11)


invRE00 = dbg_inv_RieExp_extended_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE11 = dbg_inv_RieExp_extended_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,0]\n',invRE00)
print('inv_RieExp[1,1]\n',invRE11)

invRE01 = dbg_inv_RieExp_extended_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE10 = dbg_inv_RieExp_extended_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,1]\n',invRE01)
print('inv_RieExp[1,0]\n',invRE10)


RE_invRE01 = dbg_Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), invRE01, 1./dim)
RE_invRE10 = dbg_Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), invRE10, 1./dim)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('RieExp(inv_RieExp)[0,1]\n',RE_invRE01)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('RieExp(inv_RieExp)[1,0]\n',RE_invRE10)






e*Id
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
invRieExp(Id, e*Id)
 tensor([[[1.4893, 0.0000, 0.0000],
         [0.0000, 1.4893, 0.0000],
         [0.0000, 0.0000, 1.4893]]])
RieExp(invRieExp(Id, e*Id))
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
e*Id
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
Expect invRieExp(e*Id, Id)
 tensor([[[-1.9123, -0.0000, -0.0000],
         [-0.0000, -1.9123, -0.0000],
         [-0.0000, -0.0000, -1.9123]]])
invRieExp(e*Id, Id)
 tensor([[[-1.9123, -0.0000, -0.0000],
         [-0.0000, -1.9123, -0.0000],
         [-0.0000, -0.0000, -1.9123]]])
Expect RieExp(Expect invRieExp(e*Id, Id))
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
RieExp(Expect invRieExp(e*Id, Id))
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
RieExp(i

In [30]:
# Results from GMM Implementation:
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.2946, -1.1249,  0.0000],
         [-1.1249,  1.8516,  0.0000],
         [ 0.0000,  0.0000,  1.9788]]])
g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
RieExp(inv_RieExp)[1,0]
 tensor([[[ 1.4831, -1.2529,  0.0000],
         [-1.2529,  1.9600,  0.0000],
         [ 0.0000,  0.0000,  2.2578]]])
    
# Results from Clarke Implementation:
g1
 tensor([[[ 1.3418, -1.1340,  0.0000],
         [-1.1340,  1.8344,  0.0000],
         [ 0.0000,  0.0000,  2.0359]]])
RieExp(inv_RieExp)[0,1]
 tensor([[[ 1.2946, -1.1249,  0.0000],
         [-1.1249,  1.8516,  0.0000],
         [ 0.0000,  0.0000,  1.9788]]])
g0
 tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
RieExp(inv_RieExp)[1,0]
 tensor([[[ 1.4831, -1.2529,  0.0000],
         [-1.2529,  1.9600,  0.0000],
         [ 0.0000,  0.0000,  2.2578]]])

tensor(0.4741)

In [30]:
e = torch.exp(torch.tensor(1))
ident = torch.eye(3).reshape((1,3,3))
invREIde = inv_RieExp_extended_scipy(ident,e * ident, 1./dim)
print('e*Id\n',e*ident)
print('invRieExp(Id, e*Id)\n',invREIde)
RE_invREIde = Rie_Exp_extended(ident,invREIde, 1./dim)
print('RieExp(invRieExp(Id, e*Id))\n',RE_invREIde)

e = torch.exp(torch.tensor(1))
ident = torch.eye(3).reshape((1,3,3))
invREeId = inv_RieExp_extended_scipy(e * ident,ident, 1./dim)
print('e*Id\n',e*ident)
# expect to get u = 4/3 * (e-(3/4)-1)*e*Id
expectu = 4./dim * (e**(-dim/4.) - 1.)*e*ident
print('Expect invRieExp(e*Id, Id)\n',expectu)
print('invRieExp(e*Id, Id)\n',invREeId)
# Should get Identity for both of following
print('Expect RieExp(Expect invRieExp(e*Id, Id))\n',ident)
RE_expectu = Rie_Exp_extended(e*ident,expectu, 1./dim)
print('RieExp(Expect invRieExp(e*Id, Id))\n',RE_expectu)
RE_invREeId = Rie_Exp_extended(e*ident,invREeId, 1./dim)
print('RieExp(invRieExp(e*Id, Id))\n',RE_invREeId)

RE_invRE00 = Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
RE_invRE11 = Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), torch.zeros((1,3,3)), 1./dim)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('RieExp[0,0]\n',RE_invRE00)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('RieExp[1,1]\n',RE_invRE11)


invRE00 = inv_RieExp_extended_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE11 = inv_RieExp_extended_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,0]\n',invRE00)
print('inv_RieExp[1,1]\n',invRE11)

invRE01 = inv_RieExp_extended_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim)
invRE10 = inv_RieExp_extended_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim)
print('inv_RieExp[0,1]\n',invRE01)
print('inv_RieExp[1,0]\n',invRE10)


RE_invRE01 = Rie_Exp_extended(G[0,13,60-37-1,22].reshape((1,3,3)), invRE01, 1./dim)
RE_invRE10 = Rie_Exp_extended(G[1,13,60-37-1,22].reshape((1,3,3)), invRE10, 1./dim)
print('g1\n',G[1,13,60-37-1,22].reshape((1,3,3)))
print('RieExp(inv_RieExp)[0,1]\n',RE_invRE01)
print('g0\n',G[0,13,60-37-1,22].reshape((1,3,3)))
print('RieExp(inv_RieExp)[1,0]\n',RE_invRE10)



e*Id
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
invRieExp(Id, e*Id)
 tensor([[[1.4893, 0.0000, 0.0000],
         [0.0000, 1.4893, 0.0000],
         [0.0000, 0.0000, 1.4893]]])
RieExp(invRieExp(Id, e*Id))
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
e*Id
 tensor([[[2.7183, 0.0000, 0.0000],
         [0.0000, 2.7183, 0.0000],
         [0.0000, 0.0000, 2.7183]]])
Expect invRieExp(e*Id, Id)
 tensor([[[-1.9123, -0.0000, -0.0000],
         [-0.0000, -1.9123, -0.0000],
         [-0.0000, -0.0000, -1.9123]]])
invRieExp(e*Id, Id)
 tensor([[[-1.9123, -0.0000, -0.0000],
         [-0.0000, -1.9123, -0.0000],
         [-0.0000, -0.0000, -1.9123]]])
Expect RieExp(Expect invRieExp(e*Id, Id))
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
RieExp(Expect invRieExp(e*Id, Id))
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
RieExp(i

In [107]:
GG0 = batch_cholesky(G[0,13,60-37-1,22].reshape((1,3,3)))
nonpsd_idx = torch.where(torch.isnan(GG0))
if len(nonpsd_idx[0]) > 0:
  print(len(nonpsd_idx[0]), 'non psd entries found in logm_invB_A', nonpsd_idx)
for i in range(len(nonpsd_idx[0])):
  GG0[nonpsd_idx[0][i]] = torch.eye((3)).double().to(device=G.device)


# KMC comment out following 4 lines and see if pseudo inverse sufficient instead
det_G = torch.det(GG0)
inv_G = torch.zeros_like(GG0)
inv_G[det_G>0.] = torch.pinverse(GG0[det_G>0.])
inv_G[det_G<=0.] = torch.eye((3)).double().to(device=G.device)
#inv_G = torch.pinverse(GG0)    
W = torch.einsum("...ij,...jk,...lk->...il", inv_G, G[1,13,60-37-1,22], inv_G)

#W_sym = (W + torch.transpose(W,len(W.shape)-2,len(W.shape)-1))/2
# The eigenvector computation becomes inaccurate for matrices close to identity
# Set those closer than float machine eps from identity matrix to identity
#W[(torch.abs(W[:]-torch.eye((3)))<1.1921e-7).sum(dim=(1,2))<9] = torch.eye((3))
lamda, Q = se.apply(W)

print(lamda)
print(Q)
lamda[0][1]==lamda[0][2]
print(torch.symeig(W,eigenvectors=True))

tensor([[0.9895, 0.9231, 0.9231]])
tensor([[[ 0.4273, -0.0000, -0.9041],
         [ 0.9041, -0.0000,  0.4273],
         [ 0.0000,  1.0000,  0.0000]]])
torch.return_types.symeig(
eigenvalues=tensor([[0.9231, 0.9231, 0.9895]]),
eigenvectors=tensor([[[-0.9041,  0.0000,  0.4273],
         [ 0.4273,  0.0000,  0.9041],
         [ 0.0000,  1.0000,  0.0000]]]))


In [21]:
# before change
geo01=get_geo(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[0,13,60-37-1,22].reshape((1,3,3)))
print(geo01)
print(G[1,13,60-37-1,22].reshape((1,3,3)))
geo01s=get_geo_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[0,13,60-37-1,22].reshape((1,3,3)))
print(geo01s)
print(G[1,13,60-37-1,22].reshape((1,3,3)))

geo10=get_geo(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[1,13,60-37-1,22].reshape((1,3,3)))
print(geo10)
print(G[0,13,60-37-1,22].reshape((1,3,3)))
geo10s=get_geo_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[1,13,60-37-1,22].reshape((1,3,3)))
print(geo10s)
print(G[0,13,60-37-1,22].reshape((1,3,3)))


print('mean[0,1]\n',geo01[5])
print('mean scipy[0,1]\n',geo01s[5])
print('mean[1,0]\n',geo10[5])
print('mean scipy[1,0]\n',geo10s[5])

tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
tensor([[[[ 1.4348e+00, -1.2436e+00,  0.0000e+00],
          [-1.2436e+00,  1.9749e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  2.2054e+00]]],


        [[[ 1.4253e+00, -1.2325e+00,  1.9509e-18],
          [-1.2325e+00,  1.9606e+00, -2.4201e-18],
          [ 1.4451e-18, -1.7926e-18,  2.1882e+00]]],


        [[[ 1.4159e+00, -1.2214e+00,  3.8877e-18],
          [-1.2214e+00,  1.9464e+00, -4.8226e-18],
          [ 2.8797e-18, -3.5722e-18,  2.1711e+00]]],


        [[[ 1.4065e+00, -1.2104e+00,  5.8101e-18],
          [-1.2104e+00,  1.9322e+00, -7.2073e-18],
          [ 4.3037e-18, -5.3386e-18,  2.1540e+00]]],


        [[[ 1.3971e+00, -1.1993e+00,  7.7183e-18],
          [-1.1993e+00,  1.9181e+00, -9.5743e-18],
          [ 5.7171e-18, -7.0919e-18,  2.1369e+00]]],


        [[[ 1.3878e+00, -1.1884e+00,  9.6121e-18],
          [-1.1884e+00,  1.9040e+00, -1.1924e-17

In [31]:
# after change
geo01=get_geo(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[0,13,60-37-1,22].reshape((1,3,3)))
print(geo01)
print(G[1,13,60-37-1,22].reshape((1,3,3)))
geo01s=get_geo_scipy(G[0,13,60-37-1,22].reshape((1,3,3)), G[1,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[0,13,60-37-1,22].reshape((1,3,3)))
print(geo01s)
print(G[1,13,60-37-1,22].reshape((1,3,3)))

geo10=get_geo(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[1,13,60-37-1,22].reshape((1,3,3)))
print(geo10)
print(G[0,13,60-37-1,22].reshape((1,3,3)))
geo10s=get_geo_scipy(G[1,13,60-37-1,22].reshape((1,3,3)), G[0,13,60-37-1,22].reshape((1,3,3)), 1./dim,11)
print(G[1,13,60-37-1,22].reshape((1,3,3)))
print(geo10s)
print(G[0,13,60-37-1,22].reshape((1,3,3)))


print('mean[0,1]\n',geo01[5])
print('mean scipy[0,1]\n',geo01s[5])
print('mean[1,0]\n',geo10[5])
print('mean scipy[1,0]\n',geo10s[5])

tensor([[[ 1.4348, -1.2436,  0.0000],
         [-1.2436,  1.9749,  0.0000],
         [ 0.0000,  0.0000,  2.2054]]])
tensor([[[[ 1.4348e+00, -1.2436e+00,  0.0000e+00],
          [-1.2436e+00,  1.9749e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  2.2054e+00]]],


        [[[ 1.4286e+00, -1.2334e+00,  6.5474e-19],
          [-1.2334e+00,  1.9643e+00, -8.1218e-19],
          [ 6.5474e-19, -8.1218e-19,  2.1909e+00]]],


        [[[ 1.4224e+00, -1.2233e+00,  1.3037e-18],
          [-1.2233e+00,  1.9537e+00, -1.6172e-18],
          [ 1.3037e-18, -1.6172e-18,  2.1765e+00]]],


        [[[ 1.4164e+00, -1.2131e+00,  1.9469e-18],
          [-1.2131e+00,  1.9433e+00, -2.4151e-18],
          [ 1.9469e-18, -2.4151e-18,  2.1621e+00]]],


        [[[ 1.4104e+00, -1.2031e+00,  2.5844e-18],
          [-1.2031e+00,  1.9329e+00, -3.2059e-18],
          [ 2.5844e-18, -3.2059e-18,  2.1479e+00]]],


        [[[ 1.4045e+00, -1.1931e+00,  3.2163e-18],
          [-1.1931e+00,  1.9226e+00, -3.9897e-18

In [48]:
def dbg_get_karcher_mean_scipy(G, a, mask=None, scale_factor=1.0, filename='', device='cuda:0'):
    size = G.size()
    G = G.reshape(size[0], -1, *size[-2:])  # (T,-1,3,3)
    gm = torch.clone(G[0]).to(device)
    for i in range(1, G.size(0)):
        G_i = G[i].to(device)
        U = scipy_logm_invB_A(gm, G_i)
        #U = logm_invB_A(make_pos_def(gm, mask.reshape(-1), 1.0e-10, skip_small_eval=True), G[i])
        #UTrless = U - torch.einsum("...ii,kl->...kl", U, torch.eye(size[-1], dtype=torch.double, device=device)) / size[-1]  # (...,2,2)

        #theta = ((torch.einsum("...ik,...ki->...", UTrless, UTrless) / a).sqrt() / 4 - np.pi)
        tr_gm_b = torch.einsum("...kii->...k", U)
        b = torch.einsum("...ik,...kj->...ij",gm, U)
        bT = b - torch.einsum("...,...ij->...ij", tr_gm_b, gm) * a
        theta = ((torch.einsum("...ik,...ki->...", bT, bT) / a).sqrt() / 4 - np.pi)

        if filename:
          sitk.WriteImage(sitk.GetImageFromArray(np.transpose(theta.reshape(*size[1:-2]).cpu(),(2,1,0))), filename+'_theta.nhdr')
          U_lin = np.zeros((6,*size[1:4])).cpu()
          #U_inv = torch.inverse(U) / tens_scale
          U_lin[0] = U.reshape(*size[1:])[:,:,:,0,0].cpu()
          U_lin[1] = U.reshape(*size[1:])[:,:,:,0,1].cpu()
          U_lin[2] = U.reshape(*size[1:])[:,:,:,0,2].cpu()
          U_lin[3] = U.reshape(*size[1:])[:,:,:,1,1].cpu()
          U_lin[4] = U.reshape(*size[1:])[:,:,:,1,2].cpu()
          U_lin[5] = U.reshape(*size[1:])[:,:,:,2,2].cpu()
          WriteTensorNPArray(np.transpose(U_lin,(3,2,1,0)), filename+'_U.nhdr')
        Ind_inRange = (theta < 0).nonzero().reshape(-1)  ## G[i] is in the range of the exponential map at gm
        Ind_notInRange = (theta >= 0).nonzero().reshape(-1)  ## G[i] is not in the range

        # when g1 = 0, len(Ind_notInRange) and len(Ind_inRange) are both zero. So check len(Ind_notInRange) first
        if len(Ind_notInRange) == 0:  # all in the range
            gm = Rie_Exp_extended(gm, inv_RieExp_extended_scipy(gm, G_i, a), a, 1.0 / (i + 1))
        elif len(Ind_inRange) == 0:  # all not in range
            print('None in range, for index,', i)
            #gm = ptPick_notInRange(gm, G_i, i)
            gm = Rie_Exp_extended(gm, inv_RieExp_extended_scipy(gm, G_i, a), a, 1.0 / (i + 1))
        else:
            print('Some not in range, for index,', i)
            print("Ind not in Range[0]:", Ind_notInRange[0])
            
            gm[Ind_inRange] = Rie_Exp_extended(gm[Ind_inRange],
                                               inv_RieExp_extended_scipy(gm[Ind_inRange], G_i[Ind_inRange], a),
                                               a, 1.0 / (i + 1))  # stop here
            #gm[Ind_notInRange] = ptPick_notInRange(gm[Ind_notInRange], G_i[Ind_notInRange], i)
            gm[Ind_notInRange] = Rie_Exp_extended(gm[Ind_notInRange],
                                               inv_RieExp_extended_scipy(gm[Ind_notInRange], G_i[Ind_notInRange], a),
                                               a, 1.0 / (i + 1))  # stop here
    torch.cuda.empty_cache()

    return gm.reshape(*size[1:])

def ptPick_notInRange(g0, g1, logm_invg0_g1, i):  # (-1,3,3)
    #alpha = torch.det(g1).clamp(min=1.0e-15).pow(1 / 4) / torch.det(g0).clamp(min=1.0e-15).pow(1 / 4)  # (-1)
    tr_g0_b = torch.einsum("...ii->", logm_invg0_g1)
    alpha = torch.exp(tr_g0_b / 4.0)
    #print('ptPick_notInRange, g0 NaN?', g0.isnan().any(), 'g1 NaN?', g1.isnan().any(), 'alpha NaN?', alpha.isnan().any())
    #print('ptPick_notInRange, g0 Inf?', g0.isinf().any(), 'g1 Inf?', g1.isinf().any(), 'alpha Inf?', alpha.isinf().any())
    #Ind_close_to_g0 = (alpha <= i).nonzero().reshape(-1)
    #Ind_close_to_g1 = (alpha > i).nonzero().reshape(-1)
    Ind_close_to_g0 = (tr_g0_b <= i).nonzero().reshape(-1)
    Ind_close_to_g1 = (tr_g0_b > i).nonzero().reshape(-1)

    def get_gm_inLine_0g0(alpha, g0, i):
        kn_over4 = -(1 + alpha) / (i + 1)  # (-1)
        gm = (1 + kn_over4) ** (4 / g0.size(-1)) * torch.einsum("...ij->ij...", g0)  # (3,3,-1)
        return gm.permute(2, 0, 1)  # (-1,3,3)

    def get_gm_inLine_0g1(alpha, g1, i):
        kn_over4 = -i * (1 + 1 / alpha) / (i + 1)  # (-1)
        gm = (1 + kn_over4) ** (4 / g1.size(-1)) * torch.einsum("...ij->ij...", g1)  # (3,3,-1)
        return gm.permute(2, 0, 1)

    if len(Ind_close_to_g1) == 0:  # all are close to g0
        gm = get_gm_inLine_0g0(alpha, g0, i)
    elif len(Ind_close_to_g0) == 0:
        gm = get_gm_inLine_0g1(alpha, g1, i)
    else:
        gm = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        gm[Ind_close_to_g0] = get_gm_inLine_0g0(alpha[Ind_close_to_g0], g0[Ind_close_to_g0], i)
        gm[Ind_close_to_g1] = get_gm_inLine_0g1(alpha[Ind_close_to_g1], g1[Ind_close_to_g1], i)
    return gm
    
def dbg_get_geo_scipy(g0, g1, a, Tpts):  # (s,t,...,,3,3)
    '''
    use odd number Tpts of time points since the geodesic may go
    though the zero matrix which will give the middle point of the geodesic
    '''
    size = g0.size()

    g0, g1 = g0.reshape(-1, *size[-2:]), g1.reshape(-1, *size[-2:])  # (-1,3,3)

    Time = torch.arange(Tpts, out=torch.DoubleTensor()) / (Tpts - 1)  # (Tpts)

    #U = scipy_logm_invB_A(g0, g1)
    #UTrless = U - torch.einsum("...ii,kl->...kl", U, torch.eye(g1.size(-1), dtype=torch.double, device=g0.device)) / g1.size(-1)  # (...,3,3)
    #theta = ((1 / a * torch.einsum("...ik,...ki->...", UTrless, UTrless)).sqrt() / 4 - np.pi)
    
    
    logm_invg0_g1 = scipy_logm_invB_A(g0, g1)
    tr_g0_b = torch.einsum("...kii->...k", logm_invg0_g1)
    b = torch.einsum("...ik,...kj->...ij",g0, logm_invg0_g1)
    bT = b - torch.einsum("...,...ij->...ij", tr_g0_b, g0) * a
    g0inv = torch.inverse(g0)
    g0bT = torch.einsum("...ik,...kj->...ij", g0inv, bT)
    theta = torch.einsum("...ik,...ki->...", g0bT, g0bT)
    print('tr(b):',tr_g0_b)
    print('theta(b):',theta)
    
    thresh = a * (4*np.pi)**2
    Ind_inRange = (theta < thresh).nonzero().reshape(-1)
    Ind_notInRange = (theta >= thresh).nonzero().reshape(-1)

    def geo_in_range(g0, g1, a, Tpts):
        u = inv_RieExp_extended_scipy(g0, g1, a)  # (-1,3,3)
        geo = torch.zeros(Tpts, *g0.size(), dtype=torch.double, device=g0.device)  # (Tpts,-1,3,3)
        geo[0], geo[-1] = g0, g1
        for i in range(1, Tpts - 1):
            #geo[i] = Rie_Exp_extended(g0, u * Time[i], a)
            geo[i] = Rie_Exp_extended(g0, u, a, Time[i])
        return geo  # (Tpts,-1,2,2)

    def geo_not_in_range(g0, g1, a, Tpts):  # (-1,3,3)
        print('geo not in range')
        m0 = torch.zeros(g0.size(), dtype=torch.double, device=g0.device)
        m0[...,0,0] = 1
        m0[...,1,1] = 1
        m0[...,2,2] = 1
        
        u0 = inv_RieExp_extended_scipy(g0, m0, a)
        u1 = inv_RieExp_extended_scipy(g1, m0, a)

        geo = torch.zeros(Tpts, *g0.size(), dtype=torch.double, device=g0.device)  # (Tpts,-1,3,3)
        geo[0], geo[-1] = g0, g1

        for i in range(1, int((Tpts - 1) / 2)):
            #geo[i] = Rie_Exp_extended(g0, u0 * Time[i], a)
            geo[i] = Rie_Exp_extended(g0, u0, a, Time[i])
        for j in range(-int((Tpts - 1) / 2), -1):
            #geo[j] = Rie_Exp_extended(g1, u1 * (1 - Time[j]), a)
            geo[j] = Rie_Exp_extended(g1, u1, a, (1 - Time[j]))
        return geo  # (Tpts,-1,2,2)

    # If g1 = 0, len(Ind_notInRange) and len(Ind_inRange) are both zero. In this case we say that g1 is in the range
    if (len(Ind_notInRange) == 0): # all in the range
        geo = geo_in_range(g0, g1, a, Tpts)
    elif (len(Ind_inRange) == 0):  # all not in range
        geo = geo_not_in_range(g0, g1, a, Tpts)
    else:
        geo = torch.zeros(Tpts, *g0.size(), dtype=torch.double, device=g0.device)  # (Tpts,-1,3,3)
        geo[:, Ind_inRange] = geo_in_range(g0[Ind_inRange], g1[Ind_inRange], a, Tpts)
        geo[:, Ind_notInRange] = geo_not_in_range(g0[Ind_notInRange], g1[Ind_notInRange], a, Tpts)
    return geo.reshape(Tpts, *size)


In [49]:
z = torch.zeros((2,1,dim,dim))
ns = G.shape[0]
dbg_meanG13 = dbg_get_karcher_mean_scipy(G[:,13,60-37-1:60-32-1,22:26],1.0/dim)
#dbg_meanG13 = get_karcher_mean_scipy(G[:,13,60-37-1:60-35-1,22],1.0/dim)
print('dbg karcher mean(G13):\n',dbg_meanG13)
#for ii in range(G.shape[1]):
#  print(ii)  
#  dbg_meanGii = get_karcher_mean_scipy(G[:,ii,60-37-1,22],1.0/dim)
#  print('karcher mean(Gii):\n',dbg_meanGii)
meanG13 = get_karcher_mean_scipy(G[:,13,60-37-1:60-32-1,22:26],1.0/dim)
print('karcher mean(G13):\n',meanG13)
meandiff = dbg_meanG13 - meanG13
print('min, max diff:',torch.min(meandiff), torch.max(meandiff))
#print(meandiff)

dbg karcher mean(G13):
 tensor([[[[ 1.2457, -1.0630,  0.0000],
          [-1.0630,  1.8027,  0.0000],
          [ 0.0000,  0.0000,  2.0779]],

         [[ 0.9331, -0.7799,  0.0000],
          [-0.7799,  1.5090,  0.0000],
          [ 0.0000,  0.0000,  3.1576]],

         [[ 1.0000,  0.0000,  0.0000],
          [ 0.0000,  1.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0000]],

         [[ 1.0000,  0.0000,  0.0000],
          [ 0.0000,  1.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0000]]],


        [[[ 1.0412, -0.7800,  0.0000],
          [-0.7800,  1.5506,  0.0000],
          [ 0.0000,  0.0000,  7.8031]],

         [[ 1.1002, -0.9930,  0.0000],
          [-0.9930,  2.0498,  0.0000],
          [ 0.0000,  0.0000,  1.6065]],

         [[ 0.9780, -0.9456,  0.0000],
          [-0.9456,  2.2239,  0.0000],
          [ 0.0000,  0.0000,  2.5487]],

         [[ 0.7238, -0.7715,  0.0000],
          [-0.7715,  2.1468,  0.0000],
          [ 0.0000,  0.0000,  2.2210]]],


        [[[ 0.9614, 

In [47]:
#for ii in range(G.shape[1]):
#    print(ii)
# 60-37-1:60-32-1,22:26
dbg_meanG15 = dbg_get_karcher_mean_scipy(G[9:11,15,60-36-1:60-35-1,22],1.0/dim)
print(G[9:11,15,60-36-1:60-35-1,22])
print(dbg_meanG15)
geoG16=dbg_get_geo_scipy(G[10,15,60-36-1:60-35-1,22].reshape((1,3,3)), G[9,15,60-36-1:60-35-1,22].reshape((1,3,3)), 1./dim,11)
print(geoG16[5])
print(geoG16)




None in range, for index, 1
tensor([[[[ 1.5492, -1.6283,  0.0000],
          [-1.6283,  2.5605,  0.0000],
          [ 0.0000,  0.0000, 48.2103]]],


        [[[ 0.8427, -0.6692,  0.0000],
          [-0.6692,  1.2583,  0.0000],
          [ 0.0000,  0.0000,  3.3507]]]])
tensor([[[ 0.9977, -0.8911,  0.0000],
         [-0.8911,  1.6928,  0.0000],
         [ 0.0000,  0.0000, 28.1103]]])
tr(b): tensor([3.4307])
theta(b): tensor([3.7702])
thresh: 0.1
geo not in range
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
tensor([[[[ 0.8427, -0.6692,  0.0000],
          [-0.6692,  1.2583,  0.0000],
          [ 0.0000,  0.0000,  3.3507]]],


        [[[ 0.8185, -0.5849,  0.0000],
          [-0.5849,  1.2779,  0.0000],
          [ 0.0000,  0.0000,  3.1423]]],


        [[[ 0.8104, -0.5049,  0.0000],
          [-0.5049,  1.3288,  0.0000],
          [ 0.0000,  0.0000,  2.9882]]],


        [[[ 0.8173, -0.4261,  0.0000],
          [-0.4261,  1.4111,  0.0000],
          [ 0.0000,  0.

In [None]:
37:36 karcher mean(G13):
 tensor([[[ 1.2457, -1.0630,  0.0000],
         [-1.0630,  1.8027,  0.0000],
         [ 0.0000,  0.0000,  2.0779]]])
36:35 karcher mean(G13):
 tensor([[[ 1.0412, -0.7800,  0.0000],
         [-0.7800,  1.5506,  0.0000],
         [ 0.0000,  0.0000,  7.8031]]])
37:35 karcher mean(G13):
 tensor([[[ 1.2137, -1.0032,  0.0000],
         [-1.0032,  1.7226,  0.0000],
         [ 0.0000,  0.0000,  1.9747]],

        [[ 1.0302, -0.7463,  0.0000],
         [-0.7463,  1.5000,  0.0000],
         [ 0.0000,  0.0000,  7.3529]]])

In [26]:
# 44,33

strti = 0
stpi = 3
print(G[strti:stpi,13,60-33-1,44])

dbg_meanG13 = get_karcher_mean(G[strti:stpi,13,60-33-1,44],1.0/dim)
print(dbg_meanG13)
dbg_meanG13_shuffle = get_karcher_mean_shuffle(G[strti:stpi,13,60-33-1,44],1.0/dim)
print(dbg_meanG13_shuffle)
dbg_meanG13_scipy = get_karcher_mean_scipy(G[strti:stpi,13,60-33-1,44],1.0/dim)
print(dbg_meanG13_scipy)
dbg_meanG13_shuffle_scipy = get_karcher_mean_shuffle_scipy(G[strti:stpi,13,60-33-1,44],1.0/dim)
print(dbg_meanG13_shuffle_scipy)



tensor([[[ 0.1430, -0.1809,  0.0000],
         [-0.1809,  0.5654,  0.0000],
         [ 0.0000,  0.0000,  0.4645]],

        [[ 0.1589, -0.1918,  0.0000],
         [-0.1918,  0.6069,  0.0000],
         [ 0.0000,  0.0000,  0.4980]],

        [[ 0.1684, -0.2002,  0.0000],
         [-0.2002,  0.6359,  0.0000],
         [ 0.0000,  0.0000,  0.5215]]])
tensor([[ 1.7659e-01, -1.7845e-01, -1.6478e-18],
        [-1.7845e-01,  5.9336e-01,  4.4575e-18],
        [-1.6478e-18,  4.4575e-18,  4.8108e-01]])
tensor([[ 0.1429, -0.1938,  0.0417],
        [-0.1945,  0.5986,  0.0020],
        [-0.0009, -0.0109,  0.4965]])
tensor([[ 0.1723, -0.1971,  0.0000],
        [-0.1971,  0.5905,  0.0000],
        [ 0.0000,  0.0000,  0.5049]])
tensor([[ 0.1655, -0.1909,  0.0000],
        [-0.1909,  0.6268,  0.0000],
        [ 0.0000,  0.0000,  0.5080]])


In [26]:
du = torch.stack((torch.zeros((dim,dim)), torch.eye(dim)))
print(du.shape,'\n',du)
print('tr(du0):',torch.einsum('...ii->',du[0]))
print('tr(du1):',torch.einsum('...ii->',du[1]))
print('tr(du01):',torch.einsum('...kii->...k',du))
trdu= torch.einsum('...kii->...k',du)
trdudu = torch.einsum('...,...ij->...ij',trdu,du)
print('tr(du01)*du:\n',trdudu)
print('tr(du01)*du:\n',(trdu*torch.einsum('...ij->ij...',du)).permute(2,0,1))


torch.Size([2, 3, 3]) 
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
tr(du0): tensor(0.)
tr(du1): tensor(3.)
tr(du01): tensor([0., 3.])
tr(du01)*du:
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[3., 0., 0.],
         [0., 3., 0.],
         [0., 0., 3.]]])
tr(du01)*du:
 tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[3., 0., 0.],
         [0., 3., 0.],
         [0., 0., 3.]]])


In [74]:
dbg_mean01 = get_karcher_mean_shuffle_scipy(G[0:2,13,60-37-1,22], 1.0/dim)
print('Karcher mean[0,1]:\n',dbg_mean01)
print('geo midpoint[0,1]:\n',geo01s[5])
dbg_mean10 = get_karcher_mean_shuffle_scipy(torch.stack((G[1,13,60-37-1,22],G[0,13,60-37-1,22])), 1.0/dim)
print('Karcher mean[1,0]:\n',dbg_mean10)
print('geo midpoint[1,0]:\n',geo10s[5])
dbg_mean00 = get_karcher_mean_shuffle_scipy(torch.stack((G[0,13,60-37-1,22],G[0,13,60-37-1,22])), 1.0/dim)
print('Karcher mean[0,0]:\n',dbg_mean00)
print('geo midpoint[0,0]:\n',G[0,13,60-37-1,22])
dbg_mean11 = get_karcher_mean_shuffle_scipy(torch.stack((G[1,13,60-37-1,22],G[1,13,60-37-1,22])), 1.0/dim)
print('Karcher mean[1,1]:\n',dbg_mean11)
print('geo midpoint[1,1]:\n',G[1,13,60-37-1,22])


entering inv_RieExp_scipy
Karcher mean[0,1]:
 tensor([[ 1.3862, -1.1815,  0.0000],
        [-1.1815,  1.8711,  0.0000],
        [ 0.0000,  0.0000,  2.1076]])
geo midpoint[0,1]:
 tensor([[[ 1.3901, -1.1964,  0.0000],
         [-1.1964,  1.9365,  0.0000],
         [ 0.0000,  0.0000,  2.1334]]])
entering inv_RieExp_scipy
Karcher mean[1,0]:
 tensor([[ 1.3862, -1.1815,  0.0000],
        [-1.1815,  1.8711,  0.0000],
        [ 0.0000,  0.0000,  2.1076]])
geo midpoint[1,0]:
 tensor([[[ 1.3862, -1.1815,  0.0000],
         [-1.1815,  1.8711,  0.0000],
         [ 0.0000,  0.0000,  2.1076]]])
entering inv_RieExp_scipy
Karcher mean[0,0]:
 tensor([[ 1.4348, -1.2436,  0.0000],
        [-1.2436,  1.9749,  0.0000],
        [ 0.0000,  0.0000,  2.2054]])
geo midpoint[0,0]:
 tensor([[ 1.4348, -1.2436,  0.0000],
        [-1.2436,  1.9749,  0.0000],
        [ 0.0000,  0.0000,  2.2054]])
entering inv_RieExp_scipy
Karcher mean[1,1]:
 tensor([[ 1.3418, -1.1340,  0.0000],
        [-1.1340,  1.8344,  0.0000],
  

In [66]:
print(G[:,13,60-37-1:60-35-1,22].shape)
print(G[:,13,60-37-1,22].shape)


torch.Size([20, 2, 3, 3])
torch.Size([20, 3, 3])


In [24]:
print(np.sum(mask_union))
print(mask_union.shape)
print(G.shape)
print(60*60*27)
mask_torch = torch.from_numpy(mask_union.reshape((-1)))
print(mask_torch.shape)
print(device)
Ind_mask =  (mask_torch > 0).nonzero().reshape(-1) 
Ind_outmask = (mask_torch < 0.5).nonzero().reshape(-1) 
#gm_mask = torch.zeros((60*60*27,3,3),device=device, dtype=torch.double)
Greshape = G.reshape((G.shape[0],-1,3,3))
print(Greshape[0,Ind_mask].shape)
gm_mask = torch.zeros_like(Greshape[0])
print(gm_mask.shape)
print(Ind_mask.shape)
gm_mask[Ind_mask] = Greshape[0,Ind_mask]


2341.0
(60, 60, 27)
torch.Size([20, 27, 60, 60, 3, 3])
97200
torch.Size([97200])
cuda:0
torch.Size([2341, 3, 3])
torch.Size([97200, 3, 3])
torch.Size([2341])


## 2. Check estimated diffeomorphisms

In [18]:
# TODO get ptPick_notInRange code correct, run by Martin.  Thresholds, formulas, etc
# TODO update squared ebin distance calculation in same manner as rest of changes
# TODO redo expected dist and expected log map calculations
# TODO figure out how best to report these changes everywhere

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [46]:
z

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cpu')

## 3. Check diffeomorphism action