In [None]:
### %matplotlib inline
from jupyterthemes import jtplot
jtplot.style(theme="onedork", context="notebook", ticks=True, grid=False)

import os
import imp
import sys
print(sys.version)

import numpy as np
from numba import njit
import h5py
# import multiprocessing as mp

NUM_PROC = 16

# if 'pool' in locals():
#     pool.terminate()

# pool = mp.Pool(NUM_PROC)
pool = None

FILE_PATH = os.path.abspath('')
print(FILE_PATH)

import fastpli.simulation
import fastpli.analysis
import fastpli.tools
import fastpli.io
imp.reload(fastpli)

import matplotlib.pyplot as plt
import tikzplotlib
# from skimage.external import tifffile as tif

import vector_field_generation as vfg
imp.reload(vfg)

np.random.seed(42)

print(fastpli.__version__)

In [None]:
scale = 10
voxel_size = 0.1

In [None]:
simpli = fastpli.simulation.Simpli()
simpli.omp_num_threads = NUM_PROC
print(FILE_PATH)
simpli.fiber_bundles = fastpli.io.fiber_bundles.load(
    os.path.join(FILE_PATH, '..', 'data', 'models',
                 'cube_2pop_psi_0.5_omega_0.0_.solved.h5'))
# simpli.fiber_bundles_properties = [[(0.75, 0, 0, 'b'),
#                                     (1.0, dn, 0, model)]] * len(
#                                         simpli.fiber_bundles)

simpli.voxel_size = voxel_size  # in µm meter
# simpli.set_voi([-5] * 3, [5] * 3)  # in µm meter
simpli.set_voi([-2] * 3, [2] * 3)  # in µm meter

voxel_size_0 = simpli.voxel_size

# fig, axs = plt.subplots(2, 3)
for m, (dn, model) in enumerate([(0.004, 'r')]):  #,  (-0.002, 'p')
    simpli.fiber_bundles_properties = [[(0.75, 0, 0, 'b'), (1.0, dn, 0, model)]
                                       ] * len(simpli.fiber_bundles)

    # low resolution
    simpli.voxel_size = voxel_size_0  # in µm meter
    tissue, optical_axis, tissue_properties = simpli.generate_tissue()

    # high resolution
    simpli.voxel_size = voxel_size_0 / scale  # in µm meter
    print(
        f"scale: {scale}, model: {model} -> {simpli.memory_usage('MB'):.0f} MB"
    )
    if simpli.memory_usage('MB') > 64**12:
        print("MEMORY!")
        sys.exit(1)
    tissue_high, optical_axis_high, tissue_properties = simpli.generate_tissue(
    )
    print("IntpVecField")

    tissue_int = np.empty_like(tissue_high)
    vf_intp = np.empty_like(optical_axis_high)

    simpli._Simpli__sim.__field_interpolation(tissue.shape, tissue_int.shape,
                                              tissue, optical_axis, tissue_int,
                                              vf_intp, False)
    #     vf_intp = vfg.IntpVecField(tissue, optical_axis, scale, True, pool)
    print("diff")
    vf_diff = np.linalg.norm(vfg.VectorOrientationSubstractionField(
        optical_axis_high, vf_intp),
                             axis=-1)

    vmax = np.amax(vf_diff)
    print(f"vmax: {vmax}")
    #     axs[m, 0].imshow(vf_diff[vf_diff.shape[0] // 2, :, :],
    #                      vmin=0,
    #                      vmax=vmax)
    #     axs[m, 1].imshow(vf_diff[:, vf_diff.shape[1] // 2, :],
    #                      vmin=0,
    #                      vmax=vmax)
    #     pcm = axs[m, 2].imshow(vf_diff[:, :, vf_diff.shape[2] // 2],
    #                            vmin=0,
    #                            vmax=vmax)

    tmp = np.logical_or(np.logical_and(tissue_high > 0, tissue_high % 2 == 0),
                        vf_diff != 0)
    tmp = vf_diff[tmp].flatten()
    val = np.mean(tmp)
    val_std = np.std(tmp)
    val_min = np.min(tmp)
    val_max = np.max(tmp)
    val_med = np.median(tmp)
    val_qnt = np.quantile(tmp, [0.25, 0.5, 0.75])

    print(
        f"val: {val} \pm {val_std}: {val_med}, {val_qnt}, {val_min}, {val_max}"
    )

    #     x = np.linspace(0, 1, 1000, True)
    #     plt.plot(x, np.quantile(tmp,x))

    tmp = np.sort(tmp)
    x = np.arange(tmp.size)/tmp.size

    if x.size > 1000:
        s = x.size // 1000
        x = x[::s]
        tmp = tmp[::s]
    plt.plot(x, tmp)
    plt.plot(x, x * 0)

