In [1]:
import torch
import sys
import os
sys.path.append("/home/karabelas/workspace/DISCO-AE")  # add the path to the DiffusionNet src
import diffusion_net_vol  # noqa
from diffusion_net_vol.geometry import *
from diffusion_net_vol.utils import read_carp_bin_mesh, sparse_torch_to_np
from utils import find_knn, zoomout_refine

In [2]:
from stage1 import hearts
from hearts import shape_to_device

In [None]:
train_dataset = hearts.HeartDataset("/home/karabelas/workspace/DISCO-AE/data/Data_hearts", name="hearts", k_eig=64,
                                          n_fmap=50, use_cache=True)

loading 20 meshes
processing  Case_01
using dataset cache path: /home/karabelas/workspace/DISCO-AE/data/Data_hearts/cache/Case_01_cache.npz
  --> dataset not in cache, repopulating
loading mesh /home/karabelas/workspace/DISCO-AE/data/Data_hearts/Case_01
Number of Points: 85754
Number of elements: 397277
Number of Fibers: 2
[[ -0.73154879  -1.37987551   0.58370212  -1.88298886  -3.6579397
   10.03090147   2.27727985  -1.79826246   3.06792462   5.40923852
    3.09422912  -3.41223913  -4.56049948  -5.96900063   0.15445905
   -1.22538017]
 [-29.51323066   9.32300772   1.34940356   2.41920183   1.38231837
    0.60800938   2.20810243   4.59658123   5.53472348  -1.51304885
   -2.79594312   6.41107216  -4.30303556  -0.50402805   2.21256644
    2.58429965]
 [-19.07210437   4.03201655   0.75392173   0.71343746  -3.07149194
   -8.82212388   0.32309525   0.30408973   6.43894668   1.79745946
    6.65066297   5.01948554  11.40411444   0.74567835  -0.05436975
   -7.16281824]]
processing  Case_02
usin

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)

In [None]:
from model import FMLoss, GeomFMapNet
from utils import augment_batch

In [None]:
import yaml

In [None]:
cfg = yaml.safe_load(open("hearts.yaml", "r"))

In [None]:
fm_net = GeomFMapNet(cfg)

In [None]:
lr = float(cfg["optimizer"]["lr"])
optimizer = torch.optim.AdamW(fm_net.parameters(), lr=lr, betas=(cfg["optimizer"]["b1"], cfg["optimizer"]["b2"]))

In [None]:
criterion = FMLoss(w_bij=cfg["loss"]["w_bij"], w_ortho=cfg["loss"]["w_ortho"])

In [None]:
device = torch.device("cpu")
print("start training")
iterations = 0
for epoch in range(1, cfg["training"]["epochs"] + 1):
    if epoch % cfg["optimizer"]["decay_iter"] == 0:
        lr *= cfg["optimizer"]["decay_factor"]
        print(f"Decaying learning rate, new one: {lr}")
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    fm_net.train()
    for i, data in enumerate(train_loader):
        data = shape_to_device(data, device)

        # data augmentation
        data = augment_batch(data, rot_x=0, rot_y=0, rot_z=0, std=0.01, noise_clip=0.00, scale_min=1.0, scale_max=1.0)

        # do iteration
        C12, C21 = fm_net(data)
        loss = criterion(C12, C21)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # log
        iterations += 1
        if iterations % cfg["misc"]["log_interval"] == 0:
            print(f"#epoch:{epoch}, #batch:{i + 1}, #iteration:{iterations}, loss:{loss}")

    # save model
    #if (epoch + 1) % cfg["misc"]["checkpoint_interval"] == 0:
    #    torch.save(fm_net.state_dict(), model_save_path.format(epoch))

In [None]:
train_loader[0].grad

In [None]:
train_dataset.grad_list[0]

In [None]:
torch.einsum('bi,i->bi', train_dataset.evecs_list[0].t(), train_dataset.massvec_list[0])

In [None]:
torch.save(fm_net.state_dict(), "/home/karabelas/workspace/DISCO-AE/data/Data_hearts/cache/fm.pth")

In [None]:
saved_maps = "/home/karabelas/workspace/DISCO-AE/data/Data_hearts/cache/heart_maps.pt"

In [None]:
# Eval loop
n_fmap=30
print("start evaluation...")
to_save = {}

fm_net.eval()
for i, data in enumerate(train_loader):
        data = shape_to_device(data, device)
        evecs1, evecs2 = data["shape1"]["evecs"], data["shape2"]["evecs"]

        # do iteration
        C12, C21 = fm_net(data)
        C12, C21 = C12.squeeze(0), C21.squeeze(0)

        # maps from 2 to 1
        evec1_on_2 = evecs1[:, :n_fmap] @ C12.transpose(0, 1)
        _, pred_labels2to1 = find_knn(evecs2[:, :n_fmap], evec1_on_2, k=1, method='cpu_kd')
        map_21 = pred_labels2to1.flatten()

        # maps from 1 to 2
        evec2_on_1 = evecs2[:, :n_fmap] @ C21.transpose(0, 1)
        _, pred_labels1to2 = find_knn(evecs1[:, :n_fmap], evec2_on_1, k=1, method='cpu_kd')
        map_12 = pred_labels1to2.flatten()

        # zoomout refinement
        
        C12, C21, evecs1, evecs2 = C12.detach().cpu().numpy(), C21.detach().cpu().numpy(), evecs1.cpu().numpy(), evecs2.cpu().numpy()
        _, map_21_ref = zoomout_refine(evecs1, evecs2, C12, nit=10,
                                       step=(evecs1.shape[-1] - n_fmap) // 10, return_p2p=True)
        _, map_12_ref = zoomout_refine(evecs2, evecs1, C21, nit=10,
                                       step=(evecs2.shape[-1] - n_fmap) // 10, return_p2p=True)
    
        to_save[f'{data["shape1"]["name"]}_{data["shape2"]["name"]}'] = [map_12, map_12_ref, map_21, map_21_ref]

torch.save(to_save, saved_maps)

In [None]:
len(to_save)

In [None]:
szs = np.array([train_dataset.verts_list[j].shape[0] for j in range(len(train_dataset.verts_list))])

In [None]:
test = sparse_torch_to_np(train_dataset.gradX_list[1])

In [None]:
test.eliminate_zeros()

In [None]:
test

In [None]:
szs.argmax()

In [None]:
bname = "/home/karabelas/workspace/DISCO-AE/data/Data_hearts/Case_01"
xyz, con, _, _ = read_carp_bin_mesh(bname)
con = con[:, [1,0,2,3]]
xyz = diffusion_net_vol.geometry.normalize_positions(xyz)

# normalize area
xyz = diffusion_net_vol.geometry.normalize_volume_scale(xyz, con)

In [None]:
L = (-1.0) * cotmatrix(xyz, con)
L_coo = L.tocoo()
inds_row = L_coo.row
inds_col = L_coo.col

In [None]:
# For meshes, we use the same edges as were used to build the Laplacian
edges = np.stack((inds_row, inds_col), axis=0)
edge_vecs = xyz[edges[1, :], :] - xyz[edges[0, :], :]


In [None]:
gradMat = build_grad_parallel_single(xyz, edges, edge_vecs, format='csr')

In [None]:
gradMat

In [None]:
gradX = sparse_torch_to_np(train_dataset.gradX_list[0])
gradY = sparse_torch_to_np(train_dataset.gradY_list[0])
gradZ = sparse_torch_to_np(train_dataset.gradZ_list[0])

In [None]:
rng = np.random.default_rng(seed=42)
x = rng.standard_normal((xyz.shape[0],3))


In [None]:
t1 = gradX @ x[:,0]
t2 = gradY @ x[:,0]
t3 = gradZ @ x[:,0]

In [None]:
tt = gradMat @ x[:, 0]

In [None]:
grad_v_reshaped = tt.reshape(3, -1).T

In [None]:
test=np.stack((t1,t2,t3), axis=1)

In [None]:
grad_v_reshaped

In [None]:
np.linalg.norm(test[:, 2] - grad_v_reshaped[:, 2]) / np.linalg.norm(test[:, 2])

In [None]:
gradMat.indices.shape

In [None]:
gradMat.indptr.shape

In [None]:
bsr = torch.sparse_csr_tensor(gradMat.indptr, gradMat.indices, gradMat.data, dtype=torch.float32)

In [None]:
bsr

In [None]:
xtorch = torch.tensor(x, dtype=torch.float)

In [None]:
(bsr @ xtorch[:, 0]).reshape(3, -1).T

In [None]:
gradMat.data.dtype

In [None]:
bsr.crow_indices().shape

In [None]:
bsr.col_indices().shape

In [None]:
bsr.values()

In [None]:
nf = 128
nnodes = 1000
mat1 = torch.randn(nnodes, nf)
mat2 = torch.randn(3*nnodes, nnodes)
res = torch.mm(mat2, mat1)

In [None]:
res.reshape(-1,128,3).shape

In [None]:
bs = 8
xgrads=[]
for b in range(bs):
    xgrads.append(torch.stack((res, res, res), dim=-1))

In [None]:
x_grad = torch.stack(xgrads, dim=0)

In [None]:
x_grad.shape