#     if m == 0:
#         fig.colorbar(pcm, ax=axs[:, 2])

# print(tissue_high.shape)
# np.savez(f"test_vfdiff_{scale}.npz")
# tif.imsave(f"vfdiff_{scale}.tiff", vf_diff, bigtiff=True)
# tif.imsave(f"label_{scale}.tiff", tissue_high, bigtiff=True)

# plt.show()
# tikzplotlib.save(f"test_{scale}.tex")

# plt.show()

In [None]:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(optical_axis_high[vf_diff.shape[0] // 2, :, :], vmin=-1, vmax=1)
axs[1].imshow(optical_axis_high[:, vf_diff.shape[1] // 2, :], vmin=-1, vmax=1)
axs[2].imshow(optical_axis_high[:, :, vf_diff.shape[2] // 2], vmin=-1, vmax=1)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(vf_intp[vf_diff.shape[0] // 2, :, :], vmin=-1, vmax=1)
axs[1].imshow(vf_intp[:, vf_diff.shape[1] // 2, :], vmin=-1, vmax=1)
axs[2].imshow(vf_intp[:, :, vf_diff.shape[2] // 2], vmin=-1, vmax=1)

vf_norm = np.linalg.norm(optical_axis_high, axis=-1)
fig, axs = plt.subplots(1, 3)
axs[0].imshow(vf_norm[vf_diff.shape[0] // 2, :, :], vmin=0, vmax=1)
axs[1].imshow(vf_norm[:, vf_diff.shape[1] // 2, :], vmin=0, vmax=1)
axs[2].imshow(vf_norm[:, :, vf_diff.shape[2] // 2], vmin=0, vmax=1)

vf_norm = np.linalg.norm(vf_intp, axis=-1)
fig, axs = plt.subplots(1, 3)
axs[0].imshow(vf_norm[vf_diff.shape[0] // 2, :, :], vmin=0, vmax=1)
axs[1].imshow(vf_norm[:, vf_diff.shape[1] // 2, :], vmin=0, vmax=1)
axs[2].imshow(vf_norm[:, :, vf_diff.shape[2] // 2], vmin=0, vmax=1)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(vf_diff[vf_diff.shape[0] // 2, :, :], vmin=0, vmax=vmax)
axs[1].imshow(vf_diff[:, vf_diff.shape[1] // 2, :], vmin=0, vmax=vmax)
axs[2].imshow(vf_diff[:, :, vf_diff.shape[2] // 2], vmin=0, vmax=vmax)

In [None]:
fig, axs = plt.subplots(1, 3, frameon=False)
axs[0].imshow(vf_diff[vf_diff.shape[0] // 2, :, :], vmin=0, vmax=vmax)
axs[1].imshow(vf_diff[:, vf_diff.shape[1] // 2, :], vmin=0, vmax=vmax)
axs[2].imshow(vf_diff[:, :, vf_diff.shape[2] // 2], vmin=0, vmax=vmax)
tis = tissue_high.copy()
print(np.sum(tis == 4))
tis[tis % 2 == 1] = 0
axs[0].imshow(tissue_high[vf_diff.shape[0] // 2, :, :],
              vmin=0,
              vmax=4,
              cmap='gray',
              alpha=0.25)
axs[1].imshow(tissue_high[:, vf_diff.shape[1] // 2, :],
              vmin=0,
              vmax=4,
              cmap='gray',
              alpha=0.25)
axs[2].imshow(optical_axis_high[:, :, vf_diff.shape[2] // 2],
              vmin=0,
              vmax=4,
              cmap='gray',
              alpha=0.25)
for ax in axs:
    ax.axis('off')

# print(os.path.join(FILE_NAME, "results"))

# os.makedirs(os.path.join(FILE_NAME, "results"), exist_ok=True)
# tikzplotlib.save(
#     os.path.join(FILE_NAME, "results", f"vdiff_{scale}_{model}.tex"))

In [None]:
phi = np.linspace(-0.1 * np.pi, 1.1 * np.pi, 100, True)
theta = phi * 0
v = np.array(
    np.array([
        np.sin(phi) * np.cos(theta),
        np.cos(phi) * np.cos(theta),
        np.sin(theta)
    ]).T)
v0 = np.zeros_like(v)
v0[:, 0] = 1
vf_difff = np.linalg.norm(vfg.VectorOrientationSubstractionField(v0, v),
                          axis=-1)

plt.plot(phi / np.pi * 180, vf_difff)
print("WISO > 1 ??